Handle messages as direct, when bot is mentioned.

This commit is contained in:
OK 2023-04-10 13:50:35 +02:00
parent 545db1f79d
commit 9a411a3fed
2 changed files with 10 additions and 7 deletions

View File

@ -114,7 +114,7 @@ class AIResponder(object):
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
async def post_process(self, response: Dict[str, Any]) -> AIResponse:
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
for fld in ('answer', 'staff', 'picture'):
if str(response[fld]).strip().lower() in ('none', '', 'null'):
response[fld] = None
@ -129,6 +129,8 @@ class AIResponder(object):
response['answer'] = str(response['answer'])
response['answer'] = re.sub(r'@\[([^\]]*)\]\([^\)]*\)', r'\1', response['answer'])
response['answer'] = re.sub(r'\[[^\]]*\]\(([^\)]*)\)', r'\1', response['answer'])
if message.direct or message.user in message.message:
response['answer_needed'] = True
return AIResponse(response['answer'],
response['answer_needed'],
parse_maybe_json(response['staff']),
@ -232,5 +234,5 @@ class AIResponder(object):
continue
logging.info(f"got this answer:\n{pp(response)}")
self.update_history(messages[-1], answer, limit)
return await self.post_process(response)
return await self.post_process(message, response)
raise RuntimeError("Failed to generate answer after multiple retries")

View File

@ -8,7 +8,7 @@ import logging
import pytest
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY
from fjerkroa_bot import FjerkroaBot
from fjerkroa_bot.ai_responder import parse_maybe_json, AIResponse
from fjerkroa_bot.ai_responder import parse_maybe_json, AIResponse, AIMessage
from discord import User, Message, TextChannel
@ -20,7 +20,7 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
Mock(text="Nice day today!")
]
self.config_data = {
"openai-token": os.environ['OPENAI_TOKEN'],
"openai-token": os.environ.get('OPENAI_TOKEN', 'test'),
"model": "gpt-4",
"max-tokens": 1024,
"temperature": 0.9,
@ -79,18 +79,19 @@ class TestFunctionality(TestBotBase):
self.assertEqual(parse_maybe_json(json_struct), expected_output)
async def test_message_lings(self) -> None:
request = AIMessage('Lala', 'Hello there!', 'chat', False,)
message = {'answer': 'Test [Link](https://www.example.com/test)',
'answer_needed': True, 'staff': None, 'picture': None, 'hack': False}
expected = AIResponse('Test https://www.example.com/test', True, None, None, False)
self.assertEqual(str(await self.bot.airesponder.post_process(message)), str(expected))
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
message = {'answer': 'Test @[Link](https://www.example.com/test)',
'answer_needed': True, 'staff': None, 'picture': None, 'hack': False}
expected = AIResponse('Test Link', True, None, None, False)
self.assertEqual(str(await self.bot.airesponder.post_process(message)), str(expected))
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
message = {'answer': 'Test [Link](https://www.example.com/test) and [Link2](https://xxx) lala',
'answer_needed': True, 'staff': None, 'picture': None, 'hack': False}
expected = AIResponse('Test https://www.example.com/test and https://xxx lala', True, None, None, False)
self.assertEqual(str(await self.bot.airesponder.post_process(message)), str(expected))
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
async def test_on_message_event(self) -> None:
async def acreate(*a, **kw):