Support drawing.
This commit is contained in:
parent
c0b51f947c
commit
ef933883c0
@ -1,5 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import openai
|
import openai
|
||||||
|
import aiohttp
|
||||||
|
import logging
|
||||||
|
from io import BytesIO
|
||||||
from typing import Optional, List, Dict, Any, Tuple
|
from typing import Optional, List, Dict, Any, Tuple
|
||||||
|
|
||||||
|
|
||||||
@ -43,6 +46,16 @@ class AIResponder(object):
|
|||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
return messages, history
|
return messages, history
|
||||||
|
|
||||||
|
async def draw(self, description: str) -> BytesIO:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
response = await openai.Image.acreate(prompt=description, n=1, size="512x512")
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(response['data'][0]['url']) as image:
|
||||||
|
return BytesIO(await image.read())
|
||||||
|
except Exception as err:
|
||||||
|
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
||||||
|
|
||||||
async def send(self, message: AIMessage) -> AIResponse:
|
async def send(self, message: AIMessage) -> AIResponse:
|
||||||
limit = self.config["history-limit"]
|
limit = self.config["history-limit"]
|
||||||
while True:
|
while True:
|
||||||
@ -54,13 +67,20 @@ class AIResponder(object):
|
|||||||
top_p=self.config["top-p"],
|
top_p=self.config["top-p"],
|
||||||
presence_penalty=self.config["presence-penalty"],
|
presence_penalty=self.config["presence-penalty"],
|
||||||
frequency_penalty=self.config["frequency-penalty"])
|
frequency_penalty=self.config["frequency-penalty"])
|
||||||
|
answer = result['choices'][0]['message']
|
||||||
|
response = json.loads(answer['content'])
|
||||||
except openai.error.InvalidRequestError as err:
|
except openai.error.InvalidRequestError as err:
|
||||||
if 'maximum context length is' in str(err) and limit > 2:
|
if 'maximum context length is' in str(err) and limit > 4:
|
||||||
limit -= 1
|
limit -= 1
|
||||||
continue
|
continue
|
||||||
raise err
|
raise err
|
||||||
answer = result['choices'][0]['message']
|
except Exception as err:
|
||||||
response = json.loads(answer['content'])
|
logging.warning(f"failed to generate response: {repr(err)}")
|
||||||
|
continue
|
||||||
history.append(answer)
|
history.append(answer)
|
||||||
self.history = history
|
self.history = history
|
||||||
return response
|
return AIResponse(response['answer'],
|
||||||
|
response['answer_needed'],
|
||||||
|
response['staff'],
|
||||||
|
response['picture'],
|
||||||
|
response['hack'])
|
||||||
|
|||||||
@ -59,11 +59,15 @@ class FjerkroaBot(commands.Bot):
|
|||||||
response = await self.airesponder.send(msg)
|
response = await self.airesponder.send(msg)
|
||||||
if response.staff is not None:
|
if response.staff is not None:
|
||||||
async with self.staff_channel.typing():
|
async with self.staff_channel.typing():
|
||||||
self.staff_channel.send(response.staff)
|
await self.staff_channel.send(response.staff)
|
||||||
if not response.answer_needed:
|
if not response.answer_needed:
|
||||||
return
|
return
|
||||||
async with message.channel.typing():
|
async with message.channel.typing():
|
||||||
message.channel.send(response.answer)
|
if response.picture is not None:
|
||||||
|
images = [discord.File(fp=await self.airesponder.draw(response.picture), filename="image.png")]
|
||||||
|
await message.channel.send(response.answer, files=images)
|
||||||
|
else:
|
||||||
|
await message.channel.send(response.answer)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
self.observer.stop()
|
self.observer.stop()
|
||||||
|
|||||||
@ -12,6 +12,7 @@ ignore = [
|
|||||||
"E266",
|
"E266",
|
||||||
"E501",
|
"E501",
|
||||||
"W503",
|
"W503",
|
||||||
|
"E306",
|
||||||
]
|
]
|
||||||
exclude = [
|
exclude = [
|
||||||
".git",
|
".git",
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
import pytest
|
||||||
|
import aiohttp
|
||||||
import json
|
import json
|
||||||
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open
|
import openai
|
||||||
|
import logging
|
||||||
|
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY
|
||||||
from fjerkroa_bot import FjerkroaBot
|
from fjerkroa_bot import FjerkroaBot
|
||||||
from discord import User, Message, TextChannel
|
from discord import User, Message, TextChannel
|
||||||
|
|
||||||
@ -14,10 +18,14 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
|
|||||||
]
|
]
|
||||||
self.config_data = {
|
self.config_data = {
|
||||||
"openai_key": "OPENAIKEY",
|
"openai_key": "OPENAIKEY",
|
||||||
"engine": "gpt-4",
|
"model": "gpt-4",
|
||||||
"max_tokens": 1024,
|
"max_tokens": 1024,
|
||||||
"n": 1,
|
|
||||||
"temperature": 0.9,
|
"temperature": 0.9,
|
||||||
|
"top-p": 1.0,
|
||||||
|
"presence-penalty": 1.0,
|
||||||
|
"frequency-penalty": 1.0,
|
||||||
|
"history-limit": 10,
|
||||||
|
"system": "You are an smart AI",
|
||||||
}
|
}
|
||||||
self.history_data = []
|
self.history_data = []
|
||||||
with patch.object(FjerkroaBot, 'load_config', new=lambda s, c: self.config_data), \
|
with patch.object(FjerkroaBot, 'load_config', new=lambda s, c: self.config_data), \
|
||||||
@ -25,11 +33,16 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
|
|||||||
mock_user.return_value = MagicMock(spec=User)
|
mock_user.return_value = MagicMock(spec=User)
|
||||||
mock_user.return_value.id = 12
|
mock_user.return_value.id = 12
|
||||||
self.bot = FjerkroaBot('config.json')
|
self.bot = FjerkroaBot('config.json')
|
||||||
|
self.bot.staff_channel = AsyncMock(spec=TextChannel)
|
||||||
|
self.bot.staff_channel.send = AsyncMock()
|
||||||
|
self.bot.welcome_channel = AsyncMock(spec=TextChannel)
|
||||||
|
self.bot.welcome_channel.send = AsyncMock()
|
||||||
|
|
||||||
def create_message(self, message: str) -> Message:
|
def create_message(self, message: str) -> Message:
|
||||||
message = MagicMock(spec=Message)
|
message = MagicMock(spec=Message)
|
||||||
message.content = "Hello, how are you?"
|
message.content = "Hello, how are you?"
|
||||||
message.author = AsyncMock(spec=User)
|
message.author = AsyncMock(spec=User)
|
||||||
|
message.author.name = 'Lala'
|
||||||
message.author.id = 123
|
message.author.id = 123
|
||||||
message.author.bot = False
|
message.author.bot = False
|
||||||
message.channel = AsyncMock(spec=TextChannel)
|
message.channel = AsyncMock(spec=TextChannel)
|
||||||
@ -45,10 +58,78 @@ class TestFunctionality(TestBotBase):
|
|||||||
self.assertEqual(result, self.config_data)
|
self.assertEqual(result, self.config_data)
|
||||||
|
|
||||||
async def test_on_message_event(self) -> None:
|
async def test_on_message_event(self) -> None:
|
||||||
|
async def acreate(*a, **kw):
|
||||||
|
answer = {'answer': 'Hello!',
|
||||||
|
'answer_needed': True,
|
||||||
|
'staff': None,
|
||||||
|
'picture': None,
|
||||||
|
'hack': False}
|
||||||
|
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
||||||
message = self.create_message("Hello there! How are you?")
|
message = self.create_message("Hello there! How are you?")
|
||||||
await self.bot.on_message(message)
|
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
||||||
|
await self.bot.on_message(message)
|
||||||
message.channel.send.assert_called_once_with("Hello!")
|
message.channel.send.assert_called_once_with("Hello!")
|
||||||
|
|
||||||
|
async def test_on_message_event2(self) -> None:
|
||||||
|
async def acreate(*a, **kw):
|
||||||
|
answer = {'answer': 'Hello!',
|
||||||
|
'answer_needed': True,
|
||||||
|
'staff': 'Hallo staff',
|
||||||
|
'picture': None,
|
||||||
|
'hack': False}
|
||||||
|
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
||||||
|
message = self.create_message("Hello there! How are you?")
|
||||||
|
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
||||||
|
await self.bot.on_message(message)
|
||||||
|
message.channel.send.assert_called_once_with("Hello!")
|
||||||
|
|
||||||
|
async def test_on_message_event3(self) -> None:
|
||||||
|
async def acreate(*a, **kw):
|
||||||
|
return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]}
|
||||||
|
|
||||||
|
def logging_warning(msg):
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
message = self.create_message("Hello there! How are you?")
|
||||||
|
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
|
||||||
|
patch.object(logging, 'warning', logging_warning):
|
||||||
|
with pytest.raises(RuntimeError, match='failed.*JSONDecodeError.*'):
|
||||||
|
await self.bot.on_message(message)
|
||||||
|
|
||||||
|
async def test_on_message_event4(self) -> None:
|
||||||
|
async def acreate(*a, **kw):
|
||||||
|
answer = {'answer': 'Hello!',
|
||||||
|
'answer_needed': True,
|
||||||
|
'staff': 'none',
|
||||||
|
'picture': 'Some picture',
|
||||||
|
'hack': False}
|
||||||
|
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
||||||
|
|
||||||
|
async def adraw(*a, **kw):
|
||||||
|
return {'data': [{'url': 'http:url'}]}
|
||||||
|
|
||||||
|
def logging_warning(msg):
|
||||||
|
raise RuntimeError(msg)
|
||||||
|
|
||||||
|
class image:
|
||||||
|
def __init__(self, *args, **kw):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def read(self):
|
||||||
|
return b'test bytes'
|
||||||
|
message = self.create_message("Hello there! How are you?")
|
||||||
|
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
|
||||||
|
patch.object(openai.Image, 'acreate', new=adraw), \
|
||||||
|
patch.object(logging, 'warning', logging_warning), \
|
||||||
|
patch.object(aiohttp.ClientSession, 'get', new=image):
|
||||||
|
await self.bot.on_message(message)
|
||||||
|
message.channel.send.assert_called_once_with("Hello!", files=[ANY])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__mait__":
|
if __name__ == "__mait__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user