363 lines
14 KiB
Python
363 lines
14 KiB
Python
import os
|
|
import json
|
|
import random
|
|
import multiline
|
|
import logging
|
|
import time
|
|
import re
|
|
import pickle
|
|
from pathlib import Path
|
|
from io import BytesIO
|
|
from pprint import pformat
|
|
from functools import lru_cache, wraps
|
|
from typing import Optional, List, Dict, Any, Tuple
|
|
|
|
|
|
def pp(*args, **kw):
|
|
if 'width' not in kw:
|
|
kw['width'] = 300
|
|
return pformat(*args, **kw)
|
|
|
|
|
|
@lru_cache(maxsize=300)
|
|
def parse_json(content: str) -> Dict:
|
|
content = content.strip()
|
|
try:
|
|
return json.loads(content)
|
|
except Exception:
|
|
try:
|
|
return multiline.loads(content, multiline=True)
|
|
except Exception as err:
|
|
raise err
|
|
|
|
|
|
def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1, max_attempts=None):
|
|
"""Generate sleep intervals for exponential backoff with jitter.
|
|
|
|
Args:
|
|
base: Base of the exponentiation operation
|
|
max_delay: Maximum delay
|
|
factor: Multiplication factor for each increase in backoff
|
|
jitter: Additional randomness range to prevent thundering herd problem
|
|
|
|
Yields:
|
|
Delay for backoff as a floating point number.
|
|
"""
|
|
attempt = 0
|
|
while True:
|
|
sleep = min(max_delay, factor * base ** attempt)
|
|
jitter_amount = jitter * sleep
|
|
sleep += random.uniform(-jitter_amount, jitter_amount)
|
|
yield sleep
|
|
attempt += 1
|
|
if max_attempts is not None and attempt > max_attempts:
|
|
raise RuntimeError("Max attempts reached in exponential backoff.")
|
|
|
|
|
|
def async_cache_to_file(filename):
|
|
cache_file = Path(filename)
|
|
cache = None
|
|
if cache_file.exists():
|
|
try:
|
|
with cache_file.open('rb') as fd:
|
|
cache = pickle.load(fd)
|
|
except Exception:
|
|
cache = {}
|
|
|
|
def decorator(func):
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
if cache is None:
|
|
return await func(*args, **kwargs)
|
|
key = json.dumps((func.__name__, list(args[1:]), kwargs), sort_keys=True)
|
|
if key in cache:
|
|
return cache[key]
|
|
result = await func(*args, **kwargs)
|
|
cache[key] = result
|
|
with cache_file.open('wb') as fd:
|
|
pickle.dump(cache, fd)
|
|
return result
|
|
return wrapper
|
|
return decorator
|
|
|
|
|
|
def parse_maybe_json(json_string):
|
|
if json_string is None:
|
|
return None
|
|
if isinstance(json_string, (list, dict)):
|
|
return ' '.join(map(str, (json_string.values() if isinstance(json_string, dict) else json_string)))
|
|
json_string = str(json_string).strip()
|
|
try:
|
|
parsed_json = parse_json(json_string)
|
|
except Exception:
|
|
for b, e in [('{', '}'), ('[', ']')]:
|
|
if json_string.startswith(b) and json_string.endswith(e):
|
|
return parse_maybe_json(json_string[1:-1])
|
|
return json_string
|
|
if isinstance(parsed_json, str):
|
|
return parsed_json
|
|
if isinstance(parsed_json, (list, dict)):
|
|
return '\n'.join(map(str, (parsed_json.values() if isinstance(parsed_json, dict) else parsed_json)))
|
|
return str(parsed_json)
|
|
|
|
|
|
def same_channel(item1: Dict[str, Any], item2: Dict[str, Any]) -> bool:
|
|
return parse_json(item1['content']).get('channel') == parse_json(item2['content']).get('channel')
|
|
|
|
|
|
class AIMessageBase(object):
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def __str__(self) -> str:
|
|
return json.dumps(vars(self))
|
|
|
|
|
|
class AIMessage(AIMessageBase):
|
|
def __init__(self, user: str, message: str, channel: str = "chat", direct: bool = False, historise_question: bool = True) -> None:
|
|
self.user = user
|
|
self.message = message
|
|
self.channel = channel
|
|
self.direct = direct
|
|
self.historise_question = historise_question
|
|
|
|
|
|
class AIResponse(AIMessageBase):
|
|
def __init__(self,
|
|
answer: Optional[str],
|
|
answer_needed: bool,
|
|
channel: Optional[str],
|
|
staff: Optional[str],
|
|
picture: Optional[str],
|
|
hack: bool
|
|
) -> None:
|
|
self.answer = answer
|
|
self.answer_needed = answer_needed
|
|
self.channel = channel
|
|
self.staff = staff
|
|
self.picture = picture
|
|
self.hack = hack
|
|
|
|
|
|
class AIResponderBase(object):
|
|
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.channel = channel if channel is not None else 'system'
|
|
|
|
|
|
class AIResponder(AIResponderBase):
|
|
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
|
super().__init__(config, channel)
|
|
self.history: List[Dict[str, Any]] = []
|
|
self.memory: str = 'I am an assistant.'
|
|
self.rate_limit_backoff = exponential_backoff()
|
|
self.history_file: Optional[Path] = None
|
|
self.memory_file: Optional[Path] = None
|
|
if 'history-directory' in self.config:
|
|
self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat'
|
|
if self.history_file.exists():
|
|
with open(self.history_file, 'rb') as fd:
|
|
self.history = pickle.load(fd)
|
|
self.memory_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.memory'
|
|
if self.memory_file.exists():
|
|
with open(self.memory_file, 'rb') as fd:
|
|
self.memory = pickle.load(fd)
|
|
logging.info(f'memmory:\n{self.memory}')
|
|
|
|
def message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
messages = []
|
|
system = self.config.get(self.channel, self.config['system'])
|
|
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
|
|
.replace('{time}', time.strftime('%H:%M:%S'))
|
|
news_feed = self.config.get('news')
|
|
if news_feed and os.path.exists(news_feed):
|
|
with open(news_feed) as fd:
|
|
news_feed = fd.read().strip()
|
|
system = system.replace('{news}', news_feed)
|
|
system = system.replace('{memory}', self.memory)
|
|
messages.append({"role": "system", "content": system})
|
|
if limit is not None:
|
|
while len(self.history) > limit:
|
|
self.shrink_history_by_one()
|
|
for msg in self.history:
|
|
messages.append(msg)
|
|
messages.append({"role": "user", "content": str(message)})
|
|
return messages
|
|
|
|
async def draw(self, description: str) -> BytesIO:
|
|
if self.config.get('leonardo-token') is not None:
|
|
return await self.draw_leonardo(description)
|
|
return await self.draw_openai(description)
|
|
|
|
async def draw_leonardo(self, description: str) -> BytesIO:
|
|
raise NotImplementedError()
|
|
|
|
async def draw_openai(self, description: str) -> BytesIO:
|
|
raise NotImplementedError()
|
|
|
|
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
|
|
for fld in ('answer', 'channel', 'staff', 'picture', 'hack'):
|
|
if str(response.get(fld)).strip().lower() in \
|
|
('none', '', 'null', '"none"', '"null"', "'none'", "'null'"):
|
|
response[fld] = None
|
|
for fld in ('answer_needed', 'hack'):
|
|
if str(response.get(fld)).strip().lower() == 'true':
|
|
response[fld] = True
|
|
else:
|
|
response[fld] = False
|
|
if response['answer'] is None:
|
|
response['answer_needed'] = False
|
|
else:
|
|
response['answer'] = str(response['answer'])
|
|
response['answer'] = re.sub(r'@\[([^\]]*)\]\([^\)]*\)', r'\1', response['answer'])
|
|
response['answer'] = re.sub(r'\[[^\]]*\]\(([^\)]*)\)', r'\1', response['answer'])
|
|
if message.direct or message.user in message.message:
|
|
response['answer_needed'] = True
|
|
response_message = AIResponse(response['answer'],
|
|
response['answer_needed'],
|
|
parse_maybe_json(response['channel']),
|
|
parse_maybe_json(response['staff']),
|
|
parse_maybe_json(response['picture']),
|
|
response['hack'])
|
|
if response_message.staff is not None and response_message.answer is not None:
|
|
response_message.answer_needed = True
|
|
if response_message.channel is None:
|
|
response_message.channel = message.channel
|
|
return response_message
|
|
|
|
def short_path(self, message: AIMessage, limit: int) -> bool:
|
|
if message.direct or 'short-path' not in self.config:
|
|
return False
|
|
for chan_re, user_re in self.config['short-path']:
|
|
chan_ma = re.match(chan_re, message.channel)
|
|
user_ma = re.match(user_re, message.user)
|
|
if chan_ma and user_ma:
|
|
self.history.append({"role": "user", "content": str(message)})
|
|
while len(self.history) > limit:
|
|
self.shrink_history_by_one()
|
|
if self.history_file is not None:
|
|
with open(self.history_file, 'wb') as fd:
|
|
pickle.dump(self.history, fd)
|
|
return True
|
|
return False
|
|
|
|
async def chat(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
|
raise NotImplementedError()
|
|
|
|
async def fix(self, answer: str) -> str:
|
|
raise NotImplementedError()
|
|
|
|
async def memory_rewrite(self, memory: str, message_user: str, answer_user: str, question: str, answer: str) -> str:
|
|
raise NotImplementedError()
|
|
|
|
async def translate(self, text: str, language: str = "english") -> str:
|
|
raise NotImplementedError()
|
|
|
|
def shrink_history_by_one(self, index: int = 0) -> None:
|
|
if index >= len(self.history):
|
|
del self.history[0]
|
|
else:
|
|
current = self.history[index]
|
|
count = sum(1 for item in self.history if same_channel(item, current))
|
|
if count > self.config.get('history-per-channel', 3):
|
|
del self.history[index]
|
|
else:
|
|
self.shrink_history_by_one(index + 1)
|
|
|
|
def update_history(self,
|
|
question: Dict[str, Any],
|
|
answer: Dict[str, Any],
|
|
limit: int,
|
|
historise_question: bool = True) -> None:
|
|
if historise_question:
|
|
self.history.append(question)
|
|
self.history.append(answer)
|
|
while len(self.history) > limit:
|
|
self.shrink_history_by_one()
|
|
if self.history_file is not None:
|
|
with open(self.history_file, 'wb') as fd:
|
|
pickle.dump(self.history, fd)
|
|
|
|
def update_memory(self, memory) -> None:
|
|
if self.memory_file is not None:
|
|
with open(self.memory_file, 'wb') as fd:
|
|
pickle.dump(self.memory, fd)
|
|
|
|
async def handle_picture(self, response: Dict) -> bool:
|
|
if not isinstance(response.get("picture"), (type(None), str)):
|
|
logging.warning(f"picture key is wrong in response: {pp(response)}")
|
|
return False
|
|
if response.get("picture") is not None:
|
|
response["picture"] = await self.translate(response["picture"])
|
|
return True
|
|
|
|
async def memoize(self, message_user: str, answer_user: str, message: str, answer: str) -> None:
|
|
self.memory = await self.memory_rewrite(self.memory, message_user, answer_user, message, answer)
|
|
self.update_memory(self.memory)
|
|
|
|
async def memoize_reaction(self, message_user: str, reaction_user: str, operation: str, reaction: str, message: str) -> None:
|
|
quoted_message = message.replace('\n', '\n> ')
|
|
await self.memoize(message_user, 'assistant',
|
|
f'\n> {quoted_message}',
|
|
f'User {reaction_user} has {operation} this raction: {reaction}')
|
|
|
|
async def send(self, message: AIMessage) -> AIResponse:
|
|
# Get the history limit from the configuration
|
|
limit = self.config["history-limit"]
|
|
|
|
# Check if a short path applies, return an empty AIResponse if it does
|
|
if self.short_path(message, limit):
|
|
return AIResponse(None, False, None, None, None, False)
|
|
|
|
# Number of retries for sending the message
|
|
retries = 3
|
|
|
|
while retries > 0:
|
|
# Get the message queue
|
|
messages = self.message(message, limit)
|
|
logging.info(f"try to send this messages:\n{pp(messages)}")
|
|
|
|
# Attempt to send the message to the AI
|
|
answer, limit = await self.chat(messages, limit)
|
|
|
|
if answer is None:
|
|
continue
|
|
|
|
# Attempt to parse the AI's response
|
|
try:
|
|
response = parse_json(answer['content'])
|
|
except Exception as err:
|
|
logging.warning(f"failed to parse the answer: {pp(err)}\n{repr(answer['content'])}")
|
|
answer['content'] = await self.fix(answer['content'])
|
|
|
|
# Retry parsing the fixed content
|
|
try:
|
|
response = parse_json(answer['content'])
|
|
except Exception as err:
|
|
logging.error(f"failed to parse the fixed answer: {pp(err)}\n{repr(answer['content'])}")
|
|
retries -= 1
|
|
continue
|
|
|
|
if not await self.handle_picture(response):
|
|
retries -= 1
|
|
continue
|
|
|
|
# Post-process the message and update the answer's content
|
|
answer_message = await self.post_process(message, response)
|
|
answer['content'] = str(answer_message)
|
|
|
|
# Update message history
|
|
self.update_history(messages[-1], answer, limit, message.historise_question)
|
|
logging.info(f"got this answer:\n{str(answer_message)}")
|
|
|
|
# Update memory
|
|
if answer_message.answer is not None:
|
|
await self.memoize(message.user, 'assistant', message.message, answer_message.answer)
|
|
|
|
# Return the updated answer message
|
|
return answer_message
|
|
|
|
# Raise an error if the process failed after all retries
|
|
raise RuntimeError("Failed to generate answer after multiple retries")
|