import json import openai import aiohttp import logging import time from io import BytesIO from pprint import pformat from typing import Optional, List, Dict, Any class AIMessageBase(object): def __init__(self) -> None: pass def __str__(self) -> str: return json.dumps(vars(self)) class AIMessage(AIMessageBase): def __init__(self, user: str, message: str, channel: str = "chat") -> None: self.user = user self.message = message self.channel = channel class AIResponse(AIMessageBase): def __init__(self, answer: Optional[str], answer_needed: bool, staff: Optional[str], picture: Optional[str], hack: bool) -> None: self.answer = answer self.answer_needed = answer_needed self.staff = staff self.picture = picture self.hack = hack class AIResponder(object): def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None: self.config = config self.history: List[Dict[str, Any]] = [] self.channel = channel if channel is not None else 'system' openai.api_key = self.config['openai-token'] def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]: messages = [] system = self.config.get(self.channel, self.config['system']) system = system.replace('{date}', time.strftime('%Y-%m-%d'))\ .replace('{time}', time.strftime('%H:%M:%S')) messages.append({"role": "system", "content": system}) if limit is None: history = self.history[:] else: history = self.history[-limit:] history.append({"role": "user", "content": str(message)}) for msg in history: messages.append(msg) return messages async def draw(self, description: str) -> BytesIO: for _ in range(7): try: response = await openai.Image.acreate(prompt=description, n=1, size="512x512") async with aiohttp.ClientSession() as session: async with session.get(response['data'][0]['url']) as image: return BytesIO(await image.read()) except Exception as err: logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}") raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries") async def post_process(self, response: Dict[str, Any]) -> AIResponse: for fld in ('answer', 'staff', 'picture'): if str(response[fld]).strip().lower() in ('none', '', 'null'): response[fld] = None for fld in ('answer_needed', 'hack'): if str(response[fld]).strip().lower() == 'true': response[fld] = True else: response[fld] = False if response['answer'] is None: response['answer_needed'] = False else: response['answer'] = str(response['answer']) return AIResponse(response['answer'], response['answer_needed'], response['staff'], response['picture'], response['hack']) async def send(self, message: AIMessage) -> AIResponse: limit = self.config["history-limit"] for _ in range(14): messages = self._message(message, limit) logging.info(f"try to send this messages:\n{pformat(messages)}") try: result = await openai.ChatCompletion.acreate(model=self.config["model"], messages=messages, temperature=self.config["temperature"], top_p=self.config["top-p"], presence_penalty=self.config["presence-penalty"], frequency_penalty=self.config["frequency-penalty"]) answer = result['choices'][0]['message'] if type(answer) != dict: answer = answer.to_dict() response = json.loads(answer['content']) if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str): continue logging.info(f"got this answer:\n{pformat(response)}") except openai.error.InvalidRequestError as err: if 'maximum context length is' in str(err) and limit > 4: limit -= 1 continue raise err except Exception as err: logging.warning(f"failed to generate response: {repr(err)}") continue self.history.append(messages[-1]) self.history.append(answer) if len(self.history) > limit: self.history = self.history[-limit:] return await self.post_process(response) raise RuntimeError("Failed to generate answer after multiple retries")