diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index 2a24b10..cc9e9a4 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -20,10 +20,11 @@ class AIMessageBase(object): class AIMessage(AIMessageBase): - def __init__(self, user: str, message: str, channel: str = "chat") -> None: + def __init__(self, user: str, message: str, channel: str = "chat", direct: bool = False) -> None: self.user = user self.message = message self.channel = channel + self.direct = direct class AIResponse(AIMessageBase): @@ -94,7 +95,7 @@ class AIResponder(object): response['hack']) def short_path(self, message: AIMessage, limit: int) -> bool: - if 'short-path' not in self.config: + if message.direct or 'short-path' not in self.config: return False for chan_re, user_re in self.config['short-path']: chan_ma = re.match(chan_re, message.channel) diff --git a/fjerkroa_bot/discord_bot.py b/fjerkroa_bot/discord_bot.py index 7f3d750..d527ab7 100644 --- a/fjerkroa_bot/discord_bot.py +++ b/fjerkroa_bot/discord_bot.py @@ -77,7 +77,7 @@ class FjerkroaBot(commands.Bot): channel_name = str(message.channel.name) else: channel_name = str(message.channel.id) - msg = AIMessage(message.author.name, message_content, channel_name) + msg = AIMessage(message.author.name, message_content, channel_name, self.user in message.mentions) await self.respond(msg, message.channel) async def respond(self, message: AIMessage, channel: TextChannel) -> None: diff --git a/tests/test_main.py b/tests/test_main.py index 9057ce2..6ac3ca0 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -88,11 +88,11 @@ class TestFunctionality(TestBotBase): with patch.object(openai.ChatCompletion, 'acreate', new=acreate): await self.bot.on_message(message) self.assertEqual(self.bot.airesponder.history[-1]["content"], - '{"user": "madeup_name", "message": "Hello, how are you?", "channel": "some_channel"}') + '{"user": "madeup_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}') message.author.name = 'different_name' await self.bot.on_message(message) self.assertEqual(self.bot.airesponder.history[-2]["content"], - '{"user": "different_name", "message": "Hello, how are you?", "channel": "some_channel"}') + '{"user": "different_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}') message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) async def test_on_message_event2(self) -> None: