- Fix infinite retry loop in ai_responder.py that caused test_fix1 to hang - Add missing picture_edit parameter to all AIResponse constructor calls - Set up complete development toolchain with Black, isort, Bandit, and MyPy - Create comprehensive Makefile for development workflows - Add pre-commit hooks with formatting, linting, security, and type checking - Update test mocking to provide contextual responses for different scenarios - Configure all tools for 140 character line length and strict type checking - Add DEVELOPMENT.md with setup instructions and workflow documentation 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
378 lines
15 KiB
Python
378 lines
15 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import pickle
|
|
import random
|
|
import re
|
|
import time
|
|
from functools import lru_cache, wraps
|
|
from io import BytesIO
|
|
from pathlib import Path
|
|
from pprint import pformat
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import multiline
|
|
|
|
|
|
def pp(*args, **kw):
|
|
if "width" not in kw:
|
|
kw["width"] = 300
|
|
return pformat(*args, **kw)
|
|
|
|
|
|
@lru_cache(maxsize=300)
|
|
def parse_json(content: str) -> Dict:
|
|
content = content.strip()
|
|
try:
|
|
return json.loads(content)
|
|
except Exception:
|
|
try:
|
|
return multiline.loads(content, multiline=True)
|
|
except Exception as err:
|
|
raise err
|
|
|
|
|
|
def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1, max_attempts=None):
|
|
"""Generate sleep intervals for exponential backoff with jitter.
|
|
|
|
Args:
|
|
base: Base of the exponentiation operation
|
|
max_delay: Maximum delay
|
|
factor: Multiplication factor for each increase in backoff
|
|
jitter: Additional randomness range to prevent thundering herd problem
|
|
|
|
Yields:
|
|
Delay for backoff as a floating point number.
|
|
"""
|
|
attempt = 0
|
|
while True:
|
|
sleep = min(max_delay, factor * base**attempt)
|
|
jitter_amount = jitter * sleep
|
|
sleep += random.uniform(-jitter_amount, jitter_amount)
|
|
yield sleep
|
|
attempt += 1
|
|
if max_attempts is not None and attempt > max_attempts:
|
|
raise RuntimeError("Max attempts reached in exponential backoff.")
|
|
|
|
|
|
def async_cache_to_file(filename):
|
|
cache_file = Path(filename)
|
|
cache = None
|
|
if cache_file.exists():
|
|
try:
|
|
with cache_file.open("rb") as fd:
|
|
cache = pickle.load(fd)
|
|
except Exception:
|
|
cache = {}
|
|
|
|
def decorator(func):
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
if cache is None:
|
|
return await func(*args, **kwargs)
|
|
key = json.dumps((func.__name__, list(args[1:]), kwargs), sort_keys=True)
|
|
if key in cache:
|
|
return cache[key]
|
|
result = await func(*args, **kwargs)
|
|
cache[key] = result
|
|
with cache_file.open("wb") as fd:
|
|
pickle.dump(cache, fd)
|
|
return result
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def parse_maybe_json(json_string):
|
|
if json_string is None:
|
|
return None
|
|
if isinstance(json_string, (list, dict)):
|
|
return " ".join(map(str, (json_string.values() if isinstance(json_string, dict) else json_string)))
|
|
json_string = str(json_string).strip()
|
|
try:
|
|
parsed_json = parse_json(json_string)
|
|
except Exception:
|
|
for b, e in [("{", "}"), ("[", "]")]:
|
|
if json_string.startswith(b) and json_string.endswith(e):
|
|
return parse_maybe_json(json_string[1:-1])
|
|
return json_string
|
|
if isinstance(parsed_json, str):
|
|
return parsed_json
|
|
if isinstance(parsed_json, (list, dict)):
|
|
return "\n".join(map(str, (parsed_json.values() if isinstance(parsed_json, dict) else parsed_json)))
|
|
return str(parsed_json)
|
|
|
|
|
|
def same_channel(item1: Dict[str, Any], item2: Dict[str, Any]) -> bool:
|
|
return parse_json(item1["content"]).get("channel") == parse_json(item2["content"]).get("channel")
|
|
|
|
|
|
class AIMessageBase(object):
|
|
def __init__(self) -> None:
|
|
self.vars: List[str] = []
|
|
|
|
def __str__(self) -> str:
|
|
return json.dumps({k: v for k, v in vars(self).items() if k in self.vars})
|
|
|
|
|
|
class AIMessage(AIMessageBase):
|
|
def __init__(self, user: str, message: str, channel: str = "chat", direct: bool = False, historise_question: bool = True) -> None:
|
|
self.user = user
|
|
self.message = message
|
|
self.urls: Optional[List[str]] = None
|
|
self.channel = channel
|
|
self.direct = direct
|
|
self.historise_question = historise_question
|
|
self.vars = ["user", "message", "channel", "direct", "historise_question"]
|
|
|
|
|
|
class AIResponse(AIMessageBase):
|
|
def __init__(
|
|
self,
|
|
answer: Optional[str],
|
|
answer_needed: bool,
|
|
channel: Optional[str],
|
|
staff: Optional[str],
|
|
picture: Optional[str],
|
|
picture_edit: bool,
|
|
hack: bool,
|
|
) -> None:
|
|
self.answer = answer
|
|
self.answer_needed = answer_needed
|
|
self.channel = channel
|
|
self.staff = staff
|
|
self.picture = picture
|
|
self.picture_edit = picture_edit
|
|
self.hack = hack
|
|
self.vars = ["answer", "answer_needed", "channel", "staff", "picture", "hack"]
|
|
|
|
|
|
class AIResponderBase(object):
|
|
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self.channel = channel if channel is not None else "system"
|
|
|
|
|
|
class AIResponder(AIResponderBase):
|
|
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
|
super().__init__(config, channel)
|
|
self.history: List[Dict[str, Any]] = []
|
|
self.memory: str = "I am an assistant."
|
|
self.rate_limit_backoff = exponential_backoff()
|
|
self.history_file: Optional[Path] = None
|
|
self.memory_file: Optional[Path] = None
|
|
if "history-directory" in self.config:
|
|
self.history_file = Path(self.config["history-directory"]).expanduser() / f"{self.channel}.dat"
|
|
if self.history_file.exists():
|
|
with open(self.history_file, "rb") as fd:
|
|
self.history = pickle.load(fd)
|
|
self.memory_file = Path(self.config["history-directory"]).expanduser() / f"{self.channel}.memory"
|
|
if self.memory_file.exists():
|
|
with open(self.memory_file, "rb") as fd:
|
|
self.memory = pickle.load(fd)
|
|
logging.info(f"memmory:\n{self.memory}")
|
|
|
|
def message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
messages = []
|
|
system = self.config.get(self.channel, self.config["system"])
|
|
system = system.replace("{date}", time.strftime("%Y-%m-%d")).replace("{time}", time.strftime("%H:%M:%S"))
|
|
news_feed = self.config.get("news")
|
|
if news_feed and os.path.exists(news_feed):
|
|
with open(news_feed) as fd:
|
|
news_feed = fd.read().strip()
|
|
system = system.replace("{news}", news_feed)
|
|
system = system.replace("{memory}", self.memory)
|
|
messages.append({"role": "system", "content": system})
|
|
if limit is not None:
|
|
while len(self.history) > limit:
|
|
self.shrink_history_by_one()
|
|
for msg in self.history:
|
|
messages.append(msg)
|
|
if not message.urls:
|
|
messages.append({"role": "user", "content": str(message)})
|
|
else:
|
|
content: List[Dict[str, Union[str, Dict[str, str]]]] = [{"type": "text", "text": str(message)}]
|
|
for url in message.urls:
|
|
content.append({"type": "image_url", "image_url": {"url": url}})
|
|
messages.append({"role": "user", "content": content})
|
|
return messages
|
|
|
|
async def draw(self, description: str) -> BytesIO:
|
|
if self.config.get("leonardo-token") is not None:
|
|
return await self.draw_leonardo(description)
|
|
return await self.draw_openai(description)
|
|
|
|
async def draw_leonardo(self, description: str) -> BytesIO:
|
|
raise NotImplementedError()
|
|
|
|
async def draw_openai(self, description: str) -> BytesIO:
|
|
raise NotImplementedError()
|
|
|
|
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
|
|
for fld in ("answer", "channel", "staff", "picture", "hack"):
|
|
if str(response.get(fld)).strip().lower() in ("none", "", "null", '"none"', '"null"', "'none'", "'null'"):
|
|
response[fld] = None
|
|
for fld in ("answer_needed", "hack", "picture_edit"):
|
|
if str(response.get(fld)).strip().lower() == "true":
|
|
response[fld] = True
|
|
else:
|
|
response[fld] = False
|
|
if response["answer"] is None:
|
|
response["answer_needed"] = False
|
|
else:
|
|
response["answer"] = str(response["answer"])
|
|
response["answer"] = re.sub(r"@\[([^\]]*)\]\([^\)]*\)", r"\1", response["answer"])
|
|
response["answer"] = re.sub(r"\[[^\]]*\]\(([^\)]*)\)", r"\1", response["answer"])
|
|
if message.direct or message.user in message.message:
|
|
response["answer_needed"] = True
|
|
response_message = AIResponse(
|
|
response["answer"],
|
|
response["answer_needed"],
|
|
parse_maybe_json(response["channel"]),
|
|
parse_maybe_json(response["staff"]),
|
|
parse_maybe_json(response["picture"]),
|
|
response["picture_edit"],
|
|
response["hack"],
|
|
)
|
|
if response_message.staff is not None and response_message.answer is not None:
|
|
response_message.answer_needed = True
|
|
if response_message.channel is None:
|
|
response_message.channel = message.channel
|
|
return response_message
|
|
|
|
def short_path(self, message: AIMessage, limit: int) -> bool:
|
|
if message.direct or "short-path" not in self.config:
|
|
return False
|
|
for chan_re, user_re in self.config["short-path"]:
|
|
chan_ma = re.match(chan_re, message.channel)
|
|
user_ma = re.match(user_re, message.user)
|
|
if chan_ma and user_ma:
|
|
self.history.append({"role": "user", "content": str(message)})
|
|
while len(self.history) > limit:
|
|
self.shrink_history_by_one()
|
|
if self.history_file is not None:
|
|
with open(self.history_file, "wb") as fd:
|
|
pickle.dump(self.history, fd)
|
|
return True
|
|
return False
|
|
|
|
async def chat(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
|
raise NotImplementedError()
|
|
|
|
async def fix(self, answer: str) -> str:
|
|
raise NotImplementedError()
|
|
|
|
async def memory_rewrite(self, memory: str, message_user: str, answer_user: str, question: str, answer: str) -> str:
|
|
raise NotImplementedError()
|
|
|
|
async def translate(self, text: str, language: str = "english") -> str:
|
|
raise NotImplementedError()
|
|
|
|
def shrink_history_by_one(self, index: int = 0) -> None:
|
|
if index >= len(self.history):
|
|
del self.history[0]
|
|
else:
|
|
current = self.history[index]
|
|
count = sum(1 for item in self.history if same_channel(item, current))
|
|
if count > self.config.get("history-per-channel", 3):
|
|
del self.history[index]
|
|
else:
|
|
self.shrink_history_by_one(index + 1)
|
|
|
|
def update_history(self, question: Dict[str, Any], answer: Dict[str, Any], limit: int, historise_question: bool = True) -> None:
|
|
if not isinstance(question["content"], str):
|
|
question["content"] = question["content"][0]["text"]
|
|
if historise_question:
|
|
self.history.append(question)
|
|
self.history.append(answer)
|
|
while len(self.history) > limit:
|
|
self.shrink_history_by_one()
|
|
if self.history_file is not None:
|
|
with open(self.history_file, "wb") as fd:
|
|
pickle.dump(self.history, fd)
|
|
|
|
def update_memory(self, memory) -> None:
|
|
if self.memory_file is not None:
|
|
with open(self.memory_file, "wb") as fd:
|
|
pickle.dump(self.memory, fd)
|
|
|
|
async def handle_picture(self, response: Dict) -> bool:
|
|
if not isinstance(response.get("picture"), (type(None), str)):
|
|
logging.warning(f"picture key is wrong in response: {pp(response)}")
|
|
return False
|
|
if response.get("picture") is not None:
|
|
response["picture"] = await self.translate(response["picture"])
|
|
return True
|
|
|
|
async def memoize(self, message_user: str, answer_user: str, message: str, answer: str) -> None:
|
|
self.memory = await self.memory_rewrite(self.memory, message_user, answer_user, message, answer)
|
|
self.update_memory(self.memory)
|
|
|
|
async def memoize_reaction(self, message_user: str, reaction_user: str, operation: str, reaction: str, message: str) -> None:
|
|
quoted_message = message.replace("\n", "\n> ")
|
|
await self.memoize(
|
|
message_user, "assistant", f"\n> {quoted_message}", f"User {reaction_user} has {operation} this raction: {reaction}"
|
|
)
|
|
|
|
async def send(self, message: AIMessage) -> AIResponse:
|
|
# Get the history limit from the configuration
|
|
limit = self.config["history-limit"]
|
|
|
|
# Check if a short path applies, return an empty AIResponse if it does
|
|
if self.short_path(message, limit):
|
|
return AIResponse(None, False, None, None, None, False, False)
|
|
|
|
# Number of retries for sending the message
|
|
retries = 3
|
|
|
|
while retries > 0:
|
|
# Get the message queue
|
|
messages = self.message(message, limit)
|
|
logging.info(f"try to send this messages:\n{pp(messages)}")
|
|
|
|
# Attempt to send the message to the AI
|
|
answer, limit = await self.chat(messages, limit)
|
|
|
|
if answer is None:
|
|
retries -= 1
|
|
continue
|
|
|
|
# Attempt to parse the AI's response
|
|
try:
|
|
response = parse_json(answer["content"])
|
|
except Exception as err:
|
|
logging.warning(f"failed to parse the answer: {pp(err)}\n{repr(answer['content'])}")
|
|
answer["content"] = await self.fix(answer["content"])
|
|
|
|
# Retry parsing the fixed content
|
|
try:
|
|
response = parse_json(answer["content"])
|
|
except Exception as err:
|
|
logging.error(f"failed to parse the fixed answer: {pp(err)}\n{repr(answer['content'])}")
|
|
retries -= 1
|
|
continue
|
|
|
|
if not await self.handle_picture(response):
|
|
retries -= 1
|
|
continue
|
|
|
|
# Post-process the message and update the answer's content
|
|
answer_message = await self.post_process(message, response)
|
|
answer["content"] = str(answer_message)
|
|
|
|
# Update message history
|
|
self.update_history(messages[-1], answer, limit, message.historise_question)
|
|
logging.info(f"got this answer:\n{str(answer_message)}")
|
|
|
|
# Update memory
|
|
if answer_message.answer is not None:
|
|
await self.memoize(message.user, "assistant", message.message, answer_message.answer)
|
|
|
|
# Return the updated answer message
|
|
return answer_message
|
|
|
|
# Raise an error if the process failed after all retries
|
|
raise RuntimeError("Failed to generate answer after multiple retries")
|