diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index 479f9be..1c28f25 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -17,9 +17,10 @@ class AIMessageBase(object): class AIMessage(AIMessageBase): - def __init__(self, user: str, message: str) -> None: + def __init__(self, user: str, message: str, channel: str = "chat") -> None: self.user = user self.message = message + self.channel = channel class AIResponse(AIMessageBase): diff --git a/fjerkroa_bot/discord_bot.py b/fjerkroa_bot/discord_bot.py index 3c918a0..fba2fcc 100644 --- a/fjerkroa_bot/discord_bot.py +++ b/fjerkroa_bot/discord_bot.py @@ -40,10 +40,12 @@ class FjerkroaBot(commands.Bot): @classmethod def load_config(self, config_file: str = "config.toml"): + logging.info(f"config file {config_file} changed, reloading.") with open(config_file, encoding='utf-8') as file: return toml.load(file) def on_config_file_modified(self, event): + logging.info(f"file {event.src_path} modified") if event.src_path == self.config_file: self.config = self.load_config(self.config_file) self.airesponder.config = self.config @@ -60,19 +62,22 @@ class FjerkroaBot(commands.Bot): async def on_member_join(self, member): logging.info(f"User {member.name} joined") - msg = AIMessage(member.name, self.config['join-message'].replace('{name}', member.name)) if self.welcome_channel is not None: + msg = AIMessage(member.name, self.config['join-message'].replace('{name}', member.name), str(self.welcome_channel.name)) await self.respond(msg, self.welcome_channel) async def on_message(self, message: Message) -> None: if self.user is not None and message.author.id == self.user.id: return - msg = AIMessage(message.author.name, str(message.content).strip()) + message_content = str(message.content).strip() + if len(message_content) < 1: + return + msg = AIMessage(message.author.name, message_content, str(message.channel.name)) await self.respond(msg, message.channel) async def respond(self, message: AIMessage, channel: TextChannel) -> None: try: - channel_name = channel.name + channel_name = str(channel.name) except Exception: channel_name = str(channel.id) logging.info(f"handle message {str(message)} for channel {channel_name}") diff --git a/tests/test_main.py b/tests/test_main.py index 53a4cb0..cfeca1a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -35,7 +35,7 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase): patch.object(FjerkroaBot, 'user', new_callable=PropertyMock) as mock_user: mock_user.return_value = MagicMock(spec=User) mock_user.return_value.id = 12 - self.bot = FjerkroaBot('config.json') + self.bot = FjerkroaBot('config.toml') self.bot.staff_channel = AsyncMock(spec=TextChannel) self.bot.staff_channel.send = AsyncMock() self.bot.welcome_channel = AsyncMock(spec=TextChannel) @@ -58,7 +58,7 @@ class TestFunctionality(TestBotBase): def test_load_config(self) -> None: with patch('builtins.open', mock_open(read_data=toml.dumps(self.config_data))): - result = FjerkroaBot.load_config('config.json') + result = FjerkroaBot.load_config('config.toml') self.assertEqual(result, self.config_data) async def test_on_message_event(self) -> None: