From 2db983c4625ae568ddbf9db0a862d4bb76e25799 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Wed, 12 Apr 2023 12:19:07 +0200 Subject: [PATCH] Improve history handling - Try to keep at least 3 messages from each channel in the history - Use post processed messages for the history, instead of the raw messages from the openai API --- fjerkroa_bot/ai_responder.py | 29 +++++++++++++++------- tests/test_ai.py | 48 ++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index 0cf44d2..580b9e8 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -97,13 +97,12 @@ class AIResponder(object): system = system.replace('{date}', time.strftime('%Y-%m-%d'))\ .replace('{time}', time.strftime('%H:%M:%S')) messages.append({"role": "system", "content": system}) - if limit is None: - history = self.history[:] - else: - history = self.history[-limit:] - history.append({"role": "user", "content": str(message)}) - for msg in history: + if limit is not None: + while len(self.history) > limit: + self.shrink_history_by_one() + for msg in self.history: messages.append(msg) + messages.append({"role": "user", "content": str(message)}) return messages async def draw(self, description: str) -> BytesIO: @@ -202,11 +201,22 @@ class AIResponder(object): logging.warning(f"failed to execute a fix for the answer: {repr(err)}") return answer + def shrink_history_by_one(self, index: int = 0) -> None: + if index >= len(self.history): + del self.history[0] + else: + current = self.history[index] + count = sum(1 for item in self.history[index:] if item.get('channel') == current.get('channel')) + if count > 3: + del self.history[index] + else: + self.shrink_history_by_one(index + 1) + def update_history(self, question: Dict[str, Any], answer: Dict[str, Any], limit: int) -> None: self.history.append(question) self.history.append(answer) - if len(self.history) > limit: - self.history = self.history[-limit:] + while len(self.history) > limit: + self.shrink_history_by_one() if self.history_file is not None: with open(self.history_file, 'wb') as fd: pickle.dump(self.history, fd) @@ -236,8 +246,9 @@ class AIResponder(object): if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str): retries -= 1 continue - self.update_history(messages[-1], answer, limit) answer_message = await self.post_process(message, response) + answer['content'] = str(answer_message) + self.update_history(messages[-1], answer, limit) logging.info(f"got this answer:\n{str(answer_message)}") return answer_message raise RuntimeError("Failed to generate answer after multiple retries") diff --git a/tests/test_ai.py b/tests/test_ai.py index 839436b..0a8d4e9 100644 --- a/tests/test_ai.py +++ b/tests/test_ai.py @@ -1,4 +1,7 @@ import unittest +import tempfile +import os +import pickle from fjerkroa_bot import AIMessage, AIResponse from .test_main import TestBotBase @@ -73,6 +76,51 @@ You always try to say something positive about the current day and the Fjærkroa self.assertAIResponse(response, AIResponse('test', True, 'something', None, False), scmp=lambda a, b: type(a) == str and len(a) > 5) print(f"\n{self.bot.airesponder.history}") + def test_update_history(self) -> None: + updater = self.bot.airesponder + updater.history = [] + updater.history_file = None + + question = {"channel": "test_channel", "content": "What is the meaning of life?"} + answer = {"channel": "test_channel", "content": "42"} + + # Test case 1: Limit set to 2 + updater.update_history(question, answer, 2) + self.assertEqual(updater.history, [question, answer]) + + # Test case 2: Limit set to 4, check limit enforcement (deletion) + new_question = {"channel": "test_channel", "content": "What is AI?"} + new_answer = {"channel": "test_channel", "content": "Artificial Intelligence"} + updater.update_history(new_question, new_answer, 3) + self.assertEqual(updater.history, [answer, new_question, new_answer]) + + # Test case 3: Limit set to 4, check limit enforcement (deletion) + other_question = {"channel": "other_channel", "content": "What is XXX?"} + other_answer = {"channel": "other_channel", "content": "Tripple X"} + updater.update_history(other_question, other_answer, 4) + self.assertEqual(updater.history, [new_question, new_answer, other_question, other_answer]) + + # Test case 4: Limit set to 4, check limit enforcement (deletion) + next_question = {"channel": "other_channel", "content": "What is YYY?"} + next_answer = {"channel": "other_channel", "content": "Tripple Y"} + updater.update_history(next_question, next_answer, 4) + self.assertEqual(updater.history, [new_answer, other_answer, next_question, next_answer]) + + # Test case 5: Limit set to 4, check limit enforcement (deletion) + next_question2 = {"channel": "other_channel", "content": "What is ZZZ?"} + next_answer2 = {"channel": "other_channel", "content": "Tripple Z"} + updater.update_history(next_question2, next_answer2, 4) + self.assertEqual(updater.history, [new_answer, next_answer, next_question2, next_answer2]) + + # Test case 5: Check history file save using mock + with unittest.mock.patch("builtins.open", unittest.mock.mock_open()) as mock_file: + _, temp_path = tempfile.mkstemp() + os.remove(temp_path) + self.bot.airesponder.history_file = temp_path + updater.update_history(question, answer, 2) + mock_file.assert_called_with(temp_path, 'wb') + mock_file().write.assert_called_with(pickle.dumps([question, answer])) + if __name__ == "__mait__": unittest.main()