discord_bot/fjerkroa_bot/ai_responder.py
2023-03-22 10:46:04 +01:00

87 lines
3.4 KiB
Python

import json
import openai
import aiohttp
import logging
from io import BytesIO
from typing import Optional, List, Dict, Any, Tuple
class AIMessageBase(object):
def __init__(self) -> None:
pass
def __str__(self) -> str:
return json.dumps(vars(self))
class AIMessage(AIMessageBase):
def __init__(self, user: str, message: str) -> None:
self.user = user
self.message = message
class AIResponse(AIMessageBase):
def __init__(self, answer: str, answer_needed: bool, staff: Optional[str], picture: Optional[str], hack: bool) -> None:
self.answer = answer
self.answer_needed = answer_needed
self.staff = staff
self.picture = picture
self.hack = hack
class AIResponder(object):
def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.history: List[Dict[str, Any]] = []
def _message(self, message: AIMessage, limit: Optional[int] = None) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
messages = []
messages.append({"role": "system", "content": self.config["system"]})
if limit is None:
history = self.history[:]
else:
history = self.history[-limit:]
history.append({"role": "user", "content": str(message)})
for msg in history:
messages.append(msg)
return messages, history
async def draw(self, description: str) -> BytesIO:
while True:
try:
response = await openai.Image.acreate(prompt=description, n=1, size="512x512")
async with aiohttp.ClientSession() as session:
async with session.get(response['data'][0]['url']) as image:
return BytesIO(await image.read())
except Exception as err:
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
async def send(self, message: AIMessage) -> AIResponse:
limit = self.config["history-limit"]
while True:
messages, history = self._message(message, limit)
try:
result = await openai.ChatCompletion.acreate(model=self.config["model"],
messages=messages,
temperature=self.config["temperature"],
top_p=self.config["top-p"],
presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"])
answer = result['choices'][0]['message']
response = json.loads(answer['content'])
except openai.error.InvalidRequestError as err:
if 'maximum context length is' in str(err) and limit > 4:
limit -= 1
continue
raise err
except Exception as err:
logging.warning(f"failed to generate response: {repr(err)}")
continue
history.append(answer)
self.history = history
return AIResponse(response['answer'],
response['answer_needed'],
response['staff'],
response['picture'],
response['hack'])