Improve history handling
- Try to keep at least 3 messages from each channel in the history - Use post processed messages for the history, instead of the raw messages from the openai API
This commit is contained in:
parent
d136b0af21
commit
2db983c462
@ -97,13 +97,12 @@ class AIResponder(object):
|
|||||||
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
|
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
|
||||||
.replace('{time}', time.strftime('%H:%M:%S'))
|
.replace('{time}', time.strftime('%H:%M:%S'))
|
||||||
messages.append({"role": "system", "content": system})
|
messages.append({"role": "system", "content": system})
|
||||||
if limit is None:
|
if limit is not None:
|
||||||
history = self.history[:]
|
while len(self.history) > limit:
|
||||||
else:
|
self.shrink_history_by_one()
|
||||||
history = self.history[-limit:]
|
for msg in self.history:
|
||||||
history.append({"role": "user", "content": str(message)})
|
|
||||||
for msg in history:
|
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
messages.append({"role": "user", "content": str(message)})
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def draw(self, description: str) -> BytesIO:
|
async def draw(self, description: str) -> BytesIO:
|
||||||
@ -202,11 +201,22 @@ class AIResponder(object):
|
|||||||
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
|
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
def shrink_history_by_one(self, index: int = 0) -> None:
|
||||||
|
if index >= len(self.history):
|
||||||
|
del self.history[0]
|
||||||
|
else:
|
||||||
|
current = self.history[index]
|
||||||
|
count = sum(1 for item in self.history[index:] if item.get('channel') == current.get('channel'))
|
||||||
|
if count > 3:
|
||||||
|
del self.history[index]
|
||||||
|
else:
|
||||||
|
self.shrink_history_by_one(index + 1)
|
||||||
|
|
||||||
def update_history(self, question: Dict[str, Any], answer: Dict[str, Any], limit: int) -> None:
|
def update_history(self, question: Dict[str, Any], answer: Dict[str, Any], limit: int) -> None:
|
||||||
self.history.append(question)
|
self.history.append(question)
|
||||||
self.history.append(answer)
|
self.history.append(answer)
|
||||||
if len(self.history) > limit:
|
while len(self.history) > limit:
|
||||||
self.history = self.history[-limit:]
|
self.shrink_history_by_one()
|
||||||
if self.history_file is not None:
|
if self.history_file is not None:
|
||||||
with open(self.history_file, 'wb') as fd:
|
with open(self.history_file, 'wb') as fd:
|
||||||
pickle.dump(self.history, fd)
|
pickle.dump(self.history, fd)
|
||||||
@ -236,8 +246,9 @@ class AIResponder(object):
|
|||||||
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
|
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
|
||||||
retries -= 1
|
retries -= 1
|
||||||
continue
|
continue
|
||||||
self.update_history(messages[-1], answer, limit)
|
|
||||||
answer_message = await self.post_process(message, response)
|
answer_message = await self.post_process(message, response)
|
||||||
|
answer['content'] = str(answer_message)
|
||||||
|
self.update_history(messages[-1], answer, limit)
|
||||||
logging.info(f"got this answer:\n{str(answer_message)}")
|
logging.info(f"got this answer:\n{str(answer_message)}")
|
||||||
return answer_message
|
return answer_message
|
||||||
raise RuntimeError("Failed to generate answer after multiple retries")
|
raise RuntimeError("Failed to generate answer after multiple retries")
|
||||||
|
|||||||
@ -1,4 +1,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
from fjerkroa_bot import AIMessage, AIResponse
|
from fjerkroa_bot import AIMessage, AIResponse
|
||||||
from .test_main import TestBotBase
|
from .test_main import TestBotBase
|
||||||
|
|
||||||
@ -73,6 +76,51 @@ You always try to say something positive about the current day and the Fjærkroa
|
|||||||
self.assertAIResponse(response, AIResponse('test', True, 'something', None, False), scmp=lambda a, b: type(a) == str and len(a) > 5)
|
self.assertAIResponse(response, AIResponse('test', True, 'something', None, False), scmp=lambda a, b: type(a) == str and len(a) > 5)
|
||||||
print(f"\n{self.bot.airesponder.history}")
|
print(f"\n{self.bot.airesponder.history}")
|
||||||
|
|
||||||
|
def test_update_history(self) -> None:
|
||||||
|
updater = self.bot.airesponder
|
||||||
|
updater.history = []
|
||||||
|
updater.history_file = None
|
||||||
|
|
||||||
|
question = {"channel": "test_channel", "content": "What is the meaning of life?"}
|
||||||
|
answer = {"channel": "test_channel", "content": "42"}
|
||||||
|
|
||||||
|
# Test case 1: Limit set to 2
|
||||||
|
updater.update_history(question, answer, 2)
|
||||||
|
self.assertEqual(updater.history, [question, answer])
|
||||||
|
|
||||||
|
# Test case 2: Limit set to 4, check limit enforcement (deletion)
|
||||||
|
new_question = {"channel": "test_channel", "content": "What is AI?"}
|
||||||
|
new_answer = {"channel": "test_channel", "content": "Artificial Intelligence"}
|
||||||
|
updater.update_history(new_question, new_answer, 3)
|
||||||
|
self.assertEqual(updater.history, [answer, new_question, new_answer])
|
||||||
|
|
||||||
|
# Test case 3: Limit set to 4, check limit enforcement (deletion)
|
||||||
|
other_question = {"channel": "other_channel", "content": "What is XXX?"}
|
||||||
|
other_answer = {"channel": "other_channel", "content": "Tripple X"}
|
||||||
|
updater.update_history(other_question, other_answer, 4)
|
||||||
|
self.assertEqual(updater.history, [new_question, new_answer, other_question, other_answer])
|
||||||
|
|
||||||
|
# Test case 4: Limit set to 4, check limit enforcement (deletion)
|
||||||
|
next_question = {"channel": "other_channel", "content": "What is YYY?"}
|
||||||
|
next_answer = {"channel": "other_channel", "content": "Tripple Y"}
|
||||||
|
updater.update_history(next_question, next_answer, 4)
|
||||||
|
self.assertEqual(updater.history, [new_answer, other_answer, next_question, next_answer])
|
||||||
|
|
||||||
|
# Test case 5: Limit set to 4, check limit enforcement (deletion)
|
||||||
|
next_question2 = {"channel": "other_channel", "content": "What is ZZZ?"}
|
||||||
|
next_answer2 = {"channel": "other_channel", "content": "Tripple Z"}
|
||||||
|
updater.update_history(next_question2, next_answer2, 4)
|
||||||
|
self.assertEqual(updater.history, [new_answer, next_answer, next_question2, next_answer2])
|
||||||
|
|
||||||
|
# Test case 5: Check history file save using mock
|
||||||
|
with unittest.mock.patch("builtins.open", unittest.mock.mock_open()) as mock_file:
|
||||||
|
_, temp_path = tempfile.mkstemp()
|
||||||
|
os.remove(temp_path)
|
||||||
|
self.bot.airesponder.history_file = temp_path
|
||||||
|
updater.update_history(question, answer, 2)
|
||||||
|
mock_file.assert_called_with(temp_path, 'wb')
|
||||||
|
mock_file().write.assert_called_with(pickle.dumps([question, answer]))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__mait__":
|
if __name__ == "__mait__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user