From 9a411a3fed0b7e082428e0dcf110156f96326c68 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Mon, 10 Apr 2023 13:50:35 +0200 Subject: [PATCH] Handle messages as direct, when bot is mentioned. --- fjerkroa_bot/ai_responder.py | 6 ++++-- tests/test_main.py | 11 ++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index 9f24d1d..41565d8 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -114,7 +114,7 @@ class AIResponder(object): logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}") raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries") - async def post_process(self, response: Dict[str, Any]) -> AIResponse: + async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse: for fld in ('answer', 'staff', 'picture'): if str(response[fld]).strip().lower() in ('none', '', 'null'): response[fld] = None @@ -129,6 +129,8 @@ class AIResponder(object): response['answer'] = str(response['answer']) response['answer'] = re.sub(r'@\[([^\]]*)\]\([^\)]*\)', r'\1', response['answer']) response['answer'] = re.sub(r'\[[^\]]*\]\(([^\)]*)\)', r'\1', response['answer']) + if message.direct or message.user in message.message: + response['answer_needed'] = True return AIResponse(response['answer'], response['answer_needed'], parse_maybe_json(response['staff']), @@ -232,5 +234,5 @@ class AIResponder(object): continue logging.info(f"got this answer:\n{pp(response)}") self.update_history(messages[-1], answer, limit) - return await self.post_process(response) + return await self.post_process(message, response) raise RuntimeError("Failed to generate answer after multiple retries") diff --git a/tests/test_main.py b/tests/test_main.py index 682d993..7b04b7f 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -8,7 +8,7 @@ import logging import pytest from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY from fjerkroa_bot import FjerkroaBot -from fjerkroa_bot.ai_responder import parse_maybe_json, AIResponse +from fjerkroa_bot.ai_responder import parse_maybe_json, AIResponse, AIMessage from discord import User, Message, TextChannel @@ -20,7 +20,7 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase): Mock(text="Nice day today!") ] self.config_data = { - "openai-token": os.environ['OPENAI_TOKEN'], + "openai-token": os.environ.get('OPENAI_TOKEN', 'test'), "model": "gpt-4", "max-tokens": 1024, "temperature": 0.9, @@ -79,18 +79,19 @@ class TestFunctionality(TestBotBase): self.assertEqual(parse_maybe_json(json_struct), expected_output) async def test_message_lings(self) -> None: + request = AIMessage('Lala', 'Hello there!', 'chat', False,) message = {'answer': 'Test [Link](https://www.example.com/test)', 'answer_needed': True, 'staff': None, 'picture': None, 'hack': False} expected = AIResponse('Test https://www.example.com/test', True, None, None, False) - self.assertEqual(str(await self.bot.airesponder.post_process(message)), str(expected)) + self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected)) message = {'answer': 'Test @[Link](https://www.example.com/test)', 'answer_needed': True, 'staff': None, 'picture': None, 'hack': False} expected = AIResponse('Test Link', True, None, None, False) - self.assertEqual(str(await self.bot.airesponder.post_process(message)), str(expected)) + self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected)) message = {'answer': 'Test [Link](https://www.example.com/test) and [Link2](https://xxx) lala', 'answer_needed': True, 'staff': None, 'picture': None, 'hack': False} expected = AIResponse('Test https://www.example.com/test and https://xxx lala', True, None, None, False) - self.assertEqual(str(await self.bot.airesponder.post_process(message)), str(expected)) + self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected)) async def test_on_message_event(self) -> None: async def acreate(*a, **kw):