From 78591ef13a44af9aef0aabef91dad810f98d1b1d Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Fri, 24 Mar 2023 13:42:10 +0100 Subject: [PATCH] Add support of short path messages, which are not send to the AI but just added to the history. --- fjerkroa_bot/ai_responder.py | 15 +++++++++++++++ fjerkroa_bot/discord_bot.py | 2 ++ tests/test_main.py | 22 ++++++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index b3f080d..902f17c 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -3,6 +3,7 @@ import openai import aiohttp import logging import time +import re from io import BytesIO from pprint import pformat from typing import Optional, List, Dict, Any @@ -84,8 +85,22 @@ class AIResponder(object): response['picture'], response['hack']) + def short_path(self, message: AIMessage) -> bool: + if 'short-path' not in self.config: + return False + for chan_re, user_re in self.config['short-path']: + chan_ma = re.match(chan_re, message.channel) + user_ma = re.match(user_re, message.user) + if chan_ma and user_ma: + return True + return False + async def send(self, message: AIMessage) -> AIResponse: limit = self.config["history-limit"] + if self.short_path(message): + self.history.append({"role": "user", "content": str(message)}) + self.history = self.history[-limit:] + return AIResponse(None, False, None, None, False) for _ in range(14): messages = self._message(message, limit) logging.info(f"try to send this messages:\n{pformat(messages)}") diff --git a/fjerkroa_bot/discord_bot.py b/fjerkroa_bot/discord_bot.py index 292996a..7f3d750 100644 --- a/fjerkroa_bot/discord_bot.py +++ b/fjerkroa_bot/discord_bot.py @@ -48,6 +48,8 @@ class FjerkroaBot(commands.Bot): if event.src_path == self.config_file: self.config = self.load_config(self.config_file) self.airesponder.config = self.config + for responder in self.aichannels.values(): + responder.config = self.config async def on_ready(self): print(f"We have logged in as {self.user}") diff --git a/tests/test_main.py b/tests/test_main.py index 31f300f..219dcfb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -74,6 +74,28 @@ class TestFunctionality(TestBotBase): await self.bot.on_message(message) message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) + async def test_on_message_stort_path(self) -> None: + async def acreate(*a, **kw): + answer = {'answer': 'Hello!', + 'answer_needed': True, + 'staff': None, + 'picture': None, + 'hack': False} + return {'choices': [{'message': {'content': json.dumps(answer)}}]} + message = self.create_message("Hello there! How are you?") + message.author.name = 'madeup_name' + message.channel.name = 'some_channel' + self.bot.config['short-path'] = [[r'some.*', r'madeup.*']] + with patch.object(openai.ChatCompletion, 'acreate', new=acreate): + await self.bot.on_message(message) + self.assertEqual(self.bot.airesponder.history[-1]["content"], + '{"user": "madeup_name", "message": "Hello, how are you?", "channel": "some_channel"}') + message.author.name = 'different_name' + await self.bot.on_message(message) + self.assertEqual(self.bot.airesponder.history[-2]["content"], + '{"user": "different_name", "message": "Hello, how are you?", "channel": "some_channel"}') + message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) + async def test_on_message_event2(self) -> None: async def acreate(*a, **kw): answer = {'answer': 'Hello!',