- Try to keep at least 3 messages from each channel in the history - Use post processed messages for the history, instead of the raw messages from the openai API
255 lines
11 KiB
Python
255 lines
11 KiB
Python
import json
|
|
import multiline
|
|
import openai
|
|
import aiohttp
|
|
import logging
|
|
import time
|
|
import re
|
|
import pickle
|
|
from pathlib import Path
|
|
from io import BytesIO
|
|
from pprint import pformat
|
|
from typing import Optional, List, Dict, Any, Tuple
|
|
|
|
|
|
def pp(*args, **kw):
|
|
if 'width' not in kw:
|
|
kw['width'] = 300
|
|
return pformat(*args, **kw)
|
|
|
|
|
|
def parse_response(content: str) -> Dict:
|
|
content = content.strip()
|
|
try:
|
|
return json.loads(content)
|
|
except Exception:
|
|
try:
|
|
return multiline.loads(content, multiline=True)
|
|
except Exception as err:
|
|
raise err
|
|
|
|
|
|
def parse_maybe_json(json_string):
|
|
if json_string is None:
|
|
return None
|
|
json_string = json_string.strip()
|
|
|
|
try:
|
|
parsed_json = parse_response(json_string)
|
|
except Exception:
|
|
if json_string.startswith('{') and json_string.endswith('}'):
|
|
return parse_maybe_json(json_string[1:-1])
|
|
return json_string
|
|
|
|
if isinstance(parsed_json, (list, dict)):
|
|
concatenated_values = []
|
|
for value in parsed_json.values() if isinstance(parsed_json, dict) else parsed_json:
|
|
concatenated_values.append(str(value))
|
|
return '\n'.join(concatenated_values)
|
|
|
|
result = str(parsed_json)
|
|
if result.lower() in ('', 'none', 'null', '"none"', '"null"', "'none'", "'null'"):
|
|
return None
|
|
return result
|
|
|
|
|
|
class AIMessageBase(object):
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def __str__(self) -> str:
|
|
return json.dumps(vars(self))
|
|
|
|
|
|
class AIMessage(AIMessageBase):
|
|
def __init__(self, user: str, message: str, channel: str = "chat", direct: bool = False) -> None:
|
|
self.user = user
|
|
self.message = message
|
|
self.channel = channel
|
|
self.direct = direct
|
|
|
|
|
|
class AIResponse(AIMessageBase):
|
|
def __init__(self, answer: Optional[str], answer_needed: bool, staff: Optional[str], picture: Optional[str], hack: bool) -> None:
|
|
self.answer = answer
|
|
self.answer_needed = answer_needed
|
|
self.staff = staff
|
|
self.picture = picture
|
|
self.hack = hack
|
|
|
|
|
|
class AIResponder(object):
|
|
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
|
self.config = config
|
|
self.history: List[Dict[str, Any]] = []
|
|
self.channel = channel if channel is not None else 'system'
|
|
openai.api_key = self.config['openai-token']
|
|
self.history_file: Optional[Path] = None
|
|
if 'history-directory' in self.config:
|
|
self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat'
|
|
if self.history_file.exists():
|
|
with open(self.history_file, 'rb') as fd:
|
|
self.history = pickle.load(fd)
|
|
|
|
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
messages = []
|
|
system = self.config.get(self.channel, self.config['system'])
|
|
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
|
|
.replace('{time}', time.strftime('%H:%M:%S'))
|
|
messages.append({"role": "system", "content": system})
|
|
if limit is not None:
|
|
while len(self.history) > limit:
|
|
self.shrink_history_by_one()
|
|
for msg in self.history:
|
|
messages.append(msg)
|
|
messages.append({"role": "user", "content": str(message)})
|
|
return messages
|
|
|
|
async def draw(self, description: str) -> BytesIO:
|
|
for _ in range(3):
|
|
try:
|
|
response = await openai.Image.acreate(prompt=description, n=1, size="512x512")
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(response['data'][0]['url']) as image:
|
|
return BytesIO(await image.read())
|
|
except Exception as err:
|
|
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
|
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
|
|
|
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
|
|
for fld in ('answer', 'staff', 'picture'):
|
|
if str(response[fld]).strip().lower() in \
|
|
('none', '', 'null', '"none"', '"null"', "'none'", "'null'"):
|
|
response[fld] = None
|
|
for fld in ('answer_needed', 'hack'):
|
|
if str(response[fld]).strip().lower() == 'true':
|
|
response[fld] = True
|
|
else:
|
|
response[fld] = False
|
|
if response['answer'] is None:
|
|
response['answer_needed'] = False
|
|
else:
|
|
response['answer'] = str(response['answer'])
|
|
response['answer'] = re.sub(r'@\[([^\]]*)\]\([^\)]*\)', r'\1', response['answer'])
|
|
response['answer'] = re.sub(r'\[[^\]]*\]\(([^\)]*)\)', r'\1', response['answer'])
|
|
if message.direct or message.user in message.message:
|
|
response['answer_needed'] = True
|
|
return AIResponse(response['answer'],
|
|
response['answer_needed'],
|
|
parse_maybe_json(response['staff']),
|
|
parse_maybe_json(response['picture']),
|
|
response['hack'])
|
|
|
|
def short_path(self, message: AIMessage, limit: int) -> bool:
|
|
if message.direct or 'short-path' not in self.config:
|
|
return False
|
|
for chan_re, user_re in self.config['short-path']:
|
|
chan_ma = re.match(chan_re, message.channel)
|
|
user_ma = re.match(user_re, message.user)
|
|
if chan_ma and user_ma:
|
|
self.history.append({"role": "user", "content": str(message)})
|
|
if len(self.history) > limit:
|
|
self.history = self.history[-limit:]
|
|
if self.history_file is not None:
|
|
with open(self.history_file, 'wb') as fd:
|
|
pickle.dump(self.history, fd)
|
|
return True
|
|
return False
|
|
|
|
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
|
try:
|
|
result = await openai.ChatCompletion.acreate(model=self.config["model"],
|
|
messages=messages,
|
|
temperature=self.config["temperature"],
|
|
max_tokens=self.config["max-tokens"],
|
|
top_p=self.config["top-p"],
|
|
presence_penalty=self.config["presence-penalty"],
|
|
frequency_penalty=self.config["frequency-penalty"])
|
|
answer = result['choices'][0]['message']
|
|
if type(answer) != dict:
|
|
answer = answer.to_dict()
|
|
return answer, limit
|
|
except openai.error.InvalidRequestError as err:
|
|
if 'maximum context length is' in str(err) and limit > 4:
|
|
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
|
|
limit -= 1
|
|
return None, limit
|
|
raise err
|
|
except Exception as err:
|
|
logging.warning(f"failed to generate response: {repr(err)}")
|
|
return None, limit
|
|
|
|
async def fix(self, answer: str) -> str:
|
|
if 'fix-model' not in self.config:
|
|
return answer
|
|
messages = [{"role": "system", "content": self.config["fix-description"]},
|
|
{"role": "user", "content": answer}]
|
|
try:
|
|
result = await openai.ChatCompletion.acreate(model=self.config["fix-model"],
|
|
messages=messages,
|
|
temperature=0.2,
|
|
max_tokens=2048)
|
|
logging.info(f"got this message as fix:\n{pp(result['choices'][0]['message']['content'])}")
|
|
response = result['choices'][0]['message']['content']
|
|
start, end = response.find("{"), response.rfind("}")
|
|
if start == -1 or end == -1 or (start + 3) >= end:
|
|
return answer
|
|
response = response[start:end + 1]
|
|
logging.info(f"fixed answer:\n{pp(response)}")
|
|
return response
|
|
except Exception as err:
|
|
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
|
|
return answer
|
|
|
|
def shrink_history_by_one(self, index: int = 0) -> None:
|
|
if index >= len(self.history):
|
|
del self.history[0]
|
|
else:
|
|
current = self.history[index]
|
|
count = sum(1 for item in self.history[index:] if item.get('channel') == current.get('channel'))
|
|
if count > 3:
|
|
del self.history[index]
|
|
else:
|
|
self.shrink_history_by_one(index + 1)
|
|
|
|
def update_history(self, question: Dict[str, Any], answer: Dict[str, Any], limit: int) -> None:
|
|
self.history.append(question)
|
|
self.history.append(answer)
|
|
while len(self.history) > limit:
|
|
self.shrink_history_by_one()
|
|
if self.history_file is not None:
|
|
with open(self.history_file, 'wb') as fd:
|
|
pickle.dump(self.history, fd)
|
|
|
|
async def send(self, message: AIMessage) -> AIResponse:
|
|
limit = self.config["history-limit"]
|
|
if self.short_path(message, limit):
|
|
return AIResponse(None, False, None, None, False)
|
|
retries = 3
|
|
while retries > 0:
|
|
messages = self._message(message, limit)
|
|
logging.info(f"try to send this messages:\n{pp(messages)}")
|
|
answer, limit = await self._acreate(messages, limit)
|
|
if answer is None:
|
|
continue
|
|
try:
|
|
response = parse_response(answer['content'])
|
|
except Exception as err:
|
|
logging.warning(f"failed to parse the answer: {pp(err)}\n{repr(answer['content'])}")
|
|
answer['content'] = await self.fix(answer['content'])
|
|
try:
|
|
response = parse_response(answer['content'])
|
|
except Exception as err:
|
|
logging.error(f"failed to parse the fixed answer: {pp(err)}\n{repr(answer['content'])}")
|
|
retries -= 1
|
|
continue
|
|
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
|
|
retries -= 1
|
|
continue
|
|
answer_message = await self.post_process(message, response)
|
|
answer['content'] = str(answer_message)
|
|
self.update_history(messages[-1], answer, limit)
|
|
logging.info(f"got this answer:\n{str(answer_message)}")
|
|
return answer_message
|
|
raise RuntimeError("Failed to generate answer after multiple retries")
|