From c0b51f947caed666c6d77dd9117202d50dfbe3ab Mon Sep 17 00:00:00 2001 From: Fjerkroa Auto Date: Tue, 21 Mar 2023 20:43:46 +0100 Subject: [PATCH] Add ai_responder, some more stuff. --- fjerkroa_bot/ai_responder.py | 66 ++++++++++++++++++++++++++++++++++++ fjerkroa_bot/discord_bot.py | 26 ++++++++++---- tests/test_main.py | 44 +++++++++++++----------- 3 files changed, 110 insertions(+), 26 deletions(-) create mode 100644 fjerkroa_bot/ai_responder.py diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py new file mode 100644 index 0000000..c911960 --- /dev/null +++ b/fjerkroa_bot/ai_responder.py @@ -0,0 +1,66 @@ +import json +import openai +from typing import Optional, List, Dict, Any, Tuple + + +class AIMessageBase(object): + def __init__(self) -> None: + pass + + def __str__(self) -> str: + return json.dumps(vars(self)) + + +class AIMessage(AIMessageBase): + def __init__(self, user: str, message: str) -> None: + self.user = user + self.message = message + + +class AIResponse(AIMessageBase): + def __init__(self, answer: str, answer_needed: bool, staff: Optional[str], picture: Optional[str], hack: bool) -> None: + self.answer = answer + self.answer_needed = answer_needed + self.staff = staff + self.picture = picture + self.hack = hack + + +class AIResponder(object): + def __init__(self, config: Dict[str, Any]) -> None: + self.config = config + self.history: List[Dict[str, Any]] = [] + + def _message(self, message: AIMessage, limit: Optional[int] = None) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + messages = [] + messages.append({"role": "system", "content": self.config["system"]}) + if limit is None: + history = self.history[:] + else: + history = self.history[-limit:] + history.append({"role": "user", "content": str(message)}) + for msg in history: + messages.append(msg) + return messages, history + + async def send(self, message: AIMessage) -> AIResponse: + limit = self.config["history-limit"] + while True: + messages, history = self._message(message, limit) + try: + result = await openai.ChatCompletion.acreate(model=self.config["model"], + messages=messages, + temperature=self.config["temperature"], + top_p=self.config["top-p"], + presence_penalty=self.config["presence-penalty"], + frequency_penalty=self.config["frequency-penalty"]) + except openai.error.InvalidRequestError as err: + if 'maximum context length is' in str(err) and limit > 2: + limit -= 1 + continue + raise err + answer = result['choices'][0]['message'] + response = json.loads(answer['content']) + history.append(answer) + self.history = history + return response diff --git a/fjerkroa_bot/discord_bot.py b/fjerkroa_bot/discord_bot.py index 2a33606..bd1f332 100644 --- a/fjerkroa_bot/discord_bot.py +++ b/fjerkroa_bot/discord_bot.py @@ -6,6 +6,7 @@ from discord import Message from discord.ext import commands from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler +from .ai_responder import AIResponder, AIMessage class ConfigFileHandler(FileSystemEventHandler): @@ -18,11 +19,6 @@ class ConfigFileHandler(FileSystemEventHandler): class FjerkroaBot(commands.Bot): def __init__(self, config_file: str): - assert not hasattr(self, 'config_file') - assert not hasattr(self, 'config') - assert not hasattr(self, 'observer') - assert not hasattr(self, 'file_handler') - self.config_file = config_file self.config = self.load_config(self.config_file) intents = discord.Intents.default() @@ -34,6 +30,8 @@ class FjerkroaBot(commands.Bot): self.observer.schedule(self.file_handler, path=".", recursive=False) self.observer.start() + self.airesponder = AIResponder(self.config) + super().__init__(command_prefix="!", case_insensitive=True, intents=intents) @classmethod @@ -44,12 +42,28 @@ class FjerkroaBot(commands.Bot): def on_config_file_modified(self, event): if event.src_path == self.config_file: self.config = self.load_config(self.config_file) + self.airesponder.config = self.config async def on_ready(self): print(f"We have logged in as {self.user}") + self.staff_channel = None + self.welcome_channel = None + for guild in self.guilds: + if self.staff_channel is None: + self.staff_channel = discord.utils.get(guild.channels, name=self.config['staff-channel']) + if self.welcome_channel is None: + self.welcome_channel = discord.utils.get(guild.channels, name=self.config['welcome-channel']) async def on_message(self, message: Message) -> None: - await message.channel.send("Hello!") + msg = AIMessage(message.author.name, str(message.content).strip()) + response = await self.airesponder.send(msg) + if response.staff is not None: + async with self.staff_channel.typing(): + self.staff_channel.send(response.staff) + if not response.answer_needed: + return + async with message.channel.typing(): + message.channel.send(response.answer) async def close(self): self.observer.stop() diff --git a/tests/test_main.py b/tests/test_main.py index 859a401..f8b6e97 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -20,30 +20,34 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase): "temperature": 0.9, } self.history_data = [] - - -class TestFunctionality(TestBotBase): - - def test_load_config(self): - with patch('builtins.open', mock_open(read_data=json.dumps(self.config_data))): - result = FjerkroaBot.load_config('config.json') - self.assertEqual(result, self.config_data) - - async def test_on_message_event(self): with patch.object(FjerkroaBot, 'load_config', new=lambda s, c: self.config_data), \ patch.object(FjerkroaBot, 'user', new_callable=PropertyMock) as mock_user: mock_user.return_value = MagicMock(spec=User) mock_user.return_value.id = 12 - bot = FjerkroaBot('config.json') - message = MagicMock(spec=Message) - message.content = "Hello, how are you?" - message.author = AsyncMock(spec=User) - message.author.id = 123 - message.author.bot = False - message.channel = AsyncMock(spec=TextChannel) - message.channel.send = AsyncMock() - await bot.on_message(message) - message.channel.send.assert_called_once_with("Hello!") + self.bot = FjerkroaBot('config.json') + + def create_message(self, message: str) -> Message: + message = MagicMock(spec=Message) + message.content = "Hello, how are you?" + message.author = AsyncMock(spec=User) + message.author.id = 123 + message.author.bot = False + message.channel = AsyncMock(spec=TextChannel) + message.channel.send = AsyncMock() + return message + + +class TestFunctionality(TestBotBase): + + def test_load_config(self) -> None: + with patch('builtins.open', mock_open(read_data=json.dumps(self.config_data))): + result = FjerkroaBot.load_config('config.json') + self.assertEqual(result, self.config_data) + + async def test_on_message_event(self) -> None: + message = self.create_message("Hello there! How are you?") + await self.bot.on_message(message) + message.channel.send.assert_called_once_with("Hello!") if __name__ == "__mait__":