discord_bot/fjerkroa_bot/ai_responder.py

473 lines
21 KiB
Python

import sys
import os
import json
import asyncio
import random
import multiline
import openai
import aiohttp
import logging
import time
import re
import pickle
from pathlib import Path
from io import BytesIO
from pprint import pformat
from functools import lru_cache, wraps
from typing import Optional, List, Dict, Any, Tuple
def pp(*args, **kw):
if 'width' not in kw:
kw['width'] = 300
return pformat(*args, **kw)
@lru_cache(maxsize=300)
def parse_json(content: str) -> Dict:
content = content.strip()
try:
return json.loads(content)
except Exception:
try:
return multiline.loads(content, multiline=True)
except Exception as err:
raise err
def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1, max_attempts=None):
"""Generate sleep intervals for exponential backoff with jitter.
Args:
base: Base of the exponentiation operation
max_delay: Maximum delay
factor: Multiplication factor for each increase in backoff
jitter: Additional randomness range to prevent thundering herd problem
Yields:
Delay for backoff as a floating point number.
"""
attempt = 0
while True:
sleep = min(max_delay, factor * base ** attempt)
jitter_amount = jitter * sleep
sleep += random.uniform(-jitter_amount, jitter_amount)
yield sleep
attempt += 1
if max_attempts is not None and attempt > max_attempts:
raise RuntimeError("Max attempts reached in exponential backoff.")
def async_cache_to_file(filename):
cache_file = Path(filename)
cache = None
if cache_file.exists():
try:
with cache_file.open('rb') as fd:
cache = pickle.load(fd)
except Exception:
cache = {}
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
if cache is None:
sys.stderr.write(f'@@@ forward {func.__name__}({repr(args)}, {repr(kwargs)}')
return await func(*args, **kwargs)
key = json.dumps((func.__name__, list(args[1:]), kwargs), sort_keys=True)
if key in cache:
sys.stderr.write(f'@@@ cache {func.__name__}({repr(args)}, {repr(kwargs)} -> {cache[key]}')
return cache[key]
sys.stderr.write(f'@@@ execute {func.__name__}({repr(args)}, {repr(kwargs)}')
result = await func(*args, **kwargs)
cache[key] = result
with cache_file.open('wb') as fd:
pickle.dump(cache, fd)
return result
return wrapper
return decorator
@async_cache_to_file('openai_chat.dat')
async def openai_chat(client, *args, **kwargs):
return await client.chat.completions.create(*args, **kwargs)
@async_cache_to_file('openai_chat.dat')
async def openai_image(client, *args, **kwargs):
return await client.images.generate(*args, **kwargs)
def parse_maybe_json(json_string):
if json_string is None:
return None
if isinstance(json_string, (list, dict)):
return ' '.join(map(str, (json_string.values() if isinstance(json_string, dict) else json_string)))
json_string = str(json_string).strip()
try:
parsed_json = parse_json(json_string)
except Exception:
for b, e in [('{', '}'), ('[', ']')]:
if json_string.startswith(b) and json_string.endswith(e):
return parse_maybe_json(json_string[1:-1])
return json_string
if isinstance(parsed_json, str):
return parsed_json
if isinstance(parsed_json, (list, dict)):
return '\n'.join(map(str, (parsed_json.values() if isinstance(parsed_json, dict) else parsed_json)))
return str(parsed_json)
def same_channel(item1: Dict[str, Any], item2: Dict[str, Any]) -> bool:
return parse_json(item1['content']).get('channel') == parse_json(item2['content']).get('channel')
class AIMessageBase(object):
def __init__(self) -> None:
pass
def __str__(self) -> str:
return json.dumps(vars(self))
class AIMessage(AIMessageBase):
def __init__(self, user: str, message: str, channel: str = "chat", direct: bool = False, historise_question: bool = True) -> None:
self.user = user
self.message = message
self.channel = channel
self.direct = direct
self.historise_question = historise_question
class AIResponse(AIMessageBase):
def __init__(self,
answer: Optional[str],
answer_needed: bool,
channel: Optional[str],
staff: Optional[str],
picture: Optional[str],
hack: bool
) -> None:
self.answer = answer
self.answer_needed = answer_needed
self.channel = channel
self.staff = staff
self.picture = picture
self.hack = hack
class AIResponder(object):
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
self.config = config
self.history: List[Dict[str, Any]] = []
self.channel = channel if channel is not None else 'system'
self.client = openai.AsyncOpenAI(api_key=self.config['openai-token'])
self.rate_limit_backoff = exponential_backoff()
self.history_file: Optional[Path] = None
if 'history-directory' in self.config:
self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat'
if self.history_file.exists():
with open(self.history_file, 'rb') as fd:
self.history = pickle.load(fd)
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
messages = []
system = self.config.get(self.channel, self.config['system'])
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
.replace('{time}', time.strftime('%H:%M:%S'))
news_feed = self.config.get('news')
if news_feed and os.path.exists(news_feed):
with open(news_feed) as fd:
news_feed = fd.read().strip()
system = system.replace('{news}', news_feed)
messages.append({"role": "system", "content": system})
if limit is not None:
while len(self.history) > limit:
self.shrink_history_by_one()
for msg in self.history:
messages.append(msg)
messages.append({"role": "user", "content": str(message)})
return messages
async def draw(self, description: str) -> BytesIO:
if self.config.get('leonardo-token') is not None:
return await self._draw_leonardo(description)
return await self._draw_openai(description)
async def _draw_openai(self, description: str) -> BytesIO:
for _ in range(3):
try:
response = await openai_image(self.client, prompt=description, n=1, size="1024x1024", model="dall-e-3")
async with aiohttp.ClientSession() as session:
async with session.get(response.data[0].url) as image:
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
return BytesIO(await image.read())
except Exception as err:
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
async def _draw_leonardo(self, description: str) -> BytesIO:
error_backoff = exponential_backoff(max_attempts=12)
generation_id = None
image_url = None
image_bytes = None
while True:
error_sleep = next(error_backoff)
try:
async with aiohttp.ClientSession() as session:
if generation_id is None:
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
json={"prompt": description,
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
"num_images": 1,
"sd_version": "v2",
"promptMagic": True,
"unzoomAmount": 1,
"width": 512,
"height": 512},
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
"Accept": "application/json",
"Content-Type": "application/json"},
) as response:
response = await response.json()
if "sdGenerationJob" not in response:
logging.warning(f"No 'sdGenerationJob' found in response, sleep for {error_sleep}s: {repr(response)}")
await asyncio.sleep(error_sleep)
continue
generation_id = response["sdGenerationJob"]["generationId"]
if image_url is None:
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
"Accept": "application/json"},
) as response:
response = await response.json()
if "generations_by_pk" not in response:
logging.warning(f"Unexpected response, sleep for {error_sleep}s: {repr(response)}")
await asyncio.sleep(error_sleep)
continue
if len(response["generations_by_pk"]["generated_images"]) == 0:
await asyncio.sleep(error_sleep)
continue
image_url = response["generations_by_pk"]["generated_images"][0]["url"]
if image_bytes is None:
async with session.get(image_url) as response:
image_bytes = BytesIO(await response.read())
async with session.delete(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
headers={"Authorization": f"Bearer {self.config['leonardo-token']}"},
) as response:
await response.json()
logging.info(f'Drawed a picture with leonardo AI on this description: {repr(description)}')
return image_bytes
except Exception as err:
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}\n{repr(err)}")
else:
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}")
await asyncio.sleep(error_sleep)
raise RuntimeError(f"Failed to generate image {repr(description)}")
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
for fld in ('answer', 'channel', 'staff', 'picture', 'hack'):
if str(response.get(fld)).strip().lower() in \
('none', '', 'null', '"none"', '"null"', "'none'", "'null'"):
response[fld] = None
for fld in ('answer_needed', 'hack'):
if str(response.get(fld)).strip().lower() == 'true':
response[fld] = True
else:
response[fld] = False
if response['answer'] is None:
response['answer_needed'] = False
else:
response['answer'] = str(response['answer'])
response['answer'] = re.sub(r'@\[([^\]]*)\]\([^\)]*\)', r'\1', response['answer'])
response['answer'] = re.sub(r'\[[^\]]*\]\(([^\)]*)\)', r'\1', response['answer'])
if message.direct or message.user in message.message:
response['answer_needed'] = True
response_message = AIResponse(response['answer'],
response['answer_needed'],
parse_maybe_json(response['channel']),
parse_maybe_json(response['staff']),
parse_maybe_json(response['picture']),
response['hack'])
if response_message.staff is not None and response_message.answer is not None:
response_message.answer_needed = True
if response_message.channel is None:
response_message.channel = message.channel
return response_message
def short_path(self, message: AIMessage, limit: int) -> bool:
if message.direct or 'short-path' not in self.config:
return False
for chan_re, user_re in self.config['short-path']:
chan_ma = re.match(chan_re, message.channel)
user_ma = re.match(user_re, message.user)
if chan_ma and user_ma:
self.history.append({"role": "user", "content": str(message)})
while len(self.history) > limit:
self.shrink_history_by_one()
if self.history_file is not None:
with open(self.history_file, 'wb') as fd:
pickle.dump(self.history, fd)
return True
return False
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
model = self.config["model"]
try:
result = await openai_chat(self.client,
model=model,
messages=messages,
temperature=self.config["temperature"],
max_tokens=self.config["max-tokens"],
top_p=self.config["top-p"],
presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"])
answer_obj = result.choices[0].message
answer = {'content': answer_obj.content, 'role': answer_obj.role}
self.rate_limit_backoff = exponential_backoff()
logging.info(f"generated response {result.usage}: {repr(answer)}")
return answer, limit
except openai.BadRequestError as err:
if 'maximum context length is' in str(err) and limit > 4:
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
limit -= 1
return None, limit
raise err
except openai.RateLimitError as err:
rate_limit_sleep = next(self.rate_limit_backoff)
if "retry-model" in self.config:
model = self.config["retry-model"]
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
await asyncio.sleep(rate_limit_sleep)
except Exception as err:
logging.warning(f"failed to generate response: {repr(err)}")
return None, limit
async def fix(self, answer: str) -> str:
if 'fix-model' not in self.config:
return answer
messages = [{"role": "system", "content": self.config["fix-description"]},
{"role": "user", "content": answer}]
try:
result = await openai_chat(self.client,
model=self.config["fix-model"],
messages=messages,
temperature=0.2,
max_tokens=2048)
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
response = result.choices[0].message.content
start, end = response.find("{"), response.rfind("}")
if start == -1 or end == -1 or (start + 3) >= end:
return answer
response = response[start:end + 1]
logging.info(f"fixed answer:\n{pp(response)}")
return response
except Exception as err:
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
return answer
async def translate(self, text: str, language: str = "english") -> str:
if 'fix-model' not in self.config:
return text
message = [{"role": "system", "content": f"You are an professional translator to {language} language,"
f" you translate everything you get directly to {language}"
f" if it is not already in {language}, otherwise you just copy it."},
{"role": "user", "content": text}]
try:
result = await openai_chat(self.client,
model=self.config["fix-model"],
messages=message,
temperature=0.2,
max_tokens=2048)
response = result.choices[0].message.content
logging.info(f"got this translated message:\n{pp(response)}")
return response
except Exception as err:
logging.warning(f"failed to translate the text: {repr(err)}")
return text
def shrink_history_by_one(self, index: int = 0) -> None:
if index >= len(self.history):
del self.history[0]
else:
current = self.history[index]
count = sum(1 for item in self.history if same_channel(item, current))
if count > self.config.get('history-per-channel', 3):
del self.history[index]
else:
self.shrink_history_by_one(index + 1)
def update_history(self,
question: Dict[str, Any],
answer: Dict[str, Any],
limit: int,
historise_question: bool = True) -> None:
if historise_question:
self.history.append(question)
self.history.append(answer)
while len(self.history) > limit:
self.shrink_history_by_one()
if self.history_file is not None:
with open(self.history_file, 'wb') as fd:
pickle.dump(self.history, fd)
async def send(self, message: AIMessage) -> AIResponse:
# Get the history limit from the configuration
limit = self.config["history-limit"]
# Check if a short path applies, return an empty AIResponse if it does
if self.short_path(message, limit):
return AIResponse(None, False, None, None, None, False)
# Number of retries for sending the message
retries = 3
while retries > 0:
# Get the message queue
messages = self._message(message, limit)
logging.info(f"try to send this messages:\n{pp(messages)}")
# Attempt to send the message to the AI
answer, limit = await self._acreate(messages, limit)
if answer is None:
continue
# Attempt to parse the AI's response
try:
response = parse_json(answer['content'])
except Exception as err:
logging.warning(f"failed to parse the answer: {pp(err)}\n{repr(answer['content'])}")
answer['content'] = await self.fix(answer['content'])
# Retry parsing the fixed content
try:
response = parse_json(answer['content'])
except Exception as err:
logging.error(f"failed to parse the fixed answer: {pp(err)}\n{repr(answer['content'])}")
retries -= 1
continue
# Check if the response has the correct picture format
if not isinstance(response.get("picture"), (type(None), str)):
logging.warning(f"picture key is wrong in response: {pp(response)}")
retries -= 1
continue
if response.get("picture") is not None:
response["picture"] = await self.translate(response["picture"])
# Post-process the message and update the answer's content
answer_message = await self.post_process(message, response)
answer['content'] = str(answer_message)
# Update message history
self.update_history(messages[-1], answer, limit, message.historise_question)
logging.info(f"got this answer:\n{str(answer_message)}")
# Return the updated answer message
return answer_message
# Raise an error if the process failed after all retries
raise RuntimeError("Failed to generate answer after multiple retries")