diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index 7b1fcf0..1126fdf 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -21,7 +21,7 @@ class AIMessage(AIMessageBase): class AIResponse(AIMessageBase): - def __init__(self, answer: str, answer_needed: bool, staff: Optional[str], picture: Optional[str], hack: bool) -> None: + def __init__(self, answer: Optional[str], answer_needed: bool, staff: Optional[str], picture: Optional[str], hack: bool) -> None: self.answer = answer self.answer_needed = answer_needed self.staff = staff @@ -58,6 +58,25 @@ class AIResponder(object): logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}") raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries") + async def post_process(self, response: Dict[str, Any]) -> AIResponse: + for fld in ('answer', 'staff', 'picture'): + if str(response[fld]).strip().lower() in ('none', '', 'null'): + response[fld] = None + for fld in ('answer_needed', 'hack'): + if str(response[fld]).strip().lower() == 'true': + response[fld] = True + else: + response[fld] = False + if response['answer'] is None: + response['answer_needed'] = False + else: + response['answer'] = str(response['answer']) + return AIResponse(response['answer'], + response['answer_needed'], + response['staff'], + response['picture'], + response['hack']) + async def send(self, message: AIMessage) -> AIResponse: limit = self.config["history-limit"] for _ in range(14): @@ -81,17 +100,5 @@ class AIResponder(object): continue history.append(answer) self.history = history - for fld in ('answer', 'staff', 'picture'): - if str(response[fld]).strip().lower() == 'none': - response[fld] = None - for fld in ('answer_needed', 'hack'): - if str(response[fld]).strip().lower() == 'true': - response[fld] = True - else: - response[fld] = False - return AIResponse(response['answer'], - response['answer_needed'], - response['staff'], - response['picture'], - response['hack']) + return await self.post_process(response) raise RuntimeError("Failed to generate answer after multiple retries") diff --git a/fjerkroa_bot/discord_bot.py b/fjerkroa_bot/discord_bot.py index fb7ab36..2fc0a04 100644 --- a/fjerkroa_bot/discord_bot.py +++ b/fjerkroa_bot/discord_bot.py @@ -2,7 +2,8 @@ import sys import argparse import json import discord -from discord import Message +import logging +from discord import Message, TextChannel from discord.ext import commands from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler @@ -49,25 +50,35 @@ class FjerkroaBot(commands.Bot): self.staff_channel = None self.welcome_channel = None for guild in self.guilds: - if self.staff_channel is None: + if self.staff_channel is None and self.config['staff-channel'] is not None: self.staff_channel = discord.utils.get(guild.channels, name=self.config['staff-channel']) - if self.welcome_channel is None: + if self.welcome_channel is None and self.config['welcome-channel'] is not None: self.welcome_channel = discord.utils.get(guild.channels, name=self.config['welcome-channel']) + async def on_member_join(self, member): + logging.info(f"User {member.name} joined") + msg = AIMessage(member.name, self.config['join-message']) + if self.welcome_channel is not None: + await self.respond(msg, self.welcome_channel) + async def on_message(self, message: Message) -> None: msg = AIMessage(message.author.name, str(message.content).strip()) - response = await self.airesponder.send(msg) - if response.staff is not None: + await self.respond(msg, message.channel) + + async def respond(self, message: AIMessage, channel: TextChannel) -> None: + logging.info(f"handle message {str(message)} for channel {channel.name}") + response = await self.airesponder.send(message) + if response.staff is not None and self.staff_channel is not None: async with self.staff_channel.typing(): await self.staff_channel.send(response.staff) if not response.answer_needed: return - async with message.channel.typing(): + async with channel.typing(): if response.picture is not None: images = [discord.File(fp=await self.airesponder.draw(response.picture), filename="image.png")] - await message.channel.send(response.answer, files=images) + await channel.send(response.answer, files=images) else: - await message.channel.send(response.answer) + await channel.send(response.answer) async def close(self): self.observer.stop() diff --git a/tests/test_ai.py b/tests/test_ai.py index 36d6cb4..396feaa 100644 --- a/tests/test_ai.py +++ b/tests/test_ai.py @@ -38,6 +38,7 @@ You always try to say something positive about the current day and the Fjærkroa async def test_responder1(self) -> None: response = await self.bot.airesponder.send(AIMessage("lala", "who are you?")) + print(response) self.assertAIResponse(response, AIResponse('test', True, None, None, False))