Support saving history.
This commit is contained in:
parent
6d2a3d6ac5
commit
b82477fc83
@ -4,6 +4,8 @@ import aiohttp
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import re
|
import re
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Optional, List, Dict, Any, Tuple
|
from typing import Optional, List, Dict, Any, Tuple
|
||||||
@ -39,6 +41,12 @@ class AIResponder(object):
|
|||||||
self.history: List[Dict[str, Any]] = []
|
self.history: List[Dict[str, Any]] = []
|
||||||
self.channel = channel if channel is not None else 'system'
|
self.channel = channel if channel is not None else 'system'
|
||||||
openai.api_key = self.config['openai-token']
|
openai.api_key = self.config['openai-token']
|
||||||
|
self.history_file: Optional[Path] = None
|
||||||
|
if 'history-directory' in self.config:
|
||||||
|
self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat'
|
||||||
|
if self.history_file.exists():
|
||||||
|
with open(self.history_file, 'rb') as fd:
|
||||||
|
self.history = pickle.load(fd)
|
||||||
|
|
||||||
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||||
messages = []
|
messages = []
|
||||||
@ -133,6 +141,15 @@ class AIResponder(object):
|
|||||||
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
|
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
def update_history(self, question: Dict[str, Any], answer: Dict[str, Any], limit: int) -> None:
|
||||||
|
self.history.append(question)
|
||||||
|
self.history.append(answer)
|
||||||
|
if len(self.history) > limit:
|
||||||
|
self.history = self.history[-limit:]
|
||||||
|
if self.history_file is not None:
|
||||||
|
with open(self.history_file, 'wb') as fd:
|
||||||
|
pickle.dump(self.history, fd)
|
||||||
|
|
||||||
async def send(self, message: AIMessage) -> AIResponse:
|
async def send(self, message: AIMessage) -> AIResponse:
|
||||||
limit = self.config["history-limit"]
|
limit = self.config["history-limit"]
|
||||||
if self.short_path(message, limit):
|
if self.short_path(message, limit):
|
||||||
@ -156,9 +173,6 @@ class AIResponder(object):
|
|||||||
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
|
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
|
||||||
continue
|
continue
|
||||||
logging.info(f"got this answer:\n{pformat(response)}")
|
logging.info(f"got this answer:\n{pformat(response)}")
|
||||||
self.history.append(messages[-1])
|
self.update_history(messages[-1], answer, limit)
|
||||||
self.history.append(answer)
|
|
||||||
if len(self.history) > limit:
|
|
||||||
self.history = self.history[-limit:]
|
|
||||||
return await self.post_process(response)
|
return await self.post_process(response)
|
||||||
raise RuntimeError("Failed to generate answer after multiple retries")
|
raise RuntimeError("Failed to generate answer after multiple retries")
|
||||||
|
|||||||
@ -111,12 +111,8 @@ class TestFunctionality(TestBotBase):
|
|||||||
async def test_on_message_event3(self) -> None:
|
async def test_on_message_event3(self) -> None:
|
||||||
async def acreate(*a, **kw):
|
async def acreate(*a, **kw):
|
||||||
return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]}
|
return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]}
|
||||||
|
|
||||||
def logging_warning(msg):
|
|
||||||
raise RuntimeError(msg)
|
|
||||||
message = self.create_message("Hello there! How are you?")
|
message = self.create_message("Hello there! How are you?")
|
||||||
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
|
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
||||||
patch.object(logging, 'warning', logging_warning):
|
|
||||||
await self.bot.on_message(message)
|
await self.bot.on_message(message)
|
||||||
self.bot.staff_channel.send.assert_called_once_with("ERROR: I could not parse this answer: '{ \"test\": 3 ]'", suppress_embeds=True)
|
self.bot.staff_channel.send.assert_called_once_with("ERROR: I could not parse this answer: '{ \"test\": 3 ]'", suppress_embeds=True)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user