Improve history handling
- Try to keep at least 3 messages from each channel in the history - Use post processed messages for the history, instead of the raw messages from the openai API
This commit is contained in:
parent
d136b0af21
commit
2db983c462
@ -97,13 +97,12 @@ class AIResponder(object):
|
||||
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
|
||||
.replace('{time}', time.strftime('%H:%M:%S'))
|
||||
messages.append({"role": "system", "content": system})
|
||||
if limit is None:
|
||||
history = self.history[:]
|
||||
else:
|
||||
history = self.history[-limit:]
|
||||
history.append({"role": "user", "content": str(message)})
|
||||
for msg in history:
|
||||
if limit is not None:
|
||||
while len(self.history) > limit:
|
||||
self.shrink_history_by_one()
|
||||
for msg in self.history:
|
||||
messages.append(msg)
|
||||
messages.append({"role": "user", "content": str(message)})
|
||||
return messages
|
||||
|
||||
async def draw(self, description: str) -> BytesIO:
|
||||
@ -202,11 +201,22 @@ class AIResponder(object):
|
||||
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
|
||||
return answer
|
||||
|
||||
def shrink_history_by_one(self, index: int = 0) -> None:
|
||||
if index >= len(self.history):
|
||||
del self.history[0]
|
||||
else:
|
||||
current = self.history[index]
|
||||
count = sum(1 for item in self.history[index:] if item.get('channel') == current.get('channel'))
|
||||
if count > 3:
|
||||
del self.history[index]
|
||||
else:
|
||||
self.shrink_history_by_one(index + 1)
|
||||
|
||||
def update_history(self, question: Dict[str, Any], answer: Dict[str, Any], limit: int) -> None:
|
||||
self.history.append(question)
|
||||
self.history.append(answer)
|
||||
if len(self.history) > limit:
|
||||
self.history = self.history[-limit:]
|
||||
while len(self.history) > limit:
|
||||
self.shrink_history_by_one()
|
||||
if self.history_file is not None:
|
||||
with open(self.history_file, 'wb') as fd:
|
||||
pickle.dump(self.history, fd)
|
||||
@ -236,8 +246,9 @@ class AIResponder(object):
|
||||
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
|
||||
retries -= 1
|
||||
continue
|
||||
self.update_history(messages[-1], answer, limit)
|
||||
answer_message = await self.post_process(message, response)
|
||||
answer['content'] = str(answer_message)
|
||||
self.update_history(messages[-1], answer, limit)
|
||||
logging.info(f"got this answer:\n{str(answer_message)}")
|
||||
return answer_message
|
||||
raise RuntimeError("Failed to generate answer after multiple retries")
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
import pickle
|
||||
from fjerkroa_bot import AIMessage, AIResponse
|
||||
from .test_main import TestBotBase
|
||||
|
||||
@ -73,6 +76,51 @@ You always try to say something positive about the current day and the Fjærkroa
|
||||
self.assertAIResponse(response, AIResponse('test', True, 'something', None, False), scmp=lambda a, b: type(a) == str and len(a) > 5)
|
||||
print(f"\n{self.bot.airesponder.history}")
|
||||
|
||||
def test_update_history(self) -> None:
|
||||
updater = self.bot.airesponder
|
||||
updater.history = []
|
||||
updater.history_file = None
|
||||
|
||||
question = {"channel": "test_channel", "content": "What is the meaning of life?"}
|
||||
answer = {"channel": "test_channel", "content": "42"}
|
||||
|
||||
# Test case 1: Limit set to 2
|
||||
updater.update_history(question, answer, 2)
|
||||
self.assertEqual(updater.history, [question, answer])
|
||||
|
||||
# Test case 2: Limit set to 4, check limit enforcement (deletion)
|
||||
new_question = {"channel": "test_channel", "content": "What is AI?"}
|
||||
new_answer = {"channel": "test_channel", "content": "Artificial Intelligence"}
|
||||
updater.update_history(new_question, new_answer, 3)
|
||||
self.assertEqual(updater.history, [answer, new_question, new_answer])
|
||||
|
||||
# Test case 3: Limit set to 4, check limit enforcement (deletion)
|
||||
other_question = {"channel": "other_channel", "content": "What is XXX?"}
|
||||
other_answer = {"channel": "other_channel", "content": "Tripple X"}
|
||||
updater.update_history(other_question, other_answer, 4)
|
||||
self.assertEqual(updater.history, [new_question, new_answer, other_question, other_answer])
|
||||
|
||||
# Test case 4: Limit set to 4, check limit enforcement (deletion)
|
||||
next_question = {"channel": "other_channel", "content": "What is YYY?"}
|
||||
next_answer = {"channel": "other_channel", "content": "Tripple Y"}
|
||||
updater.update_history(next_question, next_answer, 4)
|
||||
self.assertEqual(updater.history, [new_answer, other_answer, next_question, next_answer])
|
||||
|
||||
# Test case 5: Limit set to 4, check limit enforcement (deletion)
|
||||
next_question2 = {"channel": "other_channel", "content": "What is ZZZ?"}
|
||||
next_answer2 = {"channel": "other_channel", "content": "Tripple Z"}
|
||||
updater.update_history(next_question2, next_answer2, 4)
|
||||
self.assertEqual(updater.history, [new_answer, next_answer, next_question2, next_answer2])
|
||||
|
||||
# Test case 5: Check history file save using mock
|
||||
with unittest.mock.patch("builtins.open", unittest.mock.mock_open()) as mock_file:
|
||||
_, temp_path = tempfile.mkstemp()
|
||||
os.remove(temp_path)
|
||||
self.bot.airesponder.history_file = temp_path
|
||||
updater.update_history(question, answer, 2)
|
||||
mock_file.assert_called_with(temp_path, 'wb')
|
||||
mock_file().write.assert_called_with(pickle.dumps([question, answer]))
|
||||
|
||||
|
||||
if __name__ == "__mait__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user