98 lines
4.0 KiB
Python
98 lines
4.0 KiB
Python
import json
|
|
import openai
|
|
import aiohttp
|
|
import logging
|
|
from io import BytesIO
|
|
from typing import Optional, List, Dict, Any, Tuple
|
|
|
|
|
|
class AIMessageBase(object):
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def __str__(self) -> str:
|
|
return json.dumps(vars(self))
|
|
|
|
|
|
class AIMessage(AIMessageBase):
|
|
def __init__(self, user: str, message: str) -> None:
|
|
self.user = user
|
|
self.message = message
|
|
|
|
|
|
class AIResponse(AIMessageBase):
|
|
def __init__(self, answer: str, answer_needed: bool, staff: Optional[str], picture: Optional[str], hack: bool) -> None:
|
|
self.answer = answer
|
|
self.answer_needed = answer_needed
|
|
self.staff = staff
|
|
self.picture = picture
|
|
self.hack = hack
|
|
|
|
|
|
class AIResponder(object):
|
|
def __init__(self, config: Dict[str, Any]) -> None:
|
|
self.config = config
|
|
self.history: List[Dict[str, Any]] = []
|
|
openai.api_key = self.config['openai_token']
|
|
|
|
def _message(self, message: AIMessage, limit: Optional[int] = None) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
|
messages = []
|
|
messages.append({"role": "system", "content": self.config["system"]})
|
|
if limit is None:
|
|
history = self.history[:]
|
|
else:
|
|
history = self.history[-limit:]
|
|
history.append({"role": "user", "content": str(message)})
|
|
for msg in history:
|
|
messages.append(msg)
|
|
return messages, history
|
|
|
|
async def draw(self, description: str) -> BytesIO:
|
|
for _ in range(7):
|
|
try:
|
|
response = await openai.Image.acreate(prompt=description, n=1, size="512x512")
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(response['data'][0]['url']) as image:
|
|
return BytesIO(await image.read())
|
|
except Exception as err:
|
|
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
|
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
|
|
|
async def send(self, message: AIMessage) -> AIResponse:
|
|
limit = self.config["history-limit"]
|
|
for _ in range(14):
|
|
messages, history = self._message(message, limit)
|
|
try:
|
|
result = await openai.ChatCompletion.acreate(model=self.config["model"],
|
|
messages=messages,
|
|
temperature=self.config["temperature"],
|
|
top_p=self.config["top-p"],
|
|
presence_penalty=self.config["presence-penalty"],
|
|
frequency_penalty=self.config["frequency-penalty"])
|
|
answer = result['choices'][0]['message']
|
|
response = json.loads(answer['content'])
|
|
except openai.error.InvalidRequestError as err:
|
|
if 'maximum context length is' in str(err) and limit > 4:
|
|
limit -= 1
|
|
continue
|
|
raise err
|
|
except Exception as err:
|
|
logging.warning(f"failed to generate response: {repr(err)}")
|
|
continue
|
|
history.append(answer)
|
|
self.history = history
|
|
for fld in ('answer', 'staff', 'picture'):
|
|
if str(response[fld]).strip().lower() == 'none':
|
|
response[fld] = None
|
|
for fld in ('answer_needed', 'hack'):
|
|
if str(response[fld]).strip().lower() == 'true':
|
|
response[fld] = True
|
|
else:
|
|
response[fld] = False
|
|
return AIResponse(response['answer'],
|
|
response['answer_needed'],
|
|
response['staff'],
|
|
response['picture'],
|
|
response['hack'])
|
|
raise RuntimeError("Failed to generate answer after multiple retries")
|