428 lines
20 KiB
Python
428 lines
20 KiB
Python
import os
|
|
import json
|
|
import asyncio
|
|
import random
|
|
import multiline
|
|
import openai
|
|
import aiohttp
|
|
import logging
|
|
import time
|
|
import re
|
|
import pickle
|
|
from pathlib import Path
|
|
from io import BytesIO
|
|
from pprint import pformat
|
|
from functools import lru_cache
|
|
from typing import Optional, List, Dict, Any, Tuple
|
|
|
|
|
|
def pp(*args, **kw):
|
|
if 'width' not in kw:
|
|
kw['width'] = 300
|
|
return pformat(*args, **kw)
|
|
|
|
|
|
@lru_cache(maxsize=300)
|
|
def parse_json(content: str) -> Dict:
|
|
content = content.strip()
|
|
try:
|
|
return json.loads(content)
|
|
except Exception:
|
|
try:
|
|
return multiline.loads(content, multiline=True)
|
|
except Exception as err:
|
|
raise err
|
|
|
|
|
|
def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1, max_attempts=None):
|
|
"""Generate sleep intervals for exponential backoff with jitter.
|
|
|
|
Args:
|
|
base: Base of the exponentiation operation
|
|
max_delay: Maximum delay
|
|
factor: Multiplication factor for each increase in backoff
|
|
jitter: Additional randomness range to prevent thundering herd problem
|
|
|
|
Yields:
|
|
Delay for backoff as a floating point number.
|
|
"""
|
|
attempt = 0
|
|
while True:
|
|
sleep = min(max_delay, factor * base ** attempt)
|
|
jitter_amount = jitter * sleep
|
|
sleep += random.uniform(-jitter_amount, jitter_amount)
|
|
yield sleep
|
|
attempt += 1
|
|
if max_attempts is not None and attempt > max_attempts:
|
|
raise RuntimeError("Max attempts reached in exponential backoff.")
|
|
|
|
|
|
def parse_maybe_json(json_string):
|
|
if json_string is None:
|
|
return None
|
|
if isinstance(json_string, (list, dict)):
|
|
return ' '.join(map(str, (json_string.values() if isinstance(json_string, dict) else json_string)))
|
|
json_string = str(json_string).strip()
|
|
try:
|
|
parsed_json = parse_json(json_string)
|
|
except Exception:
|
|
for b, e in [('{', '}'), ('[', ']')]:
|
|
if json_string.startswith(b) and json_string.endswith(e):
|
|
return parse_maybe_json(json_string[1:-1])
|
|
return json_string
|
|
if isinstance(parsed_json, str):
|
|
return parsed_json
|
|
if isinstance(parsed_json, (list, dict)):
|
|
return '\n'.join(map(str, (parsed_json.values() if isinstance(parsed_json, dict) else parsed_json)))
|
|
return str(parsed_json)
|
|
|
|
|
|
def same_channel(item1: Dict[str, Any], item2: Dict[str, Any]) -> bool:
|
|
return parse_json(item1['content']).get('channel') == parse_json(item2['content']).get('channel')
|
|
|
|
|
|
class AIMessageBase(object):
|
|
def __init__(self) -> None:
|
|
pass
|
|
|
|
def __str__(self) -> str:
|
|
return json.dumps(vars(self))
|
|
|
|
|
|
class AIMessage(AIMessageBase):
|
|
def __init__(self, user: str, message: str, channel: str = "chat", direct: bool = False, historise_question: bool = True) -> None:
|
|
self.user = user
|
|
self.message = message
|
|
self.channel = channel
|
|
self.direct = direct
|
|
self.historise_question = historise_question
|
|
|
|
|
|
class AIResponse(AIMessageBase):
|
|
def __init__(self,
|
|
answer: Optional[str],
|
|
answer_needed: bool,
|
|
channel: Optional[str],
|
|
staff: Optional[str],
|
|
picture: Optional[str],
|
|
hack: bool
|
|
) -> None:
|
|
self.answer = answer
|
|
self.answer_needed = answer_needed
|
|
self.channel = channel
|
|
self.staff = staff
|
|
self.picture = picture
|
|
self.hack = hack
|
|
|
|
|
|
class AIResponder(object):
|
|
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
|
self.config = config
|
|
self.history: List[Dict[str, Any]] = []
|
|
self.channel = channel if channel is not None else 'system'
|
|
self.client = openai.AsyncOpenAI(api_key=self.config['openai-token'])
|
|
self.rate_limit_backoff = exponential_backoff()
|
|
self.history_file: Optional[Path] = None
|
|
if 'history-directory' in self.config:
|
|
self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat'
|
|
if self.history_file.exists():
|
|
with open(self.history_file, 'rb') as fd:
|
|
self.history = pickle.load(fd)
|
|
|
|
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
messages = []
|
|
system = self.config.get(self.channel, self.config['system'])
|
|
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
|
|
.replace('{time}', time.strftime('%H:%M:%S'))
|
|
news_feed = self.config.get('news')
|
|
if news_feed and os.path.exists(news_feed):
|
|
with open(news_feed) as fd:
|
|
news_feed = fd.read().strip()
|
|
system = system.replace('{news}', news_feed)
|
|
messages.append({"role": "system", "content": system})
|
|
if limit is not None:
|
|
while len(self.history) > limit:
|
|
self.shrink_history_by_one()
|
|
for msg in self.history:
|
|
messages.append(msg)
|
|
messages.append({"role": "user", "content": str(message)})
|
|
return messages
|
|
|
|
async def draw(self, description: str) -> BytesIO:
|
|
if self.config.get('leonardo-token') is not None:
|
|
return await self._draw_leonardo(description)
|
|
return await self._draw_openai(description)
|
|
|
|
async def _draw_openai(self, description: str) -> BytesIO:
|
|
for _ in range(3):
|
|
try:
|
|
response = await self.client.images.generate(prompt=description, n=1, size="1024x1024", model="dall-e-3")
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(response.data[0].url) as image:
|
|
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
|
|
return BytesIO(await image.read())
|
|
except Exception as err:
|
|
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
|
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
|
|
|
async def _draw_leonardo(self, description: str) -> BytesIO:
|
|
error_backoff = exponential_backoff(max_attempts=12)
|
|
generation_id = None
|
|
image_url = None
|
|
image_bytes = None
|
|
while True:
|
|
error_sleep = next(error_backoff)
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
if generation_id is None:
|
|
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
|
|
json={"prompt": description,
|
|
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
|
|
"num_images": 1,
|
|
"sd_version": "v2",
|
|
"promptMagic": True,
|
|
"unzoomAmount": 1,
|
|
"width": 512,
|
|
"height": 512},
|
|
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
|
"Accept": "application/json",
|
|
"Content-Type": "application/json"},
|
|
) as response:
|
|
response = await response.json()
|
|
if "sdGenerationJob" not in response:
|
|
logging.warning(f"No 'sdGenerationJob' found in response, sleep for {error_sleep}s: {repr(response)}")
|
|
await asyncio.sleep(error_sleep)
|
|
continue
|
|
generation_id = response["sdGenerationJob"]["generationId"]
|
|
if image_url is None:
|
|
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
|
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
|
"Accept": "application/json"},
|
|
) as response:
|
|
response = await response.json()
|
|
if "generations_by_pk" not in response:
|
|
logging.warning(f"Unexpected response, sleep for {error_sleep}s: {repr(response)}")
|
|
await asyncio.sleep(error_sleep)
|
|
continue
|
|
if len(response["generations_by_pk"]["generated_images"]) == 0:
|
|
await asyncio.sleep(error_sleep)
|
|
continue
|
|
image_url = response["generations_by_pk"]["generated_images"][0]["url"]
|
|
if image_bytes is None:
|
|
async with session.get(image_url) as response:
|
|
image_bytes = BytesIO(await response.read())
|
|
async with session.delete(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
|
headers={"Authorization": f"Bearer {self.config['leonardo-token']}"},
|
|
) as response:
|
|
await response.json()
|
|
logging.info(f'Drawed a picture with leonardo AI on this description: {repr(description)}')
|
|
return image_bytes
|
|
except Exception as err:
|
|
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}\n{repr(err)}")
|
|
else:
|
|
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}")
|
|
await asyncio.sleep(error_sleep)
|
|
raise RuntimeError(f"Failed to generate image {repr(description)}")
|
|
|
|
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
|
|
for fld in ('answer', 'channel', 'staff', 'picture', 'hack'):
|
|
if str(response.get(fld)).strip().lower() in \
|
|
('none', '', 'null', '"none"', '"null"', "'none'", "'null'"):
|
|
response[fld] = None
|
|
for fld in ('answer_needed', 'hack'):
|
|
if str(response.get(fld)).strip().lower() == 'true':
|
|
response[fld] = True
|
|
else:
|
|
response[fld] = False
|
|
if response['answer'] is None:
|
|
response['answer_needed'] = False
|
|
else:
|
|
response['answer'] = str(response['answer'])
|
|
response['answer'] = re.sub(r'@\[([^\]]*)\]\([^\)]*\)', r'\1', response['answer'])
|
|
response['answer'] = re.sub(r'\[[^\]]*\]\(([^\)]*)\)', r'\1', response['answer'])
|
|
if message.direct or message.user in message.message:
|
|
response['answer_needed'] = True
|
|
response_message = AIResponse(response['answer'],
|
|
response['answer_needed'],
|
|
parse_maybe_json(response['channel']),
|
|
parse_maybe_json(response['staff']),
|
|
parse_maybe_json(response['picture']),
|
|
response['hack'])
|
|
if response_message.staff is not None and response_message.answer is not None:
|
|
response_message.answer_needed = True
|
|
if response_message.channel is None:
|
|
response_message.channel = message.channel
|
|
return response_message
|
|
|
|
def short_path(self, message: AIMessage, limit: int) -> bool:
|
|
if message.direct or 'short-path' not in self.config:
|
|
return False
|
|
for chan_re, user_re in self.config['short-path']:
|
|
chan_ma = re.match(chan_re, message.channel)
|
|
user_ma = re.match(user_re, message.user)
|
|
if chan_ma and user_ma:
|
|
self.history.append({"role": "user", "content": str(message)})
|
|
while len(self.history) > limit:
|
|
self.shrink_history_by_one()
|
|
if self.history_file is not None:
|
|
with open(self.history_file, 'wb') as fd:
|
|
pickle.dump(self.history, fd)
|
|
return True
|
|
return False
|
|
|
|
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[openai.types.chat.ChatCompletionMessage], int]:
|
|
model = self.config["model"]
|
|
try:
|
|
result = await self.client.chat.completions.create(model=model,
|
|
messages=messages,
|
|
temperature=self.config["temperature"],
|
|
max_tokens=self.config["max-tokens"],
|
|
top_p=self.config["top-p"],
|
|
presence_penalty=self.config["presence-penalty"],
|
|
frequency_penalty=self.config["frequency-penalty"])
|
|
answer = result.choices[0].message
|
|
self.rate_limit_backoff = exponential_backoff()
|
|
logging.info(f"generated response {result.usage}: {repr(answer)}")
|
|
return answer, limit
|
|
except openai.BadRequestError as err:
|
|
if 'maximum context length is' in str(err) and limit > 4:
|
|
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
|
|
limit -= 1
|
|
return None, limit
|
|
raise err
|
|
except openai.RateLimitError as err:
|
|
rate_limit_sleep = next(self.rate_limit_backoff)
|
|
if "retry-model" in self.config:
|
|
model = self.config["retry-model"]
|
|
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
|
|
await asyncio.sleep(rate_limit_sleep)
|
|
except Exception as err:
|
|
logging.warning(f"failed to generate response: {repr(err)}")
|
|
return None, limit
|
|
|
|
async def fix(self, answer: str) -> str:
|
|
if 'fix-model' not in self.config:
|
|
return answer
|
|
messages = [{"role": "system", "content": self.config["fix-description"]},
|
|
{"role": "user", "content": answer}]
|
|
try:
|
|
result = await self.client.chat.completions.create(model=self.config["fix-model"],
|
|
messages=messages,
|
|
temperature=0.2,
|
|
max_tokens=2048)
|
|
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
|
|
response = result.choices[0].message.content
|
|
start, end = response.find("{"), response.rfind("}")
|
|
if start == -1 or end == -1 or (start + 3) >= end:
|
|
return answer
|
|
response = response[start:end + 1]
|
|
logging.info(f"fixed answer:\n{pp(response)}")
|
|
return response
|
|
except Exception as err:
|
|
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
|
|
return answer
|
|
|
|
async def translate(self, text: str, language: str = "english") -> str:
|
|
if 'fix-model' not in self.config:
|
|
return text
|
|
message = [{"role": "system", "content": f"You are an professional translator to {language} language,"
|
|
f" you translate everything you get directly to {language}"
|
|
f" if it is not already in {language}, otherwise you just copy it."},
|
|
{"role": "user", "content": text}]
|
|
try:
|
|
result = await self.client.chat.completions.create(model=self.config["fix-model"],
|
|
messages=message,
|
|
temperature=0.2,
|
|
max_tokens=2048)
|
|
response = result.choices[0].message.content
|
|
logging.info(f"got this translated message:\n{pp(response)}")
|
|
return response
|
|
except Exception as err:
|
|
logging.warning(f"failed to translate the text: {repr(err)}")
|
|
return text
|
|
|
|
def shrink_history_by_one(self, index: int = 0) -> None:
|
|
if index >= len(self.history):
|
|
del self.history[0]
|
|
else:
|
|
current = self.history[index]
|
|
count = sum(1 for item in self.history if same_channel(item, current))
|
|
if count > self.config.get('history-per-channel', 3):
|
|
del self.history[index]
|
|
else:
|
|
self.shrink_history_by_one(index + 1)
|
|
|
|
def update_history(self,
|
|
question: Dict[str, Any],
|
|
answer: Dict[str, Any],
|
|
limit: int,
|
|
historise_question: bool = True) -> None:
|
|
if historise_question:
|
|
self.history.append(question)
|
|
self.history.append(answer)
|
|
while len(self.history) > limit:
|
|
self.shrink_history_by_one()
|
|
if self.history_file is not None:
|
|
with open(self.history_file, 'wb') as fd:
|
|
pickle.dump(self.history, fd)
|
|
|
|
async def send(self, message: AIMessage) -> AIResponse:
|
|
# Get the history limit from the configuration
|
|
limit = self.config["history-limit"]
|
|
|
|
# Check if a short path applies, return an empty AIResponse if it does
|
|
if self.short_path(message, limit):
|
|
return AIResponse(None, False, None, None, None, False)
|
|
|
|
# Number of retries for sending the message
|
|
retries = 3
|
|
|
|
while retries > 0:
|
|
# Get the message queue
|
|
messages = self._message(message, limit)
|
|
logging.info(f"try to send this messages:\n{pp(messages)}")
|
|
|
|
# Attempt to send the message to the AI
|
|
answer, limit = await self._acreate(messages, limit)
|
|
|
|
if answer is None:
|
|
continue
|
|
|
|
# Attempt to parse the AI's response
|
|
try:
|
|
response = parse_json(answer.content)
|
|
except Exception as err:
|
|
logging.warning(f"failed to parse the answer: {pp(err)}\n{repr(answer.content)}")
|
|
answer.content = await self.fix(answer['content'])
|
|
|
|
# Retry parsing the fixed content
|
|
try:
|
|
response = parse_json(answer.content)
|
|
except Exception as err:
|
|
logging.error(f"failed to parse the fixed answer: {pp(err)}\n{repr(answer.content)}")
|
|
retries -= 1
|
|
continue
|
|
|
|
# Check if the response has the correct picture format
|
|
if not isinstance(response.get("picture"), (type(None), str)):
|
|
logging.warning(f"picture key is wrong in response: {pp(response)}")
|
|
retries -= 1
|
|
continue
|
|
|
|
if response.get("picture") is not None:
|
|
response["picture"] = await self.translate(response["picture"])
|
|
|
|
# Post-process the message and update the answer's content
|
|
answer_message = await self.post_process(message, response)
|
|
answer.content = str(answer_message)
|
|
|
|
# Update message history
|
|
self.update_history(messages[-1], answer, limit, message.historise_question)
|
|
logging.info(f"got this answer:\n{str(answer_message)}")
|
|
|
|
# Return the updated answer message
|
|
return answer_message
|
|
|
|
# Raise an error if the process failed after all retries
|
|
raise RuntimeError("Failed to generate answer after multiple retries")
|