From 2508a12b44093c9632d6c5cd35215749f0552a7f Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Thu, 13 Apr 2023 18:36:06 +0200 Subject: [PATCH] Fix history limit handling. --- .gitignore | 1 + fjerkroa_bot/ai_responder.py | 14 +++++++++----- tests/test_ai.py | 20 ++++++++++---------- tests/test_main.py | 8 ++++---- 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/.gitignore b/.gitignore index 54ef463..3fc038f 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ build/ history/ .config.yaml .db +.env diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index c72b704..95f016b 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -18,7 +18,7 @@ def pp(*args, **kw): return pformat(*args, **kw) -def parse_response(content: str) -> Dict: +def parse_json(content: str) -> Dict: content = content.strip() try: return json.loads(content) @@ -36,7 +36,7 @@ def parse_maybe_json(json_string): return ' '.join(map(str, (json_string.values() if isinstance(json_string, dict) else json_string))) json_string = str(json_string).strip() try: - parsed_json = parse_response(json_string) + parsed_json = parse_json(json_string) except Exception: for b, e in [('{', '}'), ('[', ']')]: if json_string.startswith(b) and json_string.endswith(e): @@ -215,7 +215,11 @@ class AIResponder(object): del self.history[0] else: current = self.history[index] - count = sum(1 for item in self.history if item.get('channel') == current.get('channel')) + + def same_channel(item: Dict[str, Any]) -> bool: + return parse_json(item['content']).get('channel') == parse_json(current['content']).get('channel') + + count = sum(1 for item in self.history if same_channel(item)) if count > self.config.get('history-per-channel', 3): del self.history[index] else: @@ -254,14 +258,14 @@ class AIResponder(object): # Attempt to parse the AI's response try: - response = parse_response(answer['content']) + response = parse_json(answer['content']) except Exception as err: logging.warning(f"failed to parse the answer: {pp(err)}\n{repr(answer['content'])}") answer['content'] = await self.fix(answer['content']) # Retry parsing the fixed content try: - response = parse_response(answer['content']) + response = parse_json(answer['content']) except Exception as err: logging.error(f"failed to parse the fixed answer: {pp(err)}\n{repr(answer['content'])}") retries -= 1 diff --git a/tests/test_ai.py b/tests/test_ai.py index 0207197..a01fae6 100644 --- a/tests/test_ai.py +++ b/tests/test_ai.py @@ -82,34 +82,34 @@ You always try to say something positive about the current day and the Fjærkroa updater.history = [] updater.history_file = None - question = {"channel": "test_channel", "content": "What is the meaning of life?"} - answer = {"channel": "test_channel", "content": "42"} + question = {"content": '{"channel": "test_channel", "message": "What is the meaning of life?"}'} + answer = {"content": '{"channel": "test_channel", "message": "42"}'} # Test case 1: Limit set to 2 updater.update_history(question, answer, 2) self.assertEqual(updater.history, [question, answer]) # Test case 2: Limit set to 4, check limit enforcement (deletion) - new_question = {"channel": "test_channel", "content": "What is AI?"} - new_answer = {"channel": "test_channel", "content": "Artificial Intelligence"} + new_question = {"content": '{"channel": "test_channel", "message": "What is AI?"}'} + new_answer = {"content": '{"channel": "test_channel", "message": "Artificial Intelligence"}'} updater.update_history(new_question, new_answer, 3) self.assertEqual(updater.history, [answer, new_question, new_answer]) # Test case 3: Limit set to 4, check limit enforcement (deletion) - other_question = {"channel": "other_channel", "content": "What is XXX?"} - other_answer = {"channel": "other_channel", "content": "Tripple X"} + other_question = {"content": '{"channel": "other_channel", "message": "What is XXX?"}'} + other_answer = {"content": '{"channel": "other_channel", "message": "Tripple X"}'} updater.update_history(other_question, other_answer, 4) self.assertEqual(updater.history, [new_question, new_answer, other_question, other_answer]) # Test case 4: Limit set to 4, check limit enforcement (deletion) - next_question = {"channel": "other_channel", "content": "What is YYY?"} - next_answer = {"channel": "other_channel", "content": "Tripple Y"} + next_question = {"content": '{"channel": "other_channel", "message": "What is YYY?"}'} + next_answer = {"content": '{"channel": "other_channel", "message": "Tripple Y"}'} updater.update_history(next_question, next_answer, 4) self.assertEqual(updater.history, [new_answer, other_answer, next_question, next_answer]) # Test case 5: Limit set to 4, check limit enforcement (deletion) - next_question2 = {"channel": "other_channel", "content": "What is ZZZ?"} - next_answer2 = {"channel": "other_channel", "content": "Tripple Z"} + next_question2 = {"content": '{"channel": "other_channel", "message": "What is ZZZ?"}'} + next_answer2 = {"content": '{"channel": "other_channel", "message": "Tripple Z"}'} updater.update_history(next_question2, next_answer2, 4) self.assertEqual(updater.history, [new_answer, next_answer, next_question2, next_answer2]) diff --git a/tests/test_main.py b/tests/test_main.py index 6e57881..4c68a84 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -162,13 +162,13 @@ class TestFunctionality(TestBotBase): @patch("builtins.open", new_callable=mock_open) def test_update_history_with_file(self, mock_file): - self.bot.airesponder.update_history({"q": "What's your name?"}, {"a": "AI"}, 10) + self.bot.airesponder.update_history({'content': '{"q": "What\'s your name?"}'}, {'content': '{"a": "AI"}'}, 10) self.assertEqual(len(self.bot.airesponder.history), 2) - self.bot.airesponder.update_history({"q1": "Q1"}, {"a1": "A1"}, 2) - self.bot.airesponder.update_history({"q2": "Q2"}, {"a2": "A2"}, 2) + self.bot.airesponder.update_history({'content': '{"q1": "Q1"}'}, {'content': '{"a1": "A1"}'}, 2) + self.bot.airesponder.update_history({'content': '{"q2": "Q2"}'}, {'content': '{"a2": "A2"}'}, 2) self.assertEqual(len(self.bot.airesponder.history), 2) self.bot.airesponder.history_file = "mock_file.pkl" - self.bot.airesponder.update_history({"q": "What's your favorite color?"}, {"a": "Blue"}, 10) + self.bot.airesponder.update_history({'content': '{"q": "What\'s your favorite color?"}'}, {'content': '{"a": "Blue"}'}, 10) mock_file.assert_called_once_with("mock_file.pkl", "wb") mock_file().write.assert_called_once()