import asyncio import unittest from io import BytesIO from unittest.mock import AsyncMock, Mock, patch import aiohttp from fjerkroa_bot.leonardo_draw import LeonardoAIDrawMixIn class MockLeonardoDrawer(LeonardoAIDrawMixIn): """Mock class to test the mixin.""" def __init__(self, config): self.config = config class TestLeonardoAIDrawMixIn(unittest.IsolatedAsyncioTestCase): def setUp(self): self.config = {"leonardo-token": "test_token"} self.drawer = MockLeonardoDrawer(self.config) async def test_draw_leonardo_success(self): """Test successful image generation with Leonardo AI.""" # Mock image data fake_image_data = b"fake_image_data" # Mock responses generation_response = { "sdGenerationJob": {"generationId": "test_generation_id"} } status_response = { "generations_by_pk": { "generated_images": [{"url": "http://example.com/image.jpg"}] } } with patch("aiohttp.ClientSession") as mock_session_class: # Create mock session mock_session = AsyncMock() mock_session_class.return_value.__aenter__.return_value = mock_session mock_session_class.return_value.__aexit__.return_value = None # Mock POST request (generation) mock_post_response = AsyncMock() mock_post_response.json.return_value = generation_response mock_session.post.return_value.__aenter__.return_value = mock_post_response mock_session.post.return_value.__aexit__.return_value = None # Mock GET requests (status check and image download) mock_get_response1 = AsyncMock() mock_get_response1.json.return_value = status_response mock_get_response2 = AsyncMock() mock_get_response2.read.return_value = fake_image_data mock_session.get.side_effect = [ mock_session.get.return_value, # Status check mock_session.get.return_value # Image download ] mock_session.get.return_value.__aenter__.side_effect = [ mock_get_response1, # Status check mock_get_response2 # Image download ] mock_session.get.return_value.__aexit__.return_value = None # Mock DELETE request mock_delete_response = AsyncMock() mock_delete_response.json.return_value = {} mock_session.delete.return_value.__aenter__.return_value = mock_delete_response mock_session.delete.return_value.__aexit__.return_value = None result = await self.drawer.draw_leonardo("A beautiful landscape") # Verify the result self.assertIsInstance(result, BytesIO) self.assertEqual(result.read(), fake_image_data) async def test_draw_leonardo_no_generation_job(self): """Test when generation job is not returned.""" generation_response = {} # No sdGenerationJob with patch("aiohttp.ClientSession") as mock_session_class: mock_session = AsyncMock() mock_session_class.return_value.__aenter__.return_value = mock_session mock_session_class.return_value.__aexit__.return_value = None mock_post_response = AsyncMock() mock_post_response.json.return_value = generation_response mock_session.post.return_value.__aenter__.return_value = mock_post_response mock_session.post.return_value.__aexit__.return_value = None with patch("asyncio.sleep") as mock_sleep: with patch("fjerkroa_bot.leonardo_draw.exponential_backoff") as mock_backoff: mock_backoff.return_value = iter([1, 2, 4]) # Limited attempts with self.assertRaises(StopIteration): await self.drawer.draw_leonardo("test description") async def test_draw_leonardo_no_generations_by_pk(self): """Test when generations_by_pk is not in response.""" generation_response = {"sdGenerationJob": {"generationId": "test_id"}} status_response = {} # No generations_by_pk with patch("aiohttp.ClientSession") as mock_session_class: mock_session = AsyncMock() mock_session_class.return_value.__aenter__.return_value = mock_session mock_session_class.return_value.__aexit__.return_value = None # Mock POST (successful) mock_post_response = AsyncMock() mock_post_response.json.return_value = generation_response mock_session.post.return_value.__aenter__.return_value = mock_post_response mock_session.post.return_value.__aexit__.return_value = None # Mock GET (status check - no generations) mock_get_response = AsyncMock() mock_get_response.json.return_value = status_response mock_session.get.return_value.__aenter__.return_value = mock_get_response mock_session.get.return_value.__aexit__.return_value = None with patch("asyncio.sleep") as mock_sleep: with patch("fjerkroa_bot.leonardo_draw.exponential_backoff") as mock_backoff: mock_backoff.return_value = iter([1, 2]) # Limited attempts with self.assertRaises(StopIteration): await self.drawer.draw_leonardo("test description") async def test_draw_leonardo_no_generated_images(self): """Test when no generated images are available yet.""" generation_response = {"sdGenerationJob": {"generationId": "test_id"}} status_response = {"generations_by_pk": {"generated_images": []}} with patch("aiohttp.ClientSession") as mock_session_class: mock_session = AsyncMock() mock_session_class.return_value.__aenter__.return_value = mock_session mock_session_class.return_value.__aexit__.return_value = None # Mock POST (successful) mock_post_response = AsyncMock() mock_post_response.json.return_value = generation_response mock_session.post.return_value.__aenter__.return_value = mock_post_response mock_session.post.return_value.__aexit__.return_value = None # Mock GET (status check - empty images) mock_get_response = AsyncMock() mock_get_response.json.return_value = status_response mock_session.get.return_value.__aenter__.return_value = mock_get_response mock_session.get.return_value.__aexit__.return_value = None with patch("asyncio.sleep") as mock_sleep: with patch("fjerkroa_bot.leonardo_draw.exponential_backoff") as mock_backoff: mock_backoff.return_value = iter([1, 2]) # Limited attempts with self.assertRaises(StopIteration): await self.drawer.draw_leonardo("test description") async def test_draw_leonardo_exception_handling(self): """Test exception handling during image generation.""" with patch("aiohttp.ClientSession") as mock_session_class: mock_session = AsyncMock() mock_session_class.return_value.__aenter__.return_value = mock_session mock_session_class.return_value.__aexit__.return_value = None # Make POST request raise an exception mock_session.post.side_effect = Exception("Network error") with patch("asyncio.sleep") as mock_sleep: with patch("fjerkroa_bot.leonardo_draw.exponential_backoff") as mock_backoff: mock_backoff.return_value = iter([1, 2]) # Limited attempts with self.assertRaises(StopIteration): await self.drawer.draw_leonardo("test description") async def test_draw_leonardo_request_parameters(self): """Test that correct parameters are sent to Leonardo API.""" fake_image_data = b"fake_image_data" generation_response = {"sdGenerationJob": {"generationId": "test_id"}} status_response = { "generations_by_pk": { "generated_images": [{"url": "http://example.com/image.jpg"}] } } with patch("aiohttp.ClientSession") as mock_session_class: mock_session = AsyncMock() mock_session_class.return_value.__aenter__.return_value = mock_session mock_session_class.return_value.__aexit__.return_value = None # Mock all responses mock_post_response = AsyncMock() mock_post_response.json.return_value = generation_response mock_session.post.return_value.__aenter__.return_value = mock_post_response mock_session.post.return_value.__aexit__.return_value = None mock_get_response1 = AsyncMock() mock_get_response1.json.return_value = status_response mock_get_response2 = AsyncMock() mock_get_response2.read.return_value = fake_image_data mock_session.get.side_effect = [ mock_session.get.return_value, mock_session.get.return_value ] mock_session.get.return_value.__aenter__.side_effect = [ mock_get_response1, mock_get_response2 ] mock_session.get.return_value.__aexit__.return_value = None mock_delete_response = AsyncMock() mock_delete_response.json.return_value = {} mock_session.delete.return_value.__aenter__.return_value = mock_delete_response mock_session.delete.return_value.__aexit__.return_value = None description = "A beautiful sunset" await self.drawer.draw_leonardo(description) # Verify POST request parameters mock_session.post.assert_called_once_with( "https://cloud.leonardo.ai/api/rest/v1/generations", json={ "prompt": description, "modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3", "num_images": 1, "sd_version": "v2", "promptMagic": True, "unzoomAmount": 1, "width": 512, "height": 512, }, headers={ "Authorization": f"Bearer {self.config['leonardo-token']}", "Accept": "application/json", "Content-Type": "application/json", }, ) # Verify DELETE request was called mock_session.delete.assert_called_once_with( "https://cloud.leonardo.ai/api/rest/v1/generations/test_id", headers={"Authorization": f"Bearer {self.config['leonardo-token']}"}, ) if __name__ == "__main__": unittest.main()