diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index c911960..a7aab7a 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -1,5 +1,8 @@ import json import openai +import aiohttp +import logging +from io import BytesIO from typing import Optional, List, Dict, Any, Tuple @@ -43,6 +46,16 @@ class AIResponder(object): messages.append(msg) return messages, history + async def draw(self, description: str) -> BytesIO: + while True: + try: + response = await openai.Image.acreate(prompt=description, n=1, size="512x512") + async with aiohttp.ClientSession() as session: + async with session.get(response['data'][0]['url']) as image: + return BytesIO(await image.read()) + except Exception as err: + logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}") + async def send(self, message: AIMessage) -> AIResponse: limit = self.config["history-limit"] while True: @@ -54,13 +67,20 @@ class AIResponder(object): top_p=self.config["top-p"], presence_penalty=self.config["presence-penalty"], frequency_penalty=self.config["frequency-penalty"]) + answer = result['choices'][0]['message'] + response = json.loads(answer['content']) except openai.error.InvalidRequestError as err: - if 'maximum context length is' in str(err) and limit > 2: + if 'maximum context length is' in str(err) and limit > 4: limit -= 1 continue raise err - answer = result['choices'][0]['message'] - response = json.loads(answer['content']) + except Exception as err: + logging.warning(f"failed to generate response: {repr(err)}") + continue history.append(answer) self.history = history - return response + return AIResponse(response['answer'], + response['answer_needed'], + response['staff'], + response['picture'], + response['hack']) diff --git a/fjerkroa_bot/discord_bot.py b/fjerkroa_bot/discord_bot.py index bd1f332..00548c2 100644 --- a/fjerkroa_bot/discord_bot.py +++ b/fjerkroa_bot/discord_bot.py @@ -59,11 +59,15 @@ class FjerkroaBot(commands.Bot): response = await self.airesponder.send(msg) if response.staff is not None: async with self.staff_channel.typing(): - self.staff_channel.send(response.staff) + await self.staff_channel.send(response.staff) if not response.answer_needed: return async with message.channel.typing(): - message.channel.send(response.answer) + if response.picture is not None: + images = [discord.File(fp=await self.airesponder.draw(response.picture), filename="image.png")] + await message.channel.send(response.answer, files=images) + else: + await message.channel.send(response.answer) async def close(self): self.observer.stop() diff --git a/pyproject.toml b/pyproject.toml index 004930f..02c8575 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ ignore = [ "E266", "E501", "W503", + "E306", ] exclude = [ ".git", diff --git a/tests/test_main.py b/tests/test_main.py index f8b6e97..ca0df63 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,10 @@ import unittest +import pytest +import aiohttp import json -from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open +import openai +import logging +from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY from fjerkroa_bot import FjerkroaBot from discord import User, Message, TextChannel @@ -14,10 +18,14 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase): ] self.config_data = { "openai_key": "OPENAIKEY", - "engine": "gpt-4", + "model": "gpt-4", "max_tokens": 1024, - "n": 1, "temperature": 0.9, + "top-p": 1.0, + "presence-penalty": 1.0, + "frequency-penalty": 1.0, + "history-limit": 10, + "system": "You are an smart AI", } self.history_data = [] with patch.object(FjerkroaBot, 'load_config', new=lambda s, c: self.config_data), \ @@ -25,11 +33,16 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase): mock_user.return_value = MagicMock(spec=User) mock_user.return_value.id = 12 self.bot = FjerkroaBot('config.json') + self.bot.staff_channel = AsyncMock(spec=TextChannel) + self.bot.staff_channel.send = AsyncMock() + self.bot.welcome_channel = AsyncMock(spec=TextChannel) + self.bot.welcome_channel.send = AsyncMock() def create_message(self, message: str) -> Message: message = MagicMock(spec=Message) message.content = "Hello, how are you?" message.author = AsyncMock(spec=User) + message.author.name = 'Lala' message.author.id = 123 message.author.bot = False message.channel = AsyncMock(spec=TextChannel) @@ -45,10 +58,78 @@ class TestFunctionality(TestBotBase): self.assertEqual(result, self.config_data) async def test_on_message_event(self) -> None: + async def acreate(*a, **kw): + answer = {'answer': 'Hello!', + 'answer_needed': True, + 'staff': None, + 'picture': None, + 'hack': False} + return {'choices': [{'message': {'content': json.dumps(answer)}}]} message = self.create_message("Hello there! How are you?") - await self.bot.on_message(message) + with patch.object(openai.ChatCompletion, 'acreate', new=acreate): + await self.bot.on_message(message) message.channel.send.assert_called_once_with("Hello!") + async def test_on_message_event2(self) -> None: + async def acreate(*a, **kw): + answer = {'answer': 'Hello!', + 'answer_needed': True, + 'staff': 'Hallo staff', + 'picture': None, + 'hack': False} + return {'choices': [{'message': {'content': json.dumps(answer)}}]} + message = self.create_message("Hello there! How are you?") + with patch.object(openai.ChatCompletion, 'acreate', new=acreate): + await self.bot.on_message(message) + message.channel.send.assert_called_once_with("Hello!") + + async def test_on_message_event3(self) -> None: + async def acreate(*a, **kw): + return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]} + + def logging_warning(msg): + raise RuntimeError(msg) + message = self.create_message("Hello there! How are you?") + with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \ + patch.object(logging, 'warning', logging_warning): + with pytest.raises(RuntimeError, match='failed.*JSONDecodeError.*'): + await self.bot.on_message(message) + + async def test_on_message_event4(self) -> None: + async def acreate(*a, **kw): + answer = {'answer': 'Hello!', + 'answer_needed': True, + 'staff': 'none', + 'picture': 'Some picture', + 'hack': False} + return {'choices': [{'message': {'content': json.dumps(answer)}}]} + + async def adraw(*a, **kw): + return {'data': [{'url': 'http:url'}]} + + def logging_warning(msg): + raise RuntimeError(msg) + + class image: + def __init__(self, *args, **kw): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + return False + + async def read(self): + return b'test bytes' + message = self.create_message("Hello there! How are you?") + with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \ + patch.object(openai.Image, 'acreate', new=adraw), \ + patch.object(logging, 'warning', logging_warning), \ + patch.object(aiohttp.ClientSession, 'get', new=image): + await self.bot.on_message(message) + message.channel.send.assert_called_once_with("Hello!", files=[ANY]) + if __name__ == "__mait__": unittest.main()