diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index 7e46eeb..74b5c81 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -4,6 +4,8 @@ import aiohttp import logging import time import re +import pickle +from pathlib import Path from io import BytesIO from pprint import pformat from typing import Optional, List, Dict, Any, Tuple @@ -39,6 +41,12 @@ class AIResponder(object): self.history: List[Dict[str, Any]] = [] self.channel = channel if channel is not None else 'system' openai.api_key = self.config['openai-token'] + self.history_file: Optional[Path] = None + if 'history-directory' in self.config: + self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat' + if self.history_file.exists(): + with open(self.history_file, 'rb') as fd: + self.history = pickle.load(fd) def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]: messages = [] @@ -133,6 +141,15 @@ class AIResponder(object): logging.warning(f"failed to execute a fix for the answer: {repr(err)}") return answer + def update_history(self, question: Dict[str, Any], answer: Dict[str, Any], limit: int) -> None: + self.history.append(question) + self.history.append(answer) + if len(self.history) > limit: + self.history = self.history[-limit:] + if self.history_file is not None: + with open(self.history_file, 'wb') as fd: + pickle.dump(self.history, fd) + async def send(self, message: AIMessage) -> AIResponse: limit = self.config["history-limit"] if self.short_path(message, limit): @@ -156,9 +173,6 @@ class AIResponder(object): if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str): continue logging.info(f"got this answer:\n{pformat(response)}") - self.history.append(messages[-1]) - self.history.append(answer) - if len(self.history) > limit: - self.history = self.history[-limit:] + self.update_history(messages[-1], answer, limit) return await self.post_process(response) raise RuntimeError("Failed to generate answer after multiple retries") diff --git a/tests/test_main.py b/tests/test_main.py index 62fe0f8..9057ce2 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -111,12 +111,8 @@ class TestFunctionality(TestBotBase): async def test_on_message_event3(self) -> None: async def acreate(*a, **kw): return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]} - - def logging_warning(msg): - raise RuntimeError(msg) message = self.create_message("Hello there! How are you?") - with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \ - patch.object(logging, 'warning', logging_warning): + with patch.object(openai.ChatCompletion, 'acreate', new=acreate): await self.bot.on_message(message) self.bot.staff_channel.send.assert_called_once_with("ERROR: I could not parse this answer: '{ \"test\": 3 ]'", suppress_embeds=True)