Support drawing.

This commit is contained in:
Fjerkroa Auto 2023-03-22 10:46:04 +01:00
parent c0b51f947c
commit ef933883c0
4 changed files with 116 additions and 10 deletions

View File

@ -1,5 +1,8 @@
import json import json
import openai import openai
import aiohttp
import logging
from io import BytesIO
from typing import Optional, List, Dict, Any, Tuple from typing import Optional, List, Dict, Any, Tuple
@ -43,6 +46,16 @@ class AIResponder(object):
messages.append(msg) messages.append(msg)
return messages, history return messages, history
async def draw(self, description: str) -> BytesIO:
while True:
try:
response = await openai.Image.acreate(prompt=description, n=1, size="512x512")
async with aiohttp.ClientSession() as session:
async with session.get(response['data'][0]['url']) as image:
return BytesIO(await image.read())
except Exception as err:
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
async def send(self, message: AIMessage) -> AIResponse: async def send(self, message: AIMessage) -> AIResponse:
limit = self.config["history-limit"] limit = self.config["history-limit"]
while True: while True:
@ -54,13 +67,20 @@ class AIResponder(object):
top_p=self.config["top-p"], top_p=self.config["top-p"],
presence_penalty=self.config["presence-penalty"], presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"]) frequency_penalty=self.config["frequency-penalty"])
answer = result['choices'][0]['message']
response = json.loads(answer['content'])
except openai.error.InvalidRequestError as err: except openai.error.InvalidRequestError as err:
if 'maximum context length is' in str(err) and limit > 2: if 'maximum context length is' in str(err) and limit > 4:
limit -= 1 limit -= 1
continue continue
raise err raise err
answer = result['choices'][0]['message'] except Exception as err:
response = json.loads(answer['content']) logging.warning(f"failed to generate response: {repr(err)}")
continue
history.append(answer) history.append(answer)
self.history = history self.history = history
return response return AIResponse(response['answer'],
response['answer_needed'],
response['staff'],
response['picture'],
response['hack'])

View File

@ -59,11 +59,15 @@ class FjerkroaBot(commands.Bot):
response = await self.airesponder.send(msg) response = await self.airesponder.send(msg)
if response.staff is not None: if response.staff is not None:
async with self.staff_channel.typing(): async with self.staff_channel.typing():
self.staff_channel.send(response.staff) await self.staff_channel.send(response.staff)
if not response.answer_needed: if not response.answer_needed:
return return
async with message.channel.typing(): async with message.channel.typing():
message.channel.send(response.answer) if response.picture is not None:
images = [discord.File(fp=await self.airesponder.draw(response.picture), filename="image.png")]
await message.channel.send(response.answer, files=images)
else:
await message.channel.send(response.answer)
async def close(self): async def close(self):
self.observer.stop() self.observer.stop()

View File

@ -12,6 +12,7 @@ ignore = [
"E266", "E266",
"E501", "E501",
"W503", "W503",
"E306",
] ]
exclude = [ exclude = [
".git", ".git",

View File

@ -1,6 +1,10 @@
import unittest import unittest
import pytest
import aiohttp
import json import json
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open import openai
import logging
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY
from fjerkroa_bot import FjerkroaBot from fjerkroa_bot import FjerkroaBot
from discord import User, Message, TextChannel from discord import User, Message, TextChannel
@ -14,10 +18,14 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
] ]
self.config_data = { self.config_data = {
"openai_key": "OPENAIKEY", "openai_key": "OPENAIKEY",
"engine": "gpt-4", "model": "gpt-4",
"max_tokens": 1024, "max_tokens": 1024,
"n": 1,
"temperature": 0.9, "temperature": 0.9,
"top-p": 1.0,
"presence-penalty": 1.0,
"frequency-penalty": 1.0,
"history-limit": 10,
"system": "You are an smart AI",
} }
self.history_data = [] self.history_data = []
with patch.object(FjerkroaBot, 'load_config', new=lambda s, c: self.config_data), \ with patch.object(FjerkroaBot, 'load_config', new=lambda s, c: self.config_data), \
@ -25,11 +33,16 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
mock_user.return_value = MagicMock(spec=User) mock_user.return_value = MagicMock(spec=User)
mock_user.return_value.id = 12 mock_user.return_value.id = 12
self.bot = FjerkroaBot('config.json') self.bot = FjerkroaBot('config.json')
self.bot.staff_channel = AsyncMock(spec=TextChannel)
self.bot.staff_channel.send = AsyncMock()
self.bot.welcome_channel = AsyncMock(spec=TextChannel)
self.bot.welcome_channel.send = AsyncMock()
def create_message(self, message: str) -> Message: def create_message(self, message: str) -> Message:
message = MagicMock(spec=Message) message = MagicMock(spec=Message)
message.content = "Hello, how are you?" message.content = "Hello, how are you?"
message.author = AsyncMock(spec=User) message.author = AsyncMock(spec=User)
message.author.name = 'Lala'
message.author.id = 123 message.author.id = 123
message.author.bot = False message.author.bot = False
message.channel = AsyncMock(spec=TextChannel) message.channel = AsyncMock(spec=TextChannel)
@ -45,10 +58,78 @@ class TestFunctionality(TestBotBase):
self.assertEqual(result, self.config_data) self.assertEqual(result, self.config_data)
async def test_on_message_event(self) -> None: async def test_on_message_event(self) -> None:
async def acreate(*a, **kw):
answer = {'answer': 'Hello!',
'answer_needed': True,
'staff': None,
'picture': None,
'hack': False}
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
message = self.create_message("Hello there! How are you?") message = self.create_message("Hello there! How are you?")
await self.bot.on_message(message) with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
await self.bot.on_message(message)
message.channel.send.assert_called_once_with("Hello!") message.channel.send.assert_called_once_with("Hello!")
async def test_on_message_event2(self) -> None:
async def acreate(*a, **kw):
answer = {'answer': 'Hello!',
'answer_needed': True,
'staff': 'Hallo staff',
'picture': None,
'hack': False}
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
await self.bot.on_message(message)
message.channel.send.assert_called_once_with("Hello!")
async def test_on_message_event3(self) -> None:
async def acreate(*a, **kw):
return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]}
def logging_warning(msg):
raise RuntimeError(msg)
message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
patch.object(logging, 'warning', logging_warning):
with pytest.raises(RuntimeError, match='failed.*JSONDecodeError.*'):
await self.bot.on_message(message)
async def test_on_message_event4(self) -> None:
async def acreate(*a, **kw):
answer = {'answer': 'Hello!',
'answer_needed': True,
'staff': 'none',
'picture': 'Some picture',
'hack': False}
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
async def adraw(*a, **kw):
return {'data': [{'url': 'http:url'}]}
def logging_warning(msg):
raise RuntimeError(msg)
class image:
def __init__(self, *args, **kw):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
return False
async def read(self):
return b'test bytes'
message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
patch.object(openai.Image, 'acreate', new=adraw), \
patch.object(logging, 'warning', logging_warning), \
patch.object(aiohttp.ClientSession, 'get', new=image):
await self.bot.on_message(message)
message.channel.send.assert_called_once_with("Hello!", files=[ANY])
if __name__ == "__mait__": if __name__ == "__mait__":
unittest.main() unittest.main()