From d23e7802483afd342e177aea20ff111d32dddc57 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Wed, 22 Mar 2023 22:16:59 +0100 Subject: [PATCH] Support different system messages for different channels and own history for those channels. --- fjerkroa_bot/ai_responder.py | 7 ++++--- fjerkroa_bot/discord_bot.py | 11 +++++++++-- tests/test_main.py | 1 + 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index 27cc6df..479f9be 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -32,15 +32,16 @@ class AIResponse(AIMessageBase): class AIResponder(object): - def __init__(self, config: Dict[str, Any]) -> None: + def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None: self.config = config self.history: List[Dict[str, Any]] = [] + self.channel = channel if channel is not None else 'system' openai.api_key = self.config['openai-token'] def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]: messages = [] - system = self.config["system"].replace('{date}', time.strftime('%Y-%m-%d'))\ - .replace('{time}', time.strftime('%H:%M:%S')) + system = self.config[self.channel].replace('{date}', time.strftime('%Y-%m-%d'))\ + .replace('{time}', time.strftime('%H:%M:%S')) messages.append({"role": "system", "content": system}) if limit is None: history = self.history[:] diff --git a/fjerkroa_bot/discord_bot.py b/fjerkroa_bot/discord_bot.py index 64b02a2..3c918a0 100644 --- a/fjerkroa_bot/discord_bot.py +++ b/fjerkroa_bot/discord_bot.py @@ -32,6 +32,9 @@ class FjerkroaBot(commands.Bot): self.observer.start() self.airesponder = AIResponder(self.config) + self.aichannels = {} + for chan_name in self.config['additional-responders']: + self.aichannels[chan_name] = AIResponder(self.config, chan_name) super().__init__(command_prefix="!", case_insensitive=True, intents=intents) @@ -73,8 +76,12 @@ class FjerkroaBot(commands.Bot): except Exception: channel_name = str(channel.id) logging.info(f"handle message {str(message)} for channel {channel_name}") + if channel_name in self.aichannels: + airesponder = self.aichannels[channel_name] + else: + airesponder = self.airesponder async with channel.typing(): - response = await self.airesponder.send(message) + response = await airesponder.send(message) if response.hack: logging.warning(f"User {message.user} tried to hack the system.") if response.staff is None: @@ -85,7 +92,7 @@ class FjerkroaBot(commands.Bot): if not response.answer_needed: return if response.picture is not None: - images = [discord.File(fp=await self.airesponder.draw(response.picture), filename="image.png")] + images = [discord.File(fp=await airesponder.draw(response.picture), filename="image.png")] await channel.send(response.answer, files=images) else: await channel.send(response.answer) diff --git a/tests/test_main.py b/tests/test_main.py index 138854b..53a4cb0 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -28,6 +28,7 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase): "frequency-penalty": 1.0, "history-limit": 10, "system": "You are an smart AI", + "additional-responders": [], } self.history_data = [] with patch.object(FjerkroaBot, 'load_config', new=lambda s, c: self.config_data), \