import json import multiline import openai import aiohttp import logging import time import re import pickle from pathlib import Path from io import BytesIO from pprint import pformat from typing import Optional, List, Dict, Any, Tuple def pp(*args, **kw): if 'width' not in kw: kw['width'] = 300 return pformat(*args, **kw) def parse_response(content: str) -> Dict: content = content.strip() try: return json.loads(content) except Exception: try: return multiline.loads(content, multiline=True) except Exception as err: raise err def parse_maybe_json(json_string): if json_string is None: return None if isinstance(json_string, list): return ' '.join([str(x) for x in json_string]) if isinstance(json_string, dict): return ' '.join([str(x) for x in json_string.values()]) json_string = str(json_string).strip() try: parsed_json = parse_response(json_string) except Exception: if json_string.startswith('{') and json_string.endswith('}'): return parse_maybe_json(json_string[1:-1]) return json_string if isinstance(parsed_json, (list, dict)): concatenated_values = [] for value in parsed_json.values() if isinstance(parsed_json, dict) else parsed_json: concatenated_values.append(str(value)) return '\n'.join(concatenated_values) result = str(parsed_json) if result.lower() in ('', 'none', 'null', '"none"', '"null"', "'none'", "'null'"): return None return result class AIMessageBase(object): def __init__(self) -> None: pass def __str__(self) -> str: return json.dumps(vars(self)) class AIMessage(AIMessageBase): def __init__(self, user: str, message: str, channel: str = "chat", direct: bool = False) -> None: self.user = user self.message = message self.channel = channel self.direct = direct class AIResponse(AIMessageBase): def __init__(self, answer: Optional[str], answer_needed: bool, channel: Optional[str], staff: Optional[str], picture: Optional[str], hack: bool ) -> None: self.answer = answer self.answer_needed = answer_needed self.channel = channel self.staff = staff self.picture = picture self.hack = hack class AIResponder(object): def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None: self.config = config self.history: List[Dict[str, Any]] = [] self.channel = channel if channel is not None else 'system' openai.api_key = self.config['openai-token'] self.history_file: Optional[Path] = None if 'history-directory' in self.config: self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat' if self.history_file.exists(): with open(self.history_file, 'rb') as fd: self.history = pickle.load(fd) def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]: messages = [] system = self.config.get(self.channel, self.config['system']) system = system.replace('{date}', time.strftime('%Y-%m-%d'))\ .replace('{time}', time.strftime('%H:%M:%S')) messages.append({"role": "system", "content": system}) if limit is not None: while len(self.history) > limit: self.shrink_history_by_one() for msg in self.history: messages.append(msg) messages.append({"role": "user", "content": str(message)}) return messages async def draw(self, description: str) -> BytesIO: for _ in range(3): try: response = await openai.Image.acreate(prompt=description, n=1, size="512x512") async with aiohttp.ClientSession() as session: async with session.get(response['data'][0]['url']) as image: return BytesIO(await image.read()) except Exception as err: logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}") raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries") async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse: for fld in ('answer', 'channel', 'staff', 'picture', 'hack'): if str(response.get(fld)).strip().lower() in \ ('none', '', 'null', '"none"', '"null"', "'none'", "'null'"): response[fld] = None for fld in ('answer_needed', 'hack'): if str(response.get(fld)).strip().lower() == 'true': response[fld] = True else: response[fld] = False if response['answer'] is None: response['answer_needed'] = False else: response['answer'] = str(response['answer']) response['answer'] = re.sub(r'@\[([^\]]*)\]\([^\)]*\)', r'\1', response['answer']) response['answer'] = re.sub(r'\[[^\]]*\]\(([^\)]*)\)', r'\1', response['answer']) if message.direct or message.user in message.message: response['answer_needed'] = True response_message = AIResponse(response['answer'], response['answer_needed'], response['channel'], parse_maybe_json(response['staff']), parse_maybe_json(response['picture']), response['hack']) if response_message.staff is not None and response_message.answer is not None: response_message.answer_needed = True return response_message def short_path(self, message: AIMessage, limit: int) -> bool: if message.direct or 'short-path' not in self.config: return False for chan_re, user_re in self.config['short-path']: chan_ma = re.match(chan_re, message.channel) user_ma = re.match(user_re, message.user) if chan_ma and user_ma: self.history.append({"role": "user", "content": str(message)}) if len(self.history) > limit: self.history = self.history[-limit:] if self.history_file is not None: with open(self.history_file, 'wb') as fd: pickle.dump(self.history, fd) return True return False async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]: try: result = await openai.ChatCompletion.acreate(model=self.config["model"], messages=messages, temperature=self.config["temperature"], max_tokens=self.config["max-tokens"], top_p=self.config["top-p"], presence_penalty=self.config["presence-penalty"], frequency_penalty=self.config["frequency-penalty"]) answer = result['choices'][0]['message'] if type(answer) != dict: answer = answer.to_dict() return answer, limit except openai.error.InvalidRequestError as err: if 'maximum context length is' in str(err) and limit > 4: logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}") limit -= 1 return None, limit raise err except Exception as err: logging.warning(f"failed to generate response: {repr(err)}") return None, limit async def fix(self, answer: str) -> str: if 'fix-model' not in self.config: return answer messages = [{"role": "system", "content": self.config["fix-description"]}, {"role": "user", "content": answer}] try: result = await openai.ChatCompletion.acreate(model=self.config["fix-model"], messages=messages, temperature=0.2, max_tokens=2048) logging.info(f"got this message as fix:\n{pp(result['choices'][0]['message']['content'])}") response = result['choices'][0]['message']['content'] start, end = response.find("{"), response.rfind("}") if start == -1 or end == -1 or (start + 3) >= end: return answer response = response[start:end + 1] logging.info(f"fixed answer:\n{pp(response)}") return response except Exception as err: logging.warning(f"failed to execute a fix for the answer: {repr(err)}") return answer def shrink_history_by_one(self, index: int = 0) -> None: if index >= len(self.history): del self.history[0] else: current = self.history[index] count = sum(1 for item in self.history[index:] if item.get('channel') == current.get('channel')) if count > self.config.get('history-per-channel', 3): del self.history[index] else: self.shrink_history_by_one(index + 1) def update_history(self, question: Dict[str, Any], answer: Dict[str, Any], limit: int) -> None: self.history.append(question) self.history.append(answer) while len(self.history) > limit: self.shrink_history_by_one() if self.history_file is not None: with open(self.history_file, 'wb') as fd: pickle.dump(self.history, fd) async def send(self, message: AIMessage) -> AIResponse: limit = self.config["history-limit"] if self.short_path(message, limit): return AIResponse(None, False, None, None, None, False) retries = 3 while retries > 0: messages = self._message(message, limit) logging.info(f"try to send this messages:\n{pp(messages)}") answer, limit = await self._acreate(messages, limit) if answer is None: continue try: response = parse_response(answer['content']) except Exception as err: logging.warning(f"failed to parse the answer: {pp(err)}\n{repr(answer['content'])}") answer['content'] = await self.fix(answer['content']) try: response = parse_response(answer['content']) except Exception as err: logging.error(f"failed to parse the fixed answer: {pp(err)}\n{repr(answer['content'])}") retries -= 1 continue if type(response.get('picture')) not in (type(None), str): logging.warning(f"picture key is wrong in response: {pp(response)}") retries -= 1 continue answer_message = await self.post_process(message, response) answer['content'] = str(answer_message) self.update_history(messages[-1], answer, limit) logging.info(f"got this answer:\n{str(answer_message)}") return answer_message raise RuntimeError("Failed to generate answer after multiple retries")