From 1a5da0ae7c606e6859a105b558b891269e759da7 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Fri, 8 Aug 2025 19:34:41 +0200 Subject: [PATCH] Add comprehensive test suite to improve coverage and fix igdblib bugs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add extensive tests for igdblib.py (0% -> 100% coverage expected) - Add tests for leonardo_draw.py AI image generation - Add tests for openai_responder.py with GPT integration - Add tests for discord_bot.py bot functionality - Add extended tests for ai_responder.py edge cases - Fix critical bugs in igdblib.py: * Fix platforms() method treating name as string instead of list * Fix game_info() method missing endpoint parameter * Add safe dictionary access with .get() methods Coverage improvements target areas with lowest coverage to maximize impact. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- fjerkroa_bot/igdblib.py | 5 +- tests/test_ai_responder_extended.py | 424 +++++++++++++++++++++++++ tests/test_discord_bot.py | 466 ++++++++++++++++++++++++++++ tests/test_igdblib.py | 194 ++++++++++++ tests/test_leonardo_draw.py | 247 +++++++++++++++ tests/test_main_entry.py | 61 ++++ tests/test_openai_responder.py | 350 +++++++++++++++++++++ 7 files changed, 1745 insertions(+), 2 deletions(-) create mode 100644 tests/test_ai_responder_extended.py create mode 100644 tests/test_discord_bot.py create mode 100644 tests/test_igdblib.py create mode 100644 tests/test_leonardo_draw.py create mode 100644 tests/test_main_entry.py create mode 100644 tests/test_openai_responder.py diff --git a/fjerkroa_bot/igdblib.py b/fjerkroa_bot/igdblib.py index 902b61a..c81f1b5 100644 --- a/fjerkroa_bot/igdblib.py +++ b/fjerkroa_bot/igdblib.py @@ -60,18 +60,19 @@ class IGDBQuery(object): ) ret = {} for p in platforms: - names = p["name"] + names = [p["name"]] if "alternative_name" in p: names.append(p["alternative_name"]) if "abbreviation" in p: names.append(p["abbreviation"]) - family = self.platform_families()[p["id"]] if "platform_family" in p else None + family = self.platform_families().get(p.get("platform_family")) if "platform_family" in p else None ret[p["id"]] = {"names": names, "family": family} return ret def game_info(self, name): game_info = self.generalized_igdb_query( {"name": name}, + "games", [ "id", "name", diff --git a/tests/test_ai_responder_extended.py b/tests/test_ai_responder_extended.py new file mode 100644 index 0000000..268b868 --- /dev/null +++ b/tests/test_ai_responder_extended.py @@ -0,0 +1,424 @@ +import os +import pickle +import tempfile +import unittest +from pathlib import Path +from unittest.mock import AsyncMock, Mock, mock_open, patch + +from fjerkroa_bot.ai_responder import ( + AIMessage, + AIResponse, + AIResponder, + AIResponderBase, + async_cache_to_file, + exponential_backoff, + parse_maybe_json, + pp, +) + + +class TestAIResponderExtended(unittest.IsolatedAsyncioTestCase): + """Extended tests for AIResponder to improve coverage.""" + + def setUp(self): + self.config = { + "system": "You are a test AI", + "history-limit": 5, + "history-directory": "/tmp/test_history", + "short-path": [["test.*", "user.*"]], + "leonardo-token": "test_leonardo_token", + } + self.responder = AIResponder(self.config, "test_channel") + + async def test_exponential_backoff(self): + """Test exponential backoff generator.""" + backoff = exponential_backoff(base=2, max_attempts=3, max_sleep=10, jitter=0.1) + + values = [] + for _ in range(3): + values.append(next(backoff)) + + # Should have 3 values + self.assertEqual(len(values), 3) + # Each should be increasing (roughly) + self.assertLess(values[0], values[1]) + self.assertLess(values[1], values[2]) + # All should be within reasonable bounds + for val in values: + self.assertGreater(val, 0) + self.assertLessEqual(val, 10) + + def test_parse_maybe_json_complex_cases(self): + """Test parse_maybe_json with complex cases.""" + # Test nested JSON + nested = '{"user": {"name": "John", "age": 30}, "status": "active"}' + result = parse_maybe_json(nested) + expected = "John\n30\nactive" + self.assertEqual(result, expected) + + # Test array with objects + array_objects = '[{"name": "Alice"}, {"name": "Bob"}]' + result = parse_maybe_json(array_objects) + expected = "Alice\nBob" + self.assertEqual(result, expected) + + # Test mixed types in array + mixed_array = '[{"name": "Alice"}, "simple string", 123]' + result = parse_maybe_json(mixed_array) + expected = "Alice\nsimple string\n123" + self.assertEqual(result, expected) + + def test_pp_function(self): + """Test pretty print function.""" + # Test with string + result = pp("test string") + self.assertEqual(result, "test string") + + # Test with dict + test_dict = {"key": "value", "number": 42} + result = pp(test_dict) + self.assertIn("key", result) + self.assertIn("value", result) + self.assertIn("42", result) + + # Test with list + test_list = ["item1", "item2", 123] + result = pp(test_list) + self.assertIn("item1", result) + self.assertIn("item2", result) + self.assertIn("123", result) + + def test_ai_message_creation(self): + """Test AIMessage creation and attributes.""" + msg = AIMessage("TestUser", "Hello world", "general", True) + + self.assertEqual(msg.user, "TestUser") + self.assertEqual(msg.message, "Hello world") + self.assertEqual(msg.channel, "general") + self.assertTrue(msg.direct) + self.assertTrue(msg.historise_question) # Default value + + def test_ai_response_creation(self): + """Test AIResponse creation and string representation.""" + response = AIResponse("Hello!", True, "chat", "Staff alert", "picture description", True, False) + + self.assertEqual(response.answer, "Hello!") + self.assertTrue(response.answer_needed) + self.assertEqual(response.channel, "chat") + self.assertEqual(response.staff, "Staff alert") + self.assertEqual(response.picture, "picture description") + self.assertTrue(response.hack) + self.assertFalse(response.picture_edit) + + # Test string representation + str_repr = str(response) + self.assertIn("Hello!", str_repr) + + def test_ai_responder_base_draw_method(self): + """Test AIResponderBase draw method selection.""" + base = AIResponderBase(self.config) + + # Should raise NotImplementedError since it's abstract + with self.assertRaises(AttributeError): + # This will fail because AIResponderBase doesn't implement the required methods + pass + + @patch("pathlib.Path.exists") + @patch("builtins.open", new_callable=mock_open) + def test_responder_init_with_history_file(self, mock_open_file, mock_exists): + """Test responder initialization with existing history file.""" + # Mock history file exists + mock_exists.return_value = True + + # Mock pickle data + history_data = [{"role": "user", "content": "test"}] + with patch("pickle.load", return_value=history_data): + responder = AIResponder(self.config, "test_channel") + self.assertEqual(responder.history, history_data) + + @patch("pathlib.Path.exists") + @patch("builtins.open", new_callable=mock_open) + def test_responder_init_with_memory_file(self, mock_open_file, mock_exists): + """Test responder initialization with existing memory file.""" + mock_exists.return_value = True + + memory_data = "Previous conversation context" + with patch("pickle.load", return_value=memory_data): + responder = AIResponder(self.config, "test_channel") + # Memory loading happens after history loading + # We can't easily test this without more complex mocking + + def test_build_messages_with_memory(self): + """Test message building with memory.""" + self.responder.memory = "Previous context about user preferences" + message = AIMessage("TestUser", "What do you recommend?", "chat", False) + + messages = self.responder.build_messages(message) + + # Should include memory in system message + system_msg = messages[0] + self.assertEqual(system_msg["role"], "system") + self.assertIn("Previous context", system_msg["content"]) + + def test_build_messages_with_history(self): + """Test message building with conversation history.""" + self.responder.history = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ] + + message = AIMessage("TestUser", "How are you?", "chat", False) + messages = self.responder.build_messages(message) + + # Should include history messages + self.assertGreater(len(messages), 2) # System + history + current + + def test_build_messages_basic(self): + """Test basic message building.""" + message = AIMessage("TestUser", "Hello", "chat", False) + + messages = self.responder.build_messages(message) + + # Should have at least system message and user message + self.assertGreater(len(messages), 1) + self.assertEqual(messages[0]["role"], "system") + self.assertEqual(messages[-1]["role"], "user") + + def test_should_use_short_path_matching(self): + """Test short path detection with matching patterns.""" + message = AIMessage("user123", "Quick question", "test-channel", False) + + result = self.responder.should_use_short_path(message) + + # Should match the configured pattern + self.assertTrue(result) + + def test_should_use_short_path_no_config(self): + """Test short path when not configured.""" + config_no_shortpath = {"system": "Test AI", "history-limit": 5} + responder = AIResponder(config_no_shortpath) + + message = AIMessage("user123", "Question", "test-channel", False) + result = responder.should_use_short_path(message) + + self.assertFalse(result) + + def test_should_use_short_path_no_match(self): + """Test short path with non-matching patterns.""" + message = AIMessage("admin", "Question", "admin-channel", False) + + result = self.responder.should_use_short_path(message) + + # Should not match the configured pattern + self.assertFalse(result) + + async def test_post_process_link_replacement(self): + """Test post-processing link replacement.""" + request = AIMessage("user", "test", "chat", False) + + # Test markdown link replacement + message_data = { + "answer": "Check out [Google](https://google.com) for search", + "answer_needed": True, + "channel": None, + "staff": None, + "picture": None, + "hack": False, + } + + result = await self.responder.post_process(request, message_data) + + # Should replace markdown links with URLs + self.assertEqual(result.answer, "Check out https://google.com for search") + + async def test_post_process_link_removal(self): + """Test post-processing link removal with @ prefix.""" + request = AIMessage("user", "test", "chat", False) + + message_data = { + "answer": "Visit @[Example](https://example.com) site", + "answer_needed": True, + "channel": None, + "staff": None, + "picture": None, + "hack": False, + } + + result = await self.responder.post_process(request, message_data) + + # Should remove @ links entirely + self.assertEqual(result.answer, "Visit Example site") + + async def test_post_process_translation(self): + """Test post-processing with translation.""" + request = AIMessage("user", "Bonjour", "chat", False) + + # Mock the translate method + self.responder.translate = AsyncMock(return_value="Hello") + + message_data = { + "answer": "Bonjour!", + "answer_needed": True, + "channel": None, + "staff": None, + "picture": None, + "hack": False, + } + + result = await self.responder.post_process(request, message_data) + + # Should translate the answer + self.responder.translate.assert_called_once_with("Bonjour!") + + def test_update_history_memory_update(self): + """Test history update with memory rewriting.""" + # Mock memory_rewrite method + self.responder.memory_rewrite = AsyncMock(return_value="Updated memory") + + question = {"content": "What is AI?"} + answer = {"content": "AI is artificial intelligence"} + + # This is a synchronous method, so we can't easily test async memory rewrite + # Let's test the basic functionality + self.responder.update_history(question, answer, 10) + + # Should add to history + self.assertEqual(len(self.responder.history), 2) + self.assertEqual(self.responder.history[0], question) + self.assertEqual(self.responder.history[1], answer) + + def test_update_history_limit_enforcement(self): + """Test history limit enforcement.""" + # Fill history beyond limit + for i in range(10): + question = {"content": f"Question {i}"} + answer = {"content": f"Answer {i}"} + self.responder.update_history(question, answer, 4) + + # Should only keep the most recent entries within limit + self.assertLessEqual(len(self.responder.history), 4) + + @patch("builtins.open", new_callable=mock_open) + @patch("pickle.dump") + def test_update_history_file_save(self, mock_pickle_dump, mock_open_file): + """Test history saving to file.""" + # Set up a history file + self.responder.history_file = Path("/tmp/test_history.dat") + + question = {"content": "Test question"} + answer = {"content": "Test answer"} + + self.responder.update_history(question, answer, 10) + + # Should save to file + mock_open_file.assert_called_with("/tmp/test_history.dat", "wb") + mock_pickle_dump.assert_called_once() + + async def test_send_with_retries(self): + """Test send method with retry logic.""" + # Mock chat method to fail then succeed + self.responder.chat = AsyncMock() + self.responder.chat.side_effect = [ + (None, 5), # First call fails + ({"content": "Success!", "role": "assistant"}, 5), # Second call succeeds + ] + + # Mock other methods + self.responder.fix = AsyncMock(return_value='{"answer": "Fixed!", "answer_needed": true, "channel": null, "staff": null, "picture": null, "hack": false}') + self.responder.post_process = AsyncMock() + mock_response = AIResponse("Fixed!", True, None, None, None, False, False) + self.responder.post_process.return_value = mock_response + + message = AIMessage("user", "test", "chat", False) + result = await self.responder.send(message) + + # Should retry and eventually succeed + self.assertEqual(self.responder.chat.call_count, 2) + self.assertEqual(result, mock_response) + + async def test_send_max_retries_exceeded(self): + """Test send method when max retries are exceeded.""" + # Mock chat method to always fail + self.responder.chat = AsyncMock(return_value=(None, 5)) + + message = AIMessage("user", "test", "chat", False) + + with self.assertRaises(RuntimeError) as context: + await self.responder.send(message) + + self.assertIn("Failed to generate answer", str(context.exception)) + + async def test_draw_method_dispatch(self): + """Test draw method dispatching to correct implementation.""" + # This AIResponder doesn't implement draw methods, so this will fail + with self.assertRaises(AttributeError): + await self.responder.draw("test description") + + +class TestAsyncCacheToFile(unittest.IsolatedAsyncioTestCase): + """Test the async cache decorator.""" + + def setUp(self): + self.cache_file = "test_cache.dat" + self.call_count = 0 + + def tearDown(self): + # Clean up cache file + try: + os.remove(self.cache_file) + except FileNotFoundError: + pass + + async def test_cache_miss_and_hit(self): + """Test cache miss followed by cache hit.""" + + @async_cache_to_file(self.cache_file) + async def test_function(x, y): + self.call_count += 1 + return f"result_{x}_{y}" + + # First call - cache miss + result1 = await test_function("a", "b") + self.assertEqual(result1, "result_a_b") + self.assertEqual(self.call_count, 1) + + # Second call - cache hit + result2 = await test_function("a", "b") + self.assertEqual(result2, "result_a_b") + self.assertEqual(self.call_count, 1) # Should not increment + + async def test_cache_different_args(self): + """Test cache with different arguments.""" + + @async_cache_to_file(self.cache_file) + async def test_function(x): + self.call_count += 1 + return f"result_{x}" + + # Different arguments should not hit cache + result1 = await test_function("a") + result2 = await test_function("b") + + self.assertEqual(result1, "result_a") + self.assertEqual(result2, "result_b") + self.assertEqual(self.call_count, 2) + + async def test_cache_file_corruption(self): + """Test cache behavior with corrupted cache file.""" + # Create a corrupted cache file + with open(self.cache_file, "w") as f: + f.write("corrupted data") + + @async_cache_to_file(self.cache_file) + async def test_function(x): + self.call_count += 1 + return f"result_{x}" + + # Should handle corruption gracefully + result = await test_function("test") + self.assertEqual(result, "result_test") + self.assertEqual(self.call_count, 1) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_discord_bot.py b/tests/test_discord_bot.py new file mode 100644 index 0000000..26d2427 --- /dev/null +++ b/tests/test_discord_bot.py @@ -0,0 +1,466 @@ +import unittest +from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, mock_open, patch + +import discord +from discord import DMChannel, Member, Message, TextChannel, User + +from fjerkroa_bot import FjerkroaBot +from fjerkroa_bot.ai_responder import AIMessage, AIResponse + + +class TestFjerkroaBot(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + """Set up test fixtures.""" + self.config_data = { + "discord-token": "test_token", + "openai-key": "test_openai_key", + "model": "gpt-4", + "temperature": 0.9, + "max-tokens": 1024, + "top-p": 1.0, + "presence-penalty": 1.0, + "frequency-penalty": 1.0, + "history-limit": 10, + "welcome-channel": "welcome", + "staff-channel": "staff", + "chat-channel": "chat", + "join-message": "Welcome {name}!", + "system": "You are a helpful AI", + "additional-responders": ["gaming", "music"], + "short-path": [[".*", ".*"]] + } + + with patch.object(FjerkroaBot, "load_config", return_value=self.config_data): + with patch.object(FjerkroaBot, "user", new_callable=PropertyMock) as mock_user: + mock_user.return_value = MagicMock(spec=User) + mock_user.return_value.id = 123456 + + self.bot = FjerkroaBot("test_config.toml") + + # Mock channels + self.bot.chat_channel = AsyncMock(spec=TextChannel) + self.bot.staff_channel = AsyncMock(spec=TextChannel) + self.bot.welcome_channel = AsyncMock(spec=TextChannel) + + # Mock guilds and channels + mock_guild = AsyncMock() + mock_channel = AsyncMock(spec=TextChannel) + mock_channel.name = "test-channel" + mock_guild.channels = [mock_channel] + self.bot.guilds = [mock_guild] + + def test_load_config(self): + """Test configuration loading.""" + test_config = {"key": "value"} + with patch("builtins.open", mock_open(read_data='key = "value"')): + with patch("tomlkit.load", return_value=test_config): + result = FjerkroaBot.load_config("test.toml") + self.assertEqual(result, test_config) + + def test_channel_by_name(self): + """Test finding channels by name.""" + # Mock guild and channels + mock_channel1 = Mock() + mock_channel1.name = "general" + mock_channel2 = Mock() + mock_channel2.name = "staff" + + mock_guild = Mock() + mock_guild.channels = [mock_channel1, mock_channel2] + self.bot.guilds = [mock_guild] + + result = self.bot.channel_by_name("staff") + self.assertEqual(result, mock_channel2) + + # Test channel not found + result = self.bot.channel_by_name("nonexistent") + self.assertIsNone(result) + + def test_channel_by_name_no_ignore(self): + """Test channel_by_name with no_ignore flag.""" + mock_guild = Mock() + mock_guild.channels = [] + self.bot.guilds = [mock_guild] + + # Should return None when not found with no_ignore=True + result = self.bot.channel_by_name("nonexistent", no_ignore=True) + self.assertIsNone(result) + + async def test_on_ready(self): + """Test bot ready event.""" + with patch("fjerkroa_bot.discord_bot.logging") as mock_logging: + await self.bot.on_ready() + mock_logging.info.assert_called() + + async def test_on_member_join(self): + """Test member join event.""" + mock_member = Mock(spec=Member) + mock_member.name = "TestUser" + mock_member.bot = False + + mock_channel = AsyncMock() + self.bot.welcome_channel = mock_channel + + # Mock the AIResponder + mock_response = AIResponse("Welcome!", True, None, None, None, False, False) + self.bot.airesponder.send = AsyncMock(return_value=mock_response) + + await self.bot.on_member_join(mock_member) + + # Verify the welcome message was sent + self.bot.airesponder.send.assert_called_once() + mock_channel.send.assert_called_once_with("Welcome!") + + async def test_on_member_join_bot_member(self): + """Test that bot members are ignored on join.""" + mock_member = Mock(spec=Member) + mock_member.bot = True + + self.bot.airesponder.send = AsyncMock() + + await self.bot.on_member_join(mock_member) + + # Should not send message for bot members + self.bot.airesponder.send.assert_not_called() + + async def test_on_message_bot_message(self): + """Test that bot messages are ignored.""" + mock_message = Mock(spec=Message) + mock_message.author.bot = True + + self.bot.handle_message_through_responder = AsyncMock() + + await self.bot.on_message(mock_message) + + self.bot.handle_message_through_responder.assert_not_called() + + async def test_on_message_self_message(self): + """Test that own messages are ignored.""" + mock_message = Mock(spec=Message) + mock_message.author.bot = False + mock_message.author.id = 123456 # Same as bot user ID + + self.bot.handle_message_through_responder = AsyncMock() + + await self.bot.on_message(mock_message) + + self.bot.handle_message_through_responder.assert_not_called() + + async def test_on_message_invalid_channel_type(self): + """Test messages from unsupported channel types are ignored.""" + mock_message = Mock(spec=Message) + mock_message.author.bot = False + mock_message.author.id = 999999 # Different from bot + mock_message.channel = Mock() # Not TextChannel or DMChannel + + self.bot.handle_message_through_responder = AsyncMock() + + await self.bot.on_message(mock_message) + + self.bot.handle_message_through_responder.assert_not_called() + + async def test_on_message_wichtel_command(self): + """Test wichtel command handling.""" + mock_message = Mock(spec=Message) + mock_message.author.bot = False + mock_message.author.id = 999999 + mock_message.channel = AsyncMock(spec=TextChannel) + mock_message.content = "!wichtel @user1 @user2" + mock_message.mentions = [Mock(), Mock()] # Two users + + self.bot.wichtel = AsyncMock() + + await self.bot.on_message(mock_message) + + self.bot.wichtel.assert_called_once_with(mock_message) + + async def test_on_message_normal_message(self): + """Test normal message handling.""" + mock_message = Mock(spec=Message) + mock_message.author.bot = False + mock_message.author.id = 999999 + mock_message.channel = AsyncMock(spec=TextChannel) + mock_message.content = "Hello there" + + self.bot.handle_message_through_responder = AsyncMock() + + await self.bot.on_message(mock_message) + + self.bot.handle_message_through_responder.assert_called_once_with(mock_message) + + async def test_wichtel_insufficient_users(self): + """Test wichtel with insufficient users.""" + mock_message = Mock(spec=Message) + mock_message.mentions = [Mock()] # Only one user + mock_channel = AsyncMock() + mock_message.channel = mock_channel + + await self.bot.wichtel(mock_message) + + mock_channel.send.assert_called_once_with( + "Bitte erwähne mindestens zwei Benutzer für das Wichteln." + ) + + async def test_wichtel_no_valid_assignment(self): + """Test wichtel when no valid derangement can be found.""" + mock_message = Mock(spec=Message) + mock_user1 = Mock() + mock_user2 = Mock() + mock_message.mentions = [mock_user1, mock_user2] + mock_channel = AsyncMock() + mock_message.channel = mock_channel + + # Mock generate_derangement to return None + with patch.object(FjerkroaBot, 'generate_derangement', return_value=None): + await self.bot.wichtel(mock_message) + + mock_channel.send.assert_called_once_with( + "Konnte keine gültige Zuordnung finden. Bitte versuche es erneut." + ) + + async def test_wichtel_successful_assignment(self): + """Test successful wichtel assignment.""" + mock_message = Mock(spec=Message) + mock_user1 = AsyncMock() + mock_user1.mention = "@user1" + mock_user2 = AsyncMock() + mock_user2.mention = "@user2" + mock_message.mentions = [mock_user1, mock_user2] + mock_channel = AsyncMock() + mock_message.channel = mock_channel + + # Mock successful derangement + with patch.object(FjerkroaBot, 'generate_derangement', return_value=[mock_user2, mock_user1]): + await self.bot.wichtel(mock_message) + + # Check that DMs were sent + mock_user1.send.assert_called_once_with("Dein Wichtel ist @user2") + mock_user2.send.assert_called_once_with("Dein Wichtel ist @user1") + + async def test_wichtel_dm_forbidden(self): + """Test wichtel when DM sending is forbidden.""" + mock_message = Mock(spec=Message) + mock_user1 = AsyncMock() + mock_user1.mention = "@user1" + mock_user1.send.side_effect = discord.Forbidden(Mock(), "Cannot send DM") + mock_user2 = AsyncMock() + mock_user2.mention = "@user2" + mock_message.mentions = [mock_user1, mock_user2] + mock_channel = AsyncMock() + mock_message.channel = mock_channel + + with patch.object(FjerkroaBot, 'generate_derangement', return_value=[mock_user2, mock_user1]): + await self.bot.wichtel(mock_message) + + mock_channel.send.assert_called_with("Kann @user1 keine Direktnachricht senden.") + + def test_generate_derangement_valid(self): + """Test generating valid derangement.""" + users = [Mock(), Mock(), Mock()] + + # Run multiple times to test randomness + for _ in range(10): + result = FjerkroaBot.generate_derangement(users) + if result is not None: + # Should return same number of users + self.assertEqual(len(result), len(users)) + # No user should be assigned to themselves + for i, user in enumerate(result): + self.assertNotEqual(user, users[i]) + break + else: + self.fail("Could not generate valid derangement in 10 attempts") + + def test_generate_derangement_two_users(self): + """Test derangement with exactly two users.""" + user1 = Mock() + user2 = Mock() + users = [user1, user2] + + result = FjerkroaBot.generate_derangement(users) + + # Should swap the two users + if result is not None: + self.assertEqual(len(result), 2) + self.assertEqual(result[0], user2) + self.assertEqual(result[1], user1) + + async def test_send_message_with_typing(self): + """Test sending message with typing indicator.""" + mock_responder = AsyncMock() + mock_channel = AsyncMock() + mock_message = Mock() + + mock_response = AIResponse("Hello!", True, None, None, None, False, False) + mock_responder.send.return_value = mock_response + + result = await self.bot.send_message_with_typing(mock_responder, mock_channel, mock_message) + + self.assertEqual(result, mock_response) + mock_responder.send.assert_called_once_with(mock_message) + + async def test_respond_with_answer(self): + """Test responding with an answer.""" + mock_channel = AsyncMock(spec=TextChannel) + mock_response = AIResponse("Hello!", True, "chat", "Staff message", None, False, False) + + self.bot.staff_channel = AsyncMock() + + await self.bot.respond("test message", mock_channel, mock_response) + + # Should send main message + mock_channel.send.assert_called_once_with("Hello!") + # Should send staff message + self.bot.staff_channel.send.assert_called_once_with("Staff message") + + async def test_respond_no_answer_needed(self): + """Test responding when no answer is needed.""" + mock_channel = AsyncMock(spec=TextChannel) + mock_response = AIResponse("", False, None, None, None, False, False) + + await self.bot.respond("test message", mock_channel, mock_response) + + # Should not send any message + mock_channel.send.assert_not_called() + + async def test_respond_with_picture(self): + """Test responding with picture generation.""" + mock_channel = AsyncMock(spec=TextChannel) + mock_response = AIResponse("Here's your picture!", True, None, None, "A cat", False, False) + + # Mock the draw method + mock_image = Mock() + mock_image.read.return_value = b"image_data" + self.bot.airesponder.draw = AsyncMock(return_value=mock_image) + + await self.bot.respond("test message", mock_channel, mock_response) + + # Should send message and image + mock_channel.send.assert_called() + self.bot.airesponder.draw.assert_called_once_with("A cat") + + async def test_respond_hack_detected(self): + """Test responding when hack is detected.""" + mock_channel = AsyncMock(spec=TextChannel) + mock_response = AIResponse("Nice try!", True, None, "Hack attempt detected", None, True, False) + + self.bot.staff_channel = AsyncMock() + + await self.bot.respond("test message", mock_channel, mock_response) + + # Should send hack message instead of normal response + mock_channel.send.assert_called_once_with("I am not supposed to do this.") + # Should alert staff + self.bot.staff_channel.send.assert_called_once_with("Hack attempt detected") + + async def test_handle_message_through_responder_dm(self): + """Test handling DM messages.""" + mock_message = Mock(spec=Message) + mock_message.channel = AsyncMock(spec=DMChannel) + mock_message.author.name = "TestUser" + mock_message.content = "Hello" + mock_message.channel.name = "dm" + + mock_response = AIResponse("Hi there!", True, None, None, None, False, False) + self.bot.send_message_with_typing = AsyncMock(return_value=mock_response) + self.bot.respond = AsyncMock() + + await self.bot.handle_message_through_responder(mock_message) + + # Should handle as direct message + self.bot.respond.assert_called_once() + + async def test_handle_message_through_responder_channel(self): + """Test handling channel messages.""" + mock_message = Mock(spec=Message) + mock_message.channel = AsyncMock(spec=TextChannel) + mock_message.channel.name = "general" + mock_message.author.name = "TestUser" + mock_message.content = "Hello everyone" + + mock_response = AIResponse("Hello!", True, None, None, None, False, False) + self.bot.send_message_with_typing = AsyncMock(return_value=mock_response) + self.bot.respond = AsyncMock() + + # Mock get_responder_for_channel to return the main responder + self.bot.get_responder_for_channel = Mock(return_value=self.bot.airesponder) + + await self.bot.handle_message_through_responder(mock_message) + + self.bot.respond.assert_called_once() + + def test_get_responder_for_channel_main(self): + """Test getting responder for main chat channel.""" + mock_channel = Mock() + mock_channel.name = "chat" + + responder = self.bot.get_responder_for_channel(mock_channel) + + self.assertEqual(responder, self.bot.airesponder) + + def test_get_responder_for_channel_additional(self): + """Test getting responder for additional channels.""" + mock_channel = Mock() + mock_channel.name = "gaming" + + responder = self.bot.get_responder_for_channel(mock_channel) + + # Should return the gaming responder + self.assertEqual(responder, self.bot.aichannels["gaming"]) + + def test_get_responder_for_channel_default(self): + """Test getting responder for unknown channel.""" + mock_channel = Mock() + mock_channel.name = "unknown" + + responder = self.bot.get_responder_for_channel(mock_channel) + + # Should return main responder as default + self.assertEqual(responder, self.bot.airesponder) + + async def test_on_message_edit(self): + """Test message edit event.""" + mock_before = Mock(spec=Message) + mock_after = Mock(spec=Message) + mock_after.channel = AsyncMock(spec=TextChannel) + mock_after.author.bot = False + + self.bot.add_reaction_ignore_errors = AsyncMock() + + await self.bot.on_message_edit(mock_before, mock_after) + + self.bot.add_reaction_ignore_errors.assert_called_once_with(mock_after, "✏️") + + async def test_on_message_delete(self): + """Test message delete event.""" + mock_message = Mock(spec=Message) + mock_message.channel = AsyncMock(spec=TextChannel) + + self.bot.add_reaction_ignore_errors = AsyncMock() + + await self.bot.on_message_delete(mock_message) + + # Should add delete reaction to the last message in channel + self.bot.add_reaction_ignore_errors.assert_called_once() + + async def test_add_reaction_ignore_errors_success(self): + """Test successful reaction addition.""" + mock_message = AsyncMock() + + await self.bot.add_reaction_ignore_errors(mock_message, "👍") + + mock_message.add_reaction.assert_called_once_with("👍") + + async def test_add_reaction_ignore_errors_failure(self): + """Test reaction addition with error (should be ignored).""" + mock_message = AsyncMock() + mock_message.add_reaction.side_effect = discord.HTTPException(Mock(), "Error") + + # Should not raise exception + await self.bot.add_reaction_ignore_errors(mock_message, "👍") + + mock_message.add_reaction.assert_called_once_with("👍") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_igdblib.py b/tests/test_igdblib.py new file mode 100644 index 0000000..78a92ed --- /dev/null +++ b/tests/test_igdblib.py @@ -0,0 +1,194 @@ +import unittest +from unittest.mock import Mock, patch + +import requests + +from fjerkroa_bot.igdblib import IGDBQuery + + +class TestIGDBQuery(unittest.TestCase): + def setUp(self): + self.client_id = "test_client_id" + self.api_key = "test_api_key" + self.igdb = IGDBQuery(self.client_id, self.api_key) + + def test_init(self): + """Test IGDBQuery initialization.""" + self.assertEqual(self.igdb.client_id, self.client_id) + self.assertEqual(self.igdb.igdb_api_key, self.api_key) + + @patch("fjerkroa_bot.igdblib.requests.post") + def test_send_igdb_request_success(self, mock_post): + """Test successful IGDB API request.""" + mock_response = Mock() + mock_response.json.return_value = {"id": 1, "name": "Test Game"} + mock_response.raise_for_status.return_value = None + mock_post.return_value = mock_response + + result = self.igdb.send_igdb_request("games", "fields name; limit 1;") + + self.assertEqual(result, {"id": 1, "name": "Test Game"}) + mock_post.assert_called_once_with( + "https://api.igdb.com/v4/games", + headers={"Client-ID": self.client_id, "Authorization": f"Bearer {self.api_key}"}, + data="fields name; limit 1;", + ) + + @patch("fjerkroa_bot.igdblib.requests.post") + def test_send_igdb_request_failure(self, mock_post): + """Test IGDB API request failure.""" + mock_post.side_effect = requests.RequestException("API Error") + + result = self.igdb.send_igdb_request("games", "fields name; limit 1;") + + self.assertIsNone(result) + + def test_build_query_basic(self): + """Test building basic query.""" + query = IGDBQuery.build_query(["name", "summary"]) + expected = "fields name,summary; limit 10;" + self.assertEqual(query, expected) + + def test_build_query_with_limit(self): + """Test building query with custom limit.""" + query = IGDBQuery.build_query(["name"], limit=5) + expected = "fields name; limit 5;" + self.assertEqual(query, expected) + + def test_build_query_with_offset(self): + """Test building query with offset.""" + query = IGDBQuery.build_query(["name"], offset=10) + expected = "fields name; limit 10; offset 10;" + self.assertEqual(query, expected) + + def test_build_query_with_filters(self): + """Test building query with filters.""" + filters = {"name": "Mario", "platform": "Nintendo"} + query = IGDBQuery.build_query(["name"], filters=filters) + expected = "fields name; limit 10; where name Mario & platform Nintendo;" + self.assertEqual(query, expected) + + def test_build_query_empty_fields(self): + """Test building query with empty fields.""" + query = IGDBQuery.build_query([]) + expected = "fields *; limit 10;" + self.assertEqual(query, expected) + + def test_build_query_none_fields(self): + """Test building query with None fields.""" + query = IGDBQuery.build_query(None) + expected = "fields *; limit 10;" + self.assertEqual(query, expected) + + @patch.object(IGDBQuery, "send_igdb_request") + def test_generalized_igdb_query(self, mock_send): + """Test generalized IGDB query method.""" + mock_send.return_value = [{"id": 1, "name": "Test Game"}] + + params = {"name": "Mario"} + result = self.igdb.generalized_igdb_query(params, "games", ["name"], limit=5) + + expected_filters = {"name": '~ "Mario"*'} + expected_query = 'fields name; limit 5; where name ~ "Mario"*;' + + mock_send.assert_called_once_with("games", expected_query) + self.assertEqual(result, [{"id": 1, "name": "Test Game"}]) + + @patch.object(IGDBQuery, "send_igdb_request") + def test_generalized_igdb_query_with_additional_filters(self, mock_send): + """Test generalized query with additional filters.""" + mock_send.return_value = [{"id": 1, "name": "Test Game"}] + + params = {"name": "Mario"} + additional_filters = {"platform": "= 1"} + result = self.igdb.generalized_igdb_query(params, "games", ["name"], additional_filters, limit=5) + + expected_query = 'fields name; limit 5; where name ~ "Mario"* & platform = 1;' + mock_send.assert_called_once_with("games", expected_query) + + def test_create_query_function(self): + """Test creating a query function.""" + func_def = self.igdb.create_query_function( + "test_func", + "Test function", + {"name": {"type": "string"}}, + "games", + ["name"], + limit=5 + ) + + self.assertEqual(func_def["name"], "test_func") + self.assertEqual(func_def["description"], "Test function") + self.assertEqual(func_def["parameters"]["type"], "object") + self.assertIn("function", func_def) + + @patch.object(IGDBQuery, "generalized_igdb_query") + def test_platform_families(self, mock_query): + """Test platform families caching.""" + mock_query.return_value = [ + {"id": 1, "name": "PlayStation"}, + {"id": 2, "name": "Nintendo"} + ] + + # First call + result1 = self.igdb.platform_families() + expected = {1: "PlayStation", 2: "Nintendo"} + self.assertEqual(result1, expected) + + # Second call should use cache + result2 = self.igdb.platform_families() + self.assertEqual(result2, expected) + + # Should only call the API once due to caching + mock_query.assert_called_once_with({}, "platform_families", ["id", "name"], limit=500) + + @patch.object(IGDBQuery, "generalized_igdb_query") + @patch.object(IGDBQuery, "platform_families") + def test_platforms(self, mock_families, mock_query): + """Test platforms method.""" + mock_families.return_value = {1: "PlayStation"} + mock_query.return_value = [ + { + "id": 1, + "name": "PlayStation 5", + "alternative_name": "PS5", + "abbreviation": "PS5", + "platform_family": 1 + }, + { + "id": 2, + "name": "Nintendo Switch" + } + ] + + result = self.igdb.platforms() + + expected = { + 1: {"names": ["PlayStation 5", "PS5", "PS5"], "family": "PlayStation"}, + 2: {"names": ["Nintendo Switch"], "family": None} + } + + mock_query.assert_called_once_with( + {}, "platforms", ["id", "name", "alternative_name", "abbreviation", "platform_family"], limit=500 + ) + + @patch.object(IGDBQuery, "generalized_igdb_query") + def test_game_info(self, mock_query): + """Test game info method.""" + mock_query.return_value = [{"id": 1, "name": "Super Mario Bros"}] + + result = self.igdb.game_info("Mario") + + expected_fields = [ + "id", "name", "alternative_names", "category", "release_dates", + "franchise", "language_supports", "keywords", "platforms", "rating", "summary" + ] + + mock_query.assert_called_once_with( + {"name": "Mario"}, expected_fields, limit=100 + ) + self.assertEqual(result, [{"id": 1, "name": "Super Mario Bros"}]) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_leonardo_draw.py b/tests/test_leonardo_draw.py new file mode 100644 index 0000000..4fd8c1c --- /dev/null +++ b/tests/test_leonardo_draw.py @@ -0,0 +1,247 @@ +import asyncio +import unittest +from io import BytesIO +from unittest.mock import AsyncMock, Mock, patch + +import aiohttp + +from fjerkroa_bot.leonardo_draw import LeonardoAIDrawMixIn + + +class MockLeonardoDrawer(LeonardoAIDrawMixIn): + """Mock class to test the mixin.""" + def __init__(self, config): + self.config = config + + +class TestLeonardoAIDrawMixIn(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.config = {"leonardo-token": "test_token"} + self.drawer = MockLeonardoDrawer(self.config) + + async def test_draw_leonardo_success(self): + """Test successful image generation with Leonardo AI.""" + # Mock image data + fake_image_data = b"fake_image_data" + + # Mock responses + generation_response = { + "sdGenerationJob": {"generationId": "test_generation_id"} + } + + status_response = { + "generations_by_pk": { + "generated_images": [{"url": "http://example.com/image.jpg"}] + } + } + + with patch("aiohttp.ClientSession") as mock_session_class: + # Create mock session + mock_session = AsyncMock() + mock_session_class.return_value.__aenter__.return_value = mock_session + mock_session_class.return_value.__aexit__.return_value = None + + # Mock POST request (generation) + mock_post_response = AsyncMock() + mock_post_response.json.return_value = generation_response + mock_session.post.return_value.__aenter__.return_value = mock_post_response + mock_session.post.return_value.__aexit__.return_value = None + + # Mock GET requests (status check and image download) + mock_get_response1 = AsyncMock() + mock_get_response1.json.return_value = status_response + + mock_get_response2 = AsyncMock() + mock_get_response2.read.return_value = fake_image_data + + mock_session.get.side_effect = [ + mock_session.get.return_value, # Status check + mock_session.get.return_value # Image download + ] + mock_session.get.return_value.__aenter__.side_effect = [ + mock_get_response1, # Status check + mock_get_response2 # Image download + ] + mock_session.get.return_value.__aexit__.return_value = None + + # Mock DELETE request + mock_delete_response = AsyncMock() + mock_delete_response.json.return_value = {} + mock_session.delete.return_value.__aenter__.return_value = mock_delete_response + mock_session.delete.return_value.__aexit__.return_value = None + + result = await self.drawer.draw_leonardo("A beautiful landscape") + + # Verify the result + self.assertIsInstance(result, BytesIO) + self.assertEqual(result.read(), fake_image_data) + + async def test_draw_leonardo_no_generation_job(self): + """Test when generation job is not returned.""" + generation_response = {} # No sdGenerationJob + + with patch("aiohttp.ClientSession") as mock_session_class: + mock_session = AsyncMock() + mock_session_class.return_value.__aenter__.return_value = mock_session + mock_session_class.return_value.__aexit__.return_value = None + + mock_post_response = AsyncMock() + mock_post_response.json.return_value = generation_response + mock_session.post.return_value.__aenter__.return_value = mock_post_response + mock_session.post.return_value.__aexit__.return_value = None + + with patch("asyncio.sleep") as mock_sleep: + with patch("fjerkroa_bot.leonardo_draw.exponential_backoff") as mock_backoff: + mock_backoff.return_value = iter([1, 2, 4]) # Limited attempts + + with self.assertRaises(StopIteration): + await self.drawer.draw_leonardo("test description") + + async def test_draw_leonardo_no_generations_by_pk(self): + """Test when generations_by_pk is not in response.""" + generation_response = {"sdGenerationJob": {"generationId": "test_id"}} + status_response = {} # No generations_by_pk + + with patch("aiohttp.ClientSession") as mock_session_class: + mock_session = AsyncMock() + mock_session_class.return_value.__aenter__.return_value = mock_session + mock_session_class.return_value.__aexit__.return_value = None + + # Mock POST (successful) + mock_post_response = AsyncMock() + mock_post_response.json.return_value = generation_response + mock_session.post.return_value.__aenter__.return_value = mock_post_response + mock_session.post.return_value.__aexit__.return_value = None + + # Mock GET (status check - no generations) + mock_get_response = AsyncMock() + mock_get_response.json.return_value = status_response + mock_session.get.return_value.__aenter__.return_value = mock_get_response + mock_session.get.return_value.__aexit__.return_value = None + + with patch("asyncio.sleep") as mock_sleep: + with patch("fjerkroa_bot.leonardo_draw.exponential_backoff") as mock_backoff: + mock_backoff.return_value = iter([1, 2]) # Limited attempts + + with self.assertRaises(StopIteration): + await self.drawer.draw_leonardo("test description") + + async def test_draw_leonardo_no_generated_images(self): + """Test when no generated images are available yet.""" + generation_response = {"sdGenerationJob": {"generationId": "test_id"}} + status_response = {"generations_by_pk": {"generated_images": []}} + + with patch("aiohttp.ClientSession") as mock_session_class: + mock_session = AsyncMock() + mock_session_class.return_value.__aenter__.return_value = mock_session + mock_session_class.return_value.__aexit__.return_value = None + + # Mock POST (successful) + mock_post_response = AsyncMock() + mock_post_response.json.return_value = generation_response + mock_session.post.return_value.__aenter__.return_value = mock_post_response + mock_session.post.return_value.__aexit__.return_value = None + + # Mock GET (status check - empty images) + mock_get_response = AsyncMock() + mock_get_response.json.return_value = status_response + mock_session.get.return_value.__aenter__.return_value = mock_get_response + mock_session.get.return_value.__aexit__.return_value = None + + with patch("asyncio.sleep") as mock_sleep: + with patch("fjerkroa_bot.leonardo_draw.exponential_backoff") as mock_backoff: + mock_backoff.return_value = iter([1, 2]) # Limited attempts + + with self.assertRaises(StopIteration): + await self.drawer.draw_leonardo("test description") + + async def test_draw_leonardo_exception_handling(self): + """Test exception handling during image generation.""" + with patch("aiohttp.ClientSession") as mock_session_class: + mock_session = AsyncMock() + mock_session_class.return_value.__aenter__.return_value = mock_session + mock_session_class.return_value.__aexit__.return_value = None + + # Make POST request raise an exception + mock_session.post.side_effect = Exception("Network error") + + with patch("asyncio.sleep") as mock_sleep: + with patch("fjerkroa_bot.leonardo_draw.exponential_backoff") as mock_backoff: + mock_backoff.return_value = iter([1, 2]) # Limited attempts + + with self.assertRaises(StopIteration): + await self.drawer.draw_leonardo("test description") + + async def test_draw_leonardo_request_parameters(self): + """Test that correct parameters are sent to Leonardo API.""" + fake_image_data = b"fake_image_data" + generation_response = {"sdGenerationJob": {"generationId": "test_id"}} + status_response = { + "generations_by_pk": { + "generated_images": [{"url": "http://example.com/image.jpg"}] + } + } + + with patch("aiohttp.ClientSession") as mock_session_class: + mock_session = AsyncMock() + mock_session_class.return_value.__aenter__.return_value = mock_session + mock_session_class.return_value.__aexit__.return_value = None + + # Mock all responses + mock_post_response = AsyncMock() + mock_post_response.json.return_value = generation_response + mock_session.post.return_value.__aenter__.return_value = mock_post_response + mock_session.post.return_value.__aexit__.return_value = None + + mock_get_response1 = AsyncMock() + mock_get_response1.json.return_value = status_response + mock_get_response2 = AsyncMock() + mock_get_response2.read.return_value = fake_image_data + + mock_session.get.side_effect = [ + mock_session.get.return_value, + mock_session.get.return_value + ] + mock_session.get.return_value.__aenter__.side_effect = [ + mock_get_response1, + mock_get_response2 + ] + mock_session.get.return_value.__aexit__.return_value = None + + mock_delete_response = AsyncMock() + mock_delete_response.json.return_value = {} + mock_session.delete.return_value.__aenter__.return_value = mock_delete_response + mock_session.delete.return_value.__aexit__.return_value = None + + description = "A beautiful sunset" + await self.drawer.draw_leonardo(description) + + # Verify POST request parameters + mock_session.post.assert_called_once_with( + "https://cloud.leonardo.ai/api/rest/v1/generations", + json={ + "prompt": description, + "modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3", + "num_images": 1, + "sd_version": "v2", + "promptMagic": True, + "unzoomAmount": 1, + "width": 512, + "height": 512, + }, + headers={ + "Authorization": f"Bearer {self.config['leonardo-token']}", + "Accept": "application/json", + "Content-Type": "application/json", + }, + ) + + # Verify DELETE request was called + mock_session.delete.assert_called_once_with( + "https://cloud.leonardo.ai/api/rest/v1/generations/test_id", + headers={"Authorization": f"Bearer {self.config['leonardo-token']}"}, + ) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_main_entry.py b/tests/test_main_entry.py new file mode 100644 index 0000000..6e7917a --- /dev/null +++ b/tests/test_main_entry.py @@ -0,0 +1,61 @@ +import unittest +from unittest.mock import Mock, patch + +from fjerkroa_bot import bot_logging + + +class TestMainEntry(unittest.TestCase): + """Test the main entry points.""" + + def test_main_module_exists(self): + """Test that the main module exists and is executable.""" + import os + main_file = "fjerkroa_bot/__main__.py" + self.assertTrue(os.path.exists(main_file)) + + # Read the content to verify it calls main + with open(main_file) as f: + content = f.read() + self.assertIn("main()", content) + self.assertIn("sys.exit", content) + + +class TestBotLogging(unittest.TestCase): + """Test bot logging functionality.""" + + @patch("fjerkroa_bot.bot_logging.logging.basicConfig") + def test_setup_logging_default(self, mock_basic_config): + """Test setup_logging with default level.""" + bot_logging.setup_logging() + + mock_basic_config.assert_called_once() + call_args = mock_basic_config.call_args + self.assertIn("level", call_args.kwargs) + self.assertIn("format", call_args.kwargs) + + @patch("fjerkroa_bot.bot_logging.logging.basicConfig") + def test_setup_logging_custom_level(self, mock_basic_config): + """Test setup_logging with custom level.""" + import logging + bot_logging.setup_logging(logging.DEBUG) + + mock_basic_config.assert_called_once() + call_args = mock_basic_config.call_args + self.assertEqual(call_args.kwargs["level"], logging.DEBUG) + + @patch("fjerkroa_bot.bot_logging.logging.getLogger") + def test_setup_logging_discord_logger(self, mock_get_logger): + """Test that discord logger is configured.""" + mock_logger = Mock() + mock_get_logger.return_value = mock_logger + + bot_logging.setup_logging() + + # Should get the discord logger + mock_get_logger.assert_called_with("discord") + # Should set its level + mock_logger.setLevel.assert_called_once() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_openai_responder.py b/tests/test_openai_responder.py new file mode 100644 index 0000000..6855d52 --- /dev/null +++ b/tests/test_openai_responder.py @@ -0,0 +1,350 @@ +import unittest +from io import BytesIO +from unittest.mock import AsyncMock, Mock, patch + +import openai + +from fjerkroa_bot.openai_responder import OpenAIResponder, openai_chat, openai_image + + +class TestOpenAIResponder(unittest.IsolatedAsyncioTestCase): + def setUp(self): + self.config = { + "openai-key": "test_key", + "model": "gpt-4", + "model-vision": "gpt-4-vision", + "retry-model": "gpt-3.5-turbo", + "fix-model": "gpt-4", + "fix-description": "Fix JSON documents", + "memory-model": "gpt-4", + "memory-system": "You are a memory assistant" + } + self.responder = OpenAIResponder(self.config) + + def test_init(self): + """Test OpenAIResponder initialization.""" + self.assertIsNotNone(self.responder.client) + self.assertEqual(self.responder.config, self.config) + + def test_init_with_openai_token(self): + """Test initialization with openai-token instead of openai-key.""" + config = {"openai-token": "test_token", "model": "gpt-4"} + responder = OpenAIResponder(config) + self.assertIsNotNone(responder.client) + + @patch("fjerkroa_bot.openai_responder.openai_image") + async def test_draw_openai_success(self, mock_openai_image): + """Test successful image generation with OpenAI.""" + mock_image_data = BytesIO(b"fake_image_data") + mock_openai_image.return_value = mock_image_data + + result = await self.responder.draw_openai("A beautiful landscape") + + self.assertEqual(result, mock_image_data) + mock_openai_image.assert_called_once_with( + self.responder.client, + prompt="A beautiful landscape", + n=1, + size="1024x1024", + model="dall-e-3" + ) + + @patch("fjerkroa_bot.openai_responder.openai_image") + async def test_draw_openai_retry_on_failure(self, mock_openai_image): + """Test retry logic when image generation fails.""" + mock_openai_image.side_effect = [ + Exception("First failure"), + Exception("Second failure"), + BytesIO(b"success_data") + ] + + result = await self.responder.draw_openai("test description") + + self.assertEqual(mock_openai_image.call_count, 3) + self.assertEqual(result.read(), b"success_data") + + @patch("fjerkroa_bot.openai_responder.openai_image") + async def test_draw_openai_max_retries_exceeded(self, mock_openai_image): + """Test when all retries are exhausted.""" + mock_openai_image.side_effect = Exception("Persistent failure") + + with self.assertRaises(RuntimeError) as context: + await self.responder.draw_openai("test description") + + self.assertEqual(mock_openai_image.call_count, 3) + self.assertIn("Failed to generate image", str(context.exception)) + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_chat_with_string_content(self, mock_openai_chat): + """Test chat with string message content.""" + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "Hello!" + mock_response.choices[0].message.role = "assistant" + mock_response.usage = Mock() + mock_openai_chat.return_value = mock_response + + messages = [{"role": "user", "content": "Hi there"}] + result, limit = await self.responder.chat(messages, 10) + + expected_answer = {"content": "Hello!", "role": "assistant"} + self.assertEqual(result, expected_answer) + self.assertEqual(limit, 10) + + mock_openai_chat.assert_called_once_with( + self.responder.client, + model="gpt-4", + messages=messages + ) + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_chat_with_vision_model(self, mock_openai_chat): + """Test chat with vision model for non-string content.""" + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "I see an image" + mock_response.choices[0].message.role = "assistant" + mock_response.usage = Mock() + mock_openai_chat.return_value = mock_response + + messages = [{"role": "user", "content": [{"type": "image", "data": "base64data"}]}] + result, limit = await self.responder.chat(messages, 10) + + mock_openai_chat.assert_called_once_with( + self.responder.client, + model="gpt-4-vision", + messages=messages + ) + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_chat_content_fallback(self, mock_openai_chat): + """Test chat content fallback when no vision model.""" + config_no_vision = {"openai-key": "test", "model": "gpt-4"} + responder = OpenAIResponder(config_no_vision) + + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "Text response" + mock_response.choices[0].message.role = "assistant" + mock_response.usage = Mock() + mock_openai_chat.return_value = mock_response + + messages = [{"role": "user", "content": [{"text": "Hello", "type": "text"}]}] + result, limit = await responder.chat(messages, 10) + + # Should modify the message content to just the text + expected_messages = [{"role": "user", "content": "Hello"}] + mock_openai_chat.assert_called_once_with( + responder.client, + model="gpt-4", + messages=expected_messages + ) + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_chat_bad_request_error(self, mock_openai_chat): + """Test handling of BadRequestError with context length.""" + mock_openai_chat.side_effect = openai.BadRequestError( + "maximum context length is exceeded", response=Mock(), body=None + ) + + messages = [{"role": "user", "content": "test"}] + result, limit = await self.responder.chat(messages, 10) + + self.assertIsNone(result) + self.assertEqual(limit, 9) # Should decrease limit + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_chat_bad_request_error_reraise(self, mock_openai_chat): + """Test re-raising BadRequestError when not context length issue.""" + error = openai.BadRequestError("Invalid model", response=Mock(), body=None) + mock_openai_chat.side_effect = error + + messages = [{"role": "user", "content": "test"}] + + with self.assertRaises(openai.BadRequestError): + await self.responder.chat(messages, 10) + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_chat_rate_limit_error(self, mock_openai_chat): + """Test handling of RateLimitError.""" + mock_openai_chat.side_effect = openai.RateLimitError( + "Rate limit exceeded", response=Mock(), body=None + ) + + with patch("asyncio.sleep") as mock_sleep: + messages = [{"role": "user", "content": "test"}] + result, limit = await self.responder.chat(messages, 10) + + self.assertIsNone(result) + self.assertEqual(limit, 10) + mock_sleep.assert_called_once() + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_chat_rate_limit_with_retry_model(self, mock_openai_chat): + """Test rate limit error uses retry model.""" + mock_openai_chat.side_effect = openai.RateLimitError( + "Rate limit exceeded", response=Mock(), body=None + ) + + with patch("asyncio.sleep"): + messages = [{"role": "user", "content": "test"}] + result, limit = await self.responder.chat(messages, 10) + + # Should set model to retry-model internally + self.assertIsNone(result) + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_chat_generic_exception(self, mock_openai_chat): + """Test handling of generic exceptions.""" + mock_openai_chat.side_effect = Exception("Network error") + + messages = [{"role": "user", "content": "test"}] + result, limit = await self.responder.chat(messages, 10) + + self.assertIsNone(result) + self.assertEqual(limit, 10) + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_fix_success(self, mock_openai_chat): + """Test successful JSON fix.""" + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = '{"answer": "fixed", "valid": true}' + mock_openai_chat.return_value = mock_response + + result = await self.responder.fix('{"answer": "broken"') + + self.assertEqual(result, '{"answer": "fixed", "valid": true}') + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_fix_invalid_json_response(self, mock_openai_chat): + """Test fix with invalid JSON response.""" + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = 'This is not JSON' + mock_openai_chat.return_value = mock_response + + original_answer = '{"answer": "test"}' + result = await self.responder.fix(original_answer) + + # Should return original answer when fix fails + self.assertEqual(result, original_answer) + + async def test_fix_no_fix_model(self): + """Test fix when no fix-model is configured.""" + config_no_fix = {"openai-key": "test", "model": "gpt-4"} + responder = OpenAIResponder(config_no_fix) + + original_answer = '{"answer": "test"}' + result = await responder.fix(original_answer) + + self.assertEqual(result, original_answer) + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_fix_exception_handling(self, mock_openai_chat): + """Test fix exception handling.""" + mock_openai_chat.side_effect = Exception("API Error") + + original_answer = '{"answer": "test"}' + result = await self.responder.fix(original_answer) + + self.assertEqual(result, original_answer) + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_translate_success(self, mock_openai_chat): + """Test successful translation.""" + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "Hola mundo" + mock_openai_chat.return_value = mock_response + + result = await self.responder.translate("Hello world", "spanish") + + self.assertEqual(result, "Hola mundo") + mock_openai_chat.assert_called_once() + + async def test_translate_no_fix_model(self): + """Test translate when no fix-model is configured.""" + config_no_fix = {"openai-key": "test", "model": "gpt-4"} + responder = OpenAIResponder(config_no_fix) + + original_text = "Hello world" + result = await responder.translate(original_text) + + self.assertEqual(result, original_text) + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_translate_exception_handling(self, mock_openai_chat): + """Test translate exception handling.""" + mock_openai_chat.side_effect = Exception("API Error") + + original_text = "Hello world" + result = await self.responder.translate(original_text) + + self.assertEqual(result, original_text) + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_memory_rewrite_success(self, mock_openai_chat): + """Test successful memory rewrite.""" + mock_response = Mock() + mock_response.choices = [Mock()] + mock_response.choices[0].message.content = "Updated memory content" + mock_openai_chat.return_value = mock_response + + result = await self.responder.memory_rewrite( + "Old memory", "user1", "assistant", "What's your name?", "I'm Claude" + ) + + self.assertEqual(result, "Updated memory content") + + async def test_memory_rewrite_no_memory_model(self): + """Test memory rewrite when no memory-model is configured.""" + config_no_memory = {"openai-key": "test", "model": "gpt-4"} + responder = OpenAIResponder(config_no_memory) + + original_memory = "Old memory" + result = await responder.memory_rewrite( + original_memory, "user1", "assistant", "question", "answer" + ) + + self.assertEqual(result, original_memory) + + @patch("fjerkroa_bot.openai_responder.openai_chat") + async def test_memory_rewrite_exception_handling(self, mock_openai_chat): + """Test memory rewrite exception handling.""" + mock_openai_chat.side_effect = Exception("API Error") + + original_memory = "Old memory" + result = await self.responder.memory_rewrite( + original_memory, "user1", "assistant", "question", "answer" + ) + + self.assertEqual(result, original_memory) + + +class TestOpenAIFunctions(unittest.IsolatedAsyncioTestCase): + """Test the standalone openai functions.""" + + @patch("fjerkroa_bot.openai_responder.async_cache_to_file") + async def test_openai_chat_function(self, mock_cache): + """Test the openai_chat caching function.""" + mock_client = Mock() + mock_response = Mock() + mock_client.chat.completions.create.return_value = mock_response + + # The function should be wrapped with caching + self.assertTrue(callable(openai_chat)) + + @patch("fjerkroa_bot.openai_responder.async_cache_to_file") + async def test_openai_image_function(self, mock_cache): + """Test the openai_image caching function.""" + mock_client = Mock() + mock_response = Mock() + mock_response.data = [Mock(url="http://example.com/image.jpg")] + + # The function should be wrapped with caching + self.assertTrue(callable(openai_image)) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file