Support saving history.

This commit is contained in:
OK 2023-03-24 18:27:30 +01:00
parent 6d2a3d6ac5
commit b82477fc83
2 changed files with 19 additions and 9 deletions

View File

@ -4,6 +4,8 @@ import aiohttp
import logging
import time
import re
import pickle
from pathlib import Path
from io import BytesIO
from pprint import pformat
from typing import Optional, List, Dict, Any, Tuple
@ -39,6 +41,12 @@ class AIResponder(object):
self.history: List[Dict[str, Any]] = []
self.channel = channel if channel is not None else 'system'
openai.api_key = self.config['openai-token']
self.history_file: Optional[Path] = None
if 'history-directory' in self.config:
self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat'
if self.history_file.exists():
with open(self.history_file, 'rb') as fd:
self.history = pickle.load(fd)
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
messages = []
@ -133,6 +141,15 @@ class AIResponder(object):
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
return answer
def update_history(self, question: Dict[str, Any], answer: Dict[str, Any], limit: int) -> None:
self.history.append(question)
self.history.append(answer)
if len(self.history) > limit:
self.history = self.history[-limit:]
if self.history_file is not None:
with open(self.history_file, 'wb') as fd:
pickle.dump(self.history, fd)
async def send(self, message: AIMessage) -> AIResponse:
limit = self.config["history-limit"]
if self.short_path(message, limit):
@ -156,9 +173,6 @@ class AIResponder(object):
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
continue
logging.info(f"got this answer:\n{pformat(response)}")
self.history.append(messages[-1])
self.history.append(answer)
if len(self.history) > limit:
self.history = self.history[-limit:]
self.update_history(messages[-1], answer, limit)
return await self.post_process(response)
raise RuntimeError("Failed to generate answer after multiple retries")

View File

@ -111,12 +111,8 @@ class TestFunctionality(TestBotBase):
async def test_on_message_event3(self) -> None:
async def acreate(*a, **kw):
return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]}
def logging_warning(msg):
raise RuntimeError(msg)
message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
patch.object(logging, 'warning', logging_warning):
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
await self.bot.on_message(message)
self.bot.staff_channel.send.assert_called_once_with("ERROR: I could not parse this answer: '{ \"test\": 3 ]'", suppress_embeds=True)