discord_bot/fjerkroa_bot/ai_responder.py

161 lines
7.0 KiB
Python

import json
import openai
import aiohttp
import logging
import time
import re
from io import BytesIO
from pprint import pformat
from typing import Optional, List, Dict, Any, Tuple
class AIMessageBase(object):
def __init__(self) -> None:
pass
def __str__(self) -> str:
return json.dumps(vars(self))
class AIMessage(AIMessageBase):
def __init__(self, user: str, message: str, channel: str = "chat") -> None:
self.user = user
self.message = message
self.channel = channel
class AIResponse(AIMessageBase):
def __init__(self, answer: Optional[str], answer_needed: bool, staff: Optional[str], picture: Optional[str], hack: bool) -> None:
self.answer = answer
self.answer_needed = answer_needed
self.staff = staff
self.picture = picture
self.hack = hack
class AIResponder(object):
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
self.config = config
self.history: List[Dict[str, Any]] = []
self.channel = channel if channel is not None else 'system'
openai.api_key = self.config['openai-token']
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
messages = []
system = self.config.get(self.channel, self.config['system'])
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
.replace('{time}', time.strftime('%H:%M:%S'))
messages.append({"role": "system", "content": system})
if limit is None:
history = self.history[:]
else:
history = self.history[-limit:]
history.append({"role": "user", "content": str(message)})
for msg in history:
messages.append(msg)
return messages
async def draw(self, description: str) -> BytesIO:
for _ in range(7):
try:
response = await openai.Image.acreate(prompt=description, n=1, size="512x512")
async with aiohttp.ClientSession() as session:
async with session.get(response['data'][0]['url']) as image:
return BytesIO(await image.read())
except Exception as err:
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
async def post_process(self, response: Dict[str, Any]) -> AIResponse:
for fld in ('answer', 'staff', 'picture'):
if str(response[fld]).strip().lower() in ('none', '', 'null'):
response[fld] = None
for fld in ('answer_needed', 'hack'):
if str(response[fld]).strip().lower() == 'true':
response[fld] = True
else:
response[fld] = False
if response['answer'] is None:
response['answer_needed'] = False
else:
response['answer'] = str(response['answer'])
return AIResponse(response['answer'],
response['answer_needed'],
response['staff'],
response['picture'],
response['hack'])
def short_path(self, message: AIMessage, limit: int) -> bool:
if 'short-path' not in self.config:
return False
for chan_re, user_re in self.config['short-path']:
chan_ma = re.match(chan_re, message.channel)
user_ma = re.match(user_re, message.user)
if chan_ma and user_ma:
self.history.append({"role": "user", "content": str(message)})
self.history = self.history[-limit:]
return True
return False
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
try:
result = await openai.ChatCompletion.acreate(model=self.config["model"],
messages=messages,
temperature=self.config["temperature"],
top_p=self.config["top-p"],
presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"])
answer = result['choices'][0]['message']
if type(answer) != dict:
answer = answer.to_dict()
return answer, limit
except openai.error.InvalidRequestError as err:
if 'maximum context length is' in str(err) and limit > 4:
limit -= 1
return None, limit
raise err
except Exception as err:
logging.warning(f"failed to generate response: {repr(err)}")
return None, limit
async def fix(self, answer: str) -> str:
if 'fix-model' not in self.config:
return answer
messages = [{"role": "system", "content": self.config["fix-description"]},
{"role": "user", "content": answer}]
for _ in range(4):
try:
result = await openai.ChatCompletion.acreate(model=self.config["fix-model"],
messages=messages,
temperature=0.2)
return result['chaices'][0]['message']['content']
except Exception as err:
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
return answer
async def send(self, message: AIMessage) -> AIResponse:
limit = self.config["history-limit"]
if self.short_path(message, limit):
return AIResponse(None, False, None, None, False)
for _ in range(14):
messages = self._message(message, limit)
logging.info(f"try to send this messages:\n{pformat(messages)}")
answer, limit = await self._acreate(messages, limit)
if answer is None:
continue
answer['content'] = await self.fix(answer['content'])
try:
response = json.loads(answer['content'])
except Exception as err:
logging.error(f"failed to parse the answer: {pformat(err)}\n{repr(answer['content'])}")
return AIResponse(None, False, f"ERROR: I could not parse this answer: {repr(answer['content'])}", None, False)
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
continue
logging.info(f"got this answer:\n{pformat(response)}")
self.history.append(messages[-1])
self.history.append(answer)
if len(self.history) > limit:
self.history = self.history[-limit:]
return await self.post_process(response)
raise RuntimeError("Failed to generate answer after multiple retries")