213 lines
10 KiB
Python
213 lines
10 KiB
Python
import os
|
|
import unittest
|
|
import aiohttp
|
|
import json
|
|
import toml
|
|
import openai
|
|
import logging
|
|
import pytest
|
|
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY
|
|
from fjerkroa_bot import FjerkroaBot
|
|
from fjerkroa_bot.ai_responder import parse_maybe_json, AIResponse, AIMessage
|
|
from discord import User, Message, TextChannel
|
|
|
|
|
|
class TestBotBase(unittest.IsolatedAsyncioTestCase):
|
|
|
|
async def asyncSetUp(self):
|
|
self.mock_response = Mock()
|
|
self.mock_response.choices = [
|
|
Mock(text="Nice day today!")
|
|
]
|
|
self.config_data = {
|
|
"openai-token": os.environ.get('OPENAI_TOKEN', 'test'),
|
|
"model": "gpt-3.5-turbo",
|
|
"max-tokens": 1024,
|
|
"temperature": 0.9,
|
|
"top-p": 1.0,
|
|
"presence-penalty": 1.0,
|
|
"frequency-penalty": 1.0,
|
|
"history-limit": 10,
|
|
"system": "You are an smart AI",
|
|
"additional-responders": [],
|
|
}
|
|
self.history_data = []
|
|
with patch.object(FjerkroaBot, 'load_config', new=lambda s, c: self.config_data), \
|
|
patch.object(FjerkroaBot, 'user', new_callable=PropertyMock) as mock_user:
|
|
mock_user.return_value = MagicMock(spec=User)
|
|
mock_user.return_value.id = 12
|
|
self.bot = FjerkroaBot('config.toml')
|
|
self.bot.staff_channel = AsyncMock(spec=TextChannel)
|
|
self.bot.staff_channel.send = AsyncMock()
|
|
self.bot.welcome_channel = AsyncMock(spec=TextChannel)
|
|
self.bot.welcome_channel.send = AsyncMock()
|
|
self.bot.airesponder.config = self.config_data
|
|
|
|
def create_message(self, message: str) -> Message:
|
|
message = MagicMock(spec=Message)
|
|
message.content = "Hello, how are you?"
|
|
message.author = AsyncMock(spec=User)
|
|
message.author.name = 'Lala'
|
|
message.author.id = 123
|
|
message.author.bot = False
|
|
message.channel = AsyncMock(spec=TextChannel)
|
|
message.channel.send = AsyncMock()
|
|
return message
|
|
|
|
|
|
class TestFunctionality(TestBotBase):
|
|
|
|
def test_load_config(self) -> None:
|
|
with patch('builtins.open', mock_open(read_data=toml.dumps(self.config_data))):
|
|
result = FjerkroaBot.load_config('config.toml')
|
|
self.assertEqual(result, self.config_data)
|
|
|
|
def test_json_strings(self) -> None:
|
|
json_string = '{"key1": "value1", "key2": "value2"}'
|
|
expected_output = "value1\nvalue2"
|
|
self.assertEqual(parse_maybe_json(json_string), expected_output)
|
|
non_json_string = "This is not a JSON string."
|
|
self.assertEqual(parse_maybe_json(non_json_string), non_json_string)
|
|
json_array = '["value1", "value2", "value3"]'
|
|
expected_output = "value1\nvalue2\nvalue3"
|
|
self.assertEqual(parse_maybe_json(json_array), expected_output)
|
|
json_string = '"value1"'
|
|
expected_output = 'value1'
|
|
self.assertEqual(parse_maybe_json(json_string), expected_output)
|
|
json_struct = '{"This is a string."}'
|
|
expected_output = 'This is a string.'
|
|
self.assertEqual(parse_maybe_json(json_struct), expected_output)
|
|
json_struct = '["This is a string."]'
|
|
expected_output = 'This is a string.'
|
|
self.assertEqual(parse_maybe_json(json_struct), expected_output)
|
|
json_struct = '{This is a string.}'
|
|
expected_output = 'This is a string.'
|
|
self.assertEqual(parse_maybe_json(json_struct), expected_output)
|
|
json_struct = '[This is a string.]'
|
|
expected_output = 'This is a string.'
|
|
self.assertEqual(parse_maybe_json(json_struct), expected_output)
|
|
|
|
async def test_message_lings(self) -> None:
|
|
request = AIMessage('Lala', 'Hello there!', 'chat', False,)
|
|
message = {'answer': 'Test [Link](https://www.example.com/test)',
|
|
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False}
|
|
expected = AIResponse('Test https://www.example.com/test', True, None, None, None, False)
|
|
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
|
|
message = {'answer': 'Test @[Link](https://www.example.com/test)',
|
|
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False}
|
|
expected = AIResponse('Test Link', True, None, None, None, False)
|
|
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
|
|
message = {'answer': 'Test [Link](https://www.example.com/test) and [Link2](https://xxx) lala',
|
|
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False}
|
|
expected = AIResponse('Test https://www.example.com/test and https://xxx lala', True, None, None, None, False)
|
|
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
|
|
|
|
async def test_on_message_event(self) -> None:
|
|
async def acreate(*a, **kw):
|
|
answer = {'answer': 'Hello!',
|
|
'answer_needed': True,
|
|
'staff': None,
|
|
'picture': None,
|
|
'hack': False}
|
|
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
|
message = self.create_message("Hello there! How are you?")
|
|
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
|
await self.bot.on_message(message)
|
|
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True)
|
|
|
|
async def test_on_message_stort_path(self) -> None:
|
|
async def acreate(*a, **kw):
|
|
answer = {'answer': 'Hello!',
|
|
'answer_needed': True,
|
|
'channel': None,
|
|
'staff': None,
|
|
'picture': None,
|
|
'hack': False}
|
|
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
|
message = self.create_message("Hello there! How are you?")
|
|
message.author.name = 'madeup_name'
|
|
message.channel.name = 'some_channel'
|
|
self.bot.config['short-path'] = [[r'some.*', r'madeup.*']]
|
|
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
|
await self.bot.on_message(message)
|
|
self.assertEqual(self.bot.airesponder.history[-1]["content"],
|
|
'{"user": "madeup_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}')
|
|
message.author.name = 'different_name'
|
|
await self.bot.on_message(message)
|
|
self.assertEqual(self.bot.airesponder.history[-2]["content"],
|
|
'{"user": "different_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}')
|
|
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True)
|
|
|
|
async def test_on_message_event2(self) -> None:
|
|
async def acreate(*a, **kw):
|
|
answer = {'answer': 'Hello!',
|
|
'answer_needed': True,
|
|
'channel': None,
|
|
'staff': 'Hallo staff',
|
|
'picture': None,
|
|
'hack': False}
|
|
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
|
message = self.create_message("Hello there! How are you?")
|
|
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
|
await self.bot.on_message(message)
|
|
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True)
|
|
|
|
async def test_on_message_event3(self) -> None:
|
|
async def acreate(*a, **kw):
|
|
return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]}
|
|
message = self.create_message("Hello there! How are you?")
|
|
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
|
with pytest.raises(RuntimeError, match="Failed to generate answer after multiple retries"):
|
|
await self.bot.on_message(message)
|
|
|
|
@patch("builtins.open", new_callable=mock_open)
|
|
def test_update_history_with_file(self, mock_file):
|
|
self.bot.airesponder.update_history({'content': '{"q": "What\'s your name?"}'}, {'content': '{"a": "AI"}'}, 10)
|
|
self.assertEqual(len(self.bot.airesponder.history), 2)
|
|
self.bot.airesponder.update_history({'content': '{"q1": "Q1"}'}, {'content': '{"a1": "A1"}'}, 2)
|
|
self.bot.airesponder.update_history({'content': '{"q2": "Q2"}'}, {'content': '{"a2": "A2"}'}, 2)
|
|
self.assertEqual(len(self.bot.airesponder.history), 2)
|
|
self.bot.airesponder.history_file = "mock_file.pkl"
|
|
self.bot.airesponder.update_history({'content': '{"q": "What\'s your favorite color?"}'}, {'content': '{"a": "Blue"}'}, 10)
|
|
mock_file.assert_called_once_with("mock_file.pkl", "wb")
|
|
mock_file().write.assert_called_once()
|
|
|
|
async def test_on_message_event4(self) -> None:
|
|
async def acreate(*a, **kw):
|
|
answer = {'answer': 'Hello!',
|
|
'answer_needed': True,
|
|
'staff': 'none',
|
|
'picture': 'Some picture',
|
|
'hack': False}
|
|
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
|
|
|
async def adraw(*a, **kw):
|
|
return {'data': [{'url': 'http:url'}]}
|
|
|
|
def logging_warning(msg):
|
|
raise RuntimeError(msg)
|
|
|
|
class image:
|
|
def __init__(self, *args, **kw):
|
|
pass
|
|
|
|
async def __aenter__(self):
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
return False
|
|
|
|
async def read(self):
|
|
return b'test bytes'
|
|
message = self.create_message("Hello there! How are you?")
|
|
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
|
|
patch.object(openai.Image, 'acreate', new=adraw), \
|
|
patch.object(logging, 'warning', logging_warning), \
|
|
patch.object(aiohttp.ClientSession, 'get', new=image):
|
|
await self.bot.on_message(message)
|
|
message.channel.send.assert_called_once_with("Hello!", files=[ANY], suppress_embeds=True)
|
|
|
|
|
|
if __name__ == "__mait__":
|
|
unittest.main()
|