Improve history handling

- Try to keep at least 3 messages from each channel in the history
- Use post processed messages for the history, instead of the raw
  messages from the openai API
This commit is contained in:
OK 2023-04-12 12:19:07 +02:00
parent d136b0af21
commit 2db983c462
2 changed files with 68 additions and 9 deletions

View File

@ -97,13 +97,12 @@ class AIResponder(object):
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
.replace('{time}', time.strftime('%H:%M:%S'))
messages.append({"role": "system", "content": system})
if limit is None:
history = self.history[:]
else:
history = self.history[-limit:]
history.append({"role": "user", "content": str(message)})
for msg in history:
if limit is not None:
while len(self.history) > limit:
self.shrink_history_by_one()
for msg in self.history:
messages.append(msg)
messages.append({"role": "user", "content": str(message)})
return messages
async def draw(self, description: str) -> BytesIO:
@ -202,11 +201,22 @@ class AIResponder(object):
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
return answer
def shrink_history_by_one(self, index: int = 0) -> None:
if index >= len(self.history):
del self.history[0]
else:
current = self.history[index]
count = sum(1 for item in self.history[index:] if item.get('channel') == current.get('channel'))
if count > 3:
del self.history[index]
else:
self.shrink_history_by_one(index + 1)
def update_history(self, question: Dict[str, Any], answer: Dict[str, Any], limit: int) -> None:
self.history.append(question)
self.history.append(answer)
if len(self.history) > limit:
self.history = self.history[-limit:]
while len(self.history) > limit:
self.shrink_history_by_one()
if self.history_file is not None:
with open(self.history_file, 'wb') as fd:
pickle.dump(self.history, fd)
@ -236,8 +246,9 @@ class AIResponder(object):
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
retries -= 1
continue
self.update_history(messages[-1], answer, limit)
answer_message = await self.post_process(message, response)
answer['content'] = str(answer_message)
self.update_history(messages[-1], answer, limit)
logging.info(f"got this answer:\n{str(answer_message)}")
return answer_message
raise RuntimeError("Failed to generate answer after multiple retries")

View File

@ -1,4 +1,7 @@
import unittest
import tempfile
import os
import pickle
from fjerkroa_bot import AIMessage, AIResponse
from .test_main import TestBotBase
@ -73,6 +76,51 @@ You always try to say something positive about the current day and the Fjærkroa
self.assertAIResponse(response, AIResponse('test', True, 'something', None, False), scmp=lambda a, b: type(a) == str and len(a) > 5)
print(f"\n{self.bot.airesponder.history}")
def test_update_history(self) -> None:
updater = self.bot.airesponder
updater.history = []
updater.history_file = None
question = {"channel": "test_channel", "content": "What is the meaning of life?"}
answer = {"channel": "test_channel", "content": "42"}
# Test case 1: Limit set to 2
updater.update_history(question, answer, 2)
self.assertEqual(updater.history, [question, answer])
# Test case 2: Limit set to 4, check limit enforcement (deletion)
new_question = {"channel": "test_channel", "content": "What is AI?"}
new_answer = {"channel": "test_channel", "content": "Artificial Intelligence"}
updater.update_history(new_question, new_answer, 3)
self.assertEqual(updater.history, [answer, new_question, new_answer])
# Test case 3: Limit set to 4, check limit enforcement (deletion)
other_question = {"channel": "other_channel", "content": "What is XXX?"}
other_answer = {"channel": "other_channel", "content": "Tripple X"}
updater.update_history(other_question, other_answer, 4)
self.assertEqual(updater.history, [new_question, new_answer, other_question, other_answer])
# Test case 4: Limit set to 4, check limit enforcement (deletion)
next_question = {"channel": "other_channel", "content": "What is YYY?"}
next_answer = {"channel": "other_channel", "content": "Tripple Y"}
updater.update_history(next_question, next_answer, 4)
self.assertEqual(updater.history, [new_answer, other_answer, next_question, next_answer])
# Test case 5: Limit set to 4, check limit enforcement (deletion)
next_question2 = {"channel": "other_channel", "content": "What is ZZZ?"}
next_answer2 = {"channel": "other_channel", "content": "Tripple Z"}
updater.update_history(next_question2, next_answer2, 4)
self.assertEqual(updater.history, [new_answer, next_answer, next_question2, next_answer2])
# Test case 5: Check history file save using mock
with unittest.mock.patch("builtins.open", unittest.mock.mock_open()) as mock_file:
_, temp_path = tempfile.mkstemp()
os.remove(temp_path)
self.bot.airesponder.history_file = temp_path
updater.update_history(question, answer, 2)
mock_file.assert_called_with(temp_path, 'wb')
mock_file().write.assert_called_with(pickle.dumps([question, answer]))
if __name__ == "__mait__":
unittest.main()