diff --git a/fjerkroa_bot/discord_bot.py b/fjerkroa_bot/discord_bot.py index ffd0aa3..b6c5c5b 100644 --- a/fjerkroa_bot/discord_bot.py +++ b/fjerkroa_bot/discord_bot.py @@ -3,12 +3,12 @@ import argparse import toml import discord import logging -from discord import Message, TextChannel +from discord import Message, TextChannel, DMChannel from discord.ext import commands from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler from .ai_responder import AIResponder, AIMessage -from typing import Optional +from typing import Optional, Union class ConfigFileHandler(FileSystemEventHandler): @@ -68,7 +68,7 @@ class FjerkroaBot(commands.Bot): async def on_message(self, message: Message) -> None: if self.user is not None and message.author.id == self.user.id: return - if not isinstance(message.channel, TextChannel): + if not isinstance(message.channel, (TextChannel, DMChannel)): return message_content = str(message.content).strip() if len(message_content) < 1: @@ -82,8 +82,8 @@ class FjerkroaBot(commands.Bot): def channel_by_name(self, channel_name: Optional[str], - fallback_channel: Optional[TextChannel] = None - ) -> Optional[TextChannel]: + fallback_channel: Optional[Union[TextChannel, DMChannel]] = None + ) -> Optional[Union[TextChannel, DMChannel]]: if channel_name is None: return fallback_channel for guild in self.guilds: @@ -92,11 +92,16 @@ class FjerkroaBot(commands.Bot): return channel return fallback_channel - async def respond(self, message: AIMessage, channel: TextChannel) -> None: - try: + async def respond(self, + message: AIMessage, + channel: Union[TextChannel, DMChannel]) -> None: + if isinstance(channel, DMChannel): + if channel.recipient is not None: + channel_name = str(channel.recipient.name) + else: + channel_name = str(channel.id) + else: channel_name = str(channel.name) - except Exception: - channel_name = str(channel.id) if channel_name in self.config.get('ignore-channels', []) and not message.direct: logging.info(f"ignore message {repr(message)} for channel {channel_name}") return