Support saving history.

This commit is contained in:
OK 2023-03-24 18:27:30 +01:00
parent 6d2a3d6ac5
commit b82477fc83
2 changed files with 19 additions and 9 deletions

View File

@ -4,6 +4,8 @@ import aiohttp
import logging import logging
import time import time
import re import re
import pickle
from pathlib import Path
from io import BytesIO from io import BytesIO
from pprint import pformat from pprint import pformat
from typing import Optional, List, Dict, Any, Tuple from typing import Optional, List, Dict, Any, Tuple
@ -39,6 +41,12 @@ class AIResponder(object):
self.history: List[Dict[str, Any]] = [] self.history: List[Dict[str, Any]] = []
self.channel = channel if channel is not None else 'system' self.channel = channel if channel is not None else 'system'
openai.api_key = self.config['openai-token'] openai.api_key = self.config['openai-token']
self.history_file: Optional[Path] = None
if 'history-directory' in self.config:
self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat'
if self.history_file.exists():
with open(self.history_file, 'rb') as fd:
self.history = pickle.load(fd)
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]: def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
messages = [] messages = []
@ -133,6 +141,15 @@ class AIResponder(object):
logging.warning(f"failed to execute a fix for the answer: {repr(err)}") logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
return answer return answer
def update_history(self, question: Dict[str, Any], answer: Dict[str, Any], limit: int) -> None:
self.history.append(question)
self.history.append(answer)
if len(self.history) > limit:
self.history = self.history[-limit:]
if self.history_file is not None:
with open(self.history_file, 'wb') as fd:
pickle.dump(self.history, fd)
async def send(self, message: AIMessage) -> AIResponse: async def send(self, message: AIMessage) -> AIResponse:
limit = self.config["history-limit"] limit = self.config["history-limit"]
if self.short_path(message, limit): if self.short_path(message, limit):
@ -156,9 +173,6 @@ class AIResponder(object):
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str): if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
continue continue
logging.info(f"got this answer:\n{pformat(response)}") logging.info(f"got this answer:\n{pformat(response)}")
self.history.append(messages[-1]) self.update_history(messages[-1], answer, limit)
self.history.append(answer)
if len(self.history) > limit:
self.history = self.history[-limit:]
return await self.post_process(response) return await self.post_process(response)
raise RuntimeError("Failed to generate answer after multiple retries") raise RuntimeError("Failed to generate answer after multiple retries")

View File

@ -111,12 +111,8 @@ class TestFunctionality(TestBotBase):
async def test_on_message_event3(self) -> None: async def test_on_message_event3(self) -> None:
async def acreate(*a, **kw): async def acreate(*a, **kw):
return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]} return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]}
def logging_warning(msg):
raise RuntimeError(msg)
message = self.create_message("Hello there! How are you?") message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \ with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
patch.object(logging, 'warning', logging_warning):
await self.bot.on_message(message) await self.bot.on_message(message)
self.bot.staff_channel.send.assert_called_once_with("ERROR: I could not parse this answer: '{ \"test\": 3 ]'", suppress_embeds=True) self.bot.staff_channel.send.assert_called_once_with("ERROR: I could not parse this answer: '{ \"test\": 3 ]'", suppress_embeds=True)