Fix and remove some tests, make openai calls cachable in a file.
This commit is contained in:
parent
7bcadecb17
commit
c010603178
@ -1,3 +1,4 @@
|
|||||||
|
import sys
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -12,7 +13,7 @@ import pickle
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from functools import lru_cache
|
from functools import lru_cache, wraps
|
||||||
from typing import Optional, List, Dict, Any, Tuple
|
from typing import Optional, List, Dict, Any, Tuple
|
||||||
|
|
||||||
|
|
||||||
@ -57,6 +58,46 @@ def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1, max_attempts
|
|||||||
raise RuntimeError("Max attempts reached in exponential backoff.")
|
raise RuntimeError("Max attempts reached in exponential backoff.")
|
||||||
|
|
||||||
|
|
||||||
|
def async_cache_to_file(filename):
|
||||||
|
cache_file = Path(filename)
|
||||||
|
cache = None
|
||||||
|
if cache_file.exists():
|
||||||
|
try:
|
||||||
|
with cache_file.open('rb') as fd:
|
||||||
|
cache = pickle.load(fd)
|
||||||
|
except Exception:
|
||||||
|
cache = {}
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
if cache is None:
|
||||||
|
sys.stderr.write(f'@@@ forward {func.__name__}({repr(args)}, {repr(kwargs)}')
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
key = json.dumps((func.__name__, list(args[1:]), kwargs), sort_keys=True)
|
||||||
|
if key in cache:
|
||||||
|
sys.stderr.write(f'@@@ cache {func.__name__}({repr(args)}, {repr(kwargs)} -> {cache[key]}')
|
||||||
|
return cache[key]
|
||||||
|
sys.stderr.write(f'@@@ execute {func.__name__}({repr(args)}, {repr(kwargs)}')
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
cache[key] = result
|
||||||
|
with cache_file.open('wb') as fd:
|
||||||
|
pickle.dump(cache, fd)
|
||||||
|
return result
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
@async_cache_to_file('openai_chat.dat')
|
||||||
|
async def openai_chat(client, *args, **kwargs):
|
||||||
|
return await client.chat.completions.create(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@async_cache_to_file('openai_chat.dat')
|
||||||
|
async def openai_image(client, *args, **kwargs):
|
||||||
|
return await client.images.generate(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def parse_maybe_json(json_string):
|
def parse_maybe_json(json_string):
|
||||||
if json_string is None:
|
if json_string is None:
|
||||||
return None
|
return None
|
||||||
@ -156,7 +197,7 @@ class AIResponder(object):
|
|||||||
async def _draw_openai(self, description: str) -> BytesIO:
|
async def _draw_openai(self, description: str) -> BytesIO:
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
try:
|
try:
|
||||||
response = await self.client.images.generate(prompt=description, n=1, size="1024x1024", model="dall-e-3")
|
response = await openai_image(self.client, prompt=description, n=1, size="1024x1024", model="dall-e-3")
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(response.data[0].url) as image:
|
async with session.get(response.data[0].url) as image:
|
||||||
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
|
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
|
||||||
@ -273,13 +314,14 @@ class AIResponder(object):
|
|||||||
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
||||||
model = self.config["model"]
|
model = self.config["model"]
|
||||||
try:
|
try:
|
||||||
result = await self.client.chat.completions.create(model=model,
|
result = await openai_chat(self.client,
|
||||||
messages=messages,
|
model=model,
|
||||||
temperature=self.config["temperature"],
|
messages=messages,
|
||||||
max_tokens=self.config["max-tokens"],
|
temperature=self.config["temperature"],
|
||||||
top_p=self.config["top-p"],
|
max_tokens=self.config["max-tokens"],
|
||||||
presence_penalty=self.config["presence-penalty"],
|
top_p=self.config["top-p"],
|
||||||
frequency_penalty=self.config["frequency-penalty"])
|
presence_penalty=self.config["presence-penalty"],
|
||||||
|
frequency_penalty=self.config["frequency-penalty"])
|
||||||
answer_obj = result.choices[0].message
|
answer_obj = result.choices[0].message
|
||||||
answer = {'content': answer_obj.content, 'role': answer_obj.role}
|
answer = {'content': answer_obj.content, 'role': answer_obj.role}
|
||||||
self.rate_limit_backoff = exponential_backoff()
|
self.rate_limit_backoff = exponential_backoff()
|
||||||
@ -307,10 +349,11 @@ class AIResponder(object):
|
|||||||
messages = [{"role": "system", "content": self.config["fix-description"]},
|
messages = [{"role": "system", "content": self.config["fix-description"]},
|
||||||
{"role": "user", "content": answer}]
|
{"role": "user", "content": answer}]
|
||||||
try:
|
try:
|
||||||
result = await self.client.chat.completions.create(model=self.config["fix-model"],
|
result = await openai_chat(self.client,
|
||||||
messages=messages,
|
model=self.config["fix-model"],
|
||||||
temperature=0.2,
|
messages=messages,
|
||||||
max_tokens=2048)
|
temperature=0.2,
|
||||||
|
max_tokens=2048)
|
||||||
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
|
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
|
||||||
response = result.choices[0].message.content
|
response = result.choices[0].message.content
|
||||||
start, end = response.find("{"), response.rfind("}")
|
start, end = response.find("{"), response.rfind("}")
|
||||||
@ -331,10 +374,11 @@ class AIResponder(object):
|
|||||||
f" if it is not already in {language}, otherwise you just copy it."},
|
f" if it is not already in {language}, otherwise you just copy it."},
|
||||||
{"role": "user", "content": text}]
|
{"role": "user", "content": text}]
|
||||||
try:
|
try:
|
||||||
result = await self.client.chat.completions.create(model=self.config["fix-model"],
|
result = await openai_chat(self.client,
|
||||||
messages=message,
|
model=self.config["fix-model"],
|
||||||
temperature=0.2,
|
messages=message,
|
||||||
max_tokens=2048)
|
temperature=0.2,
|
||||||
|
max_tokens=2048)
|
||||||
response = result.choices[0].message.content
|
response = result.choices[0].message.content
|
||||||
logging.info(f"got this translated message:\n{pp(response)}")
|
logging.info(f"got this translated message:\n{pp(response)}")
|
||||||
return response
|
return response
|
||||||
|
|||||||
BIN
openai_chat.dat
Normal file
BIN
openai_chat.dat
Normal file
Binary file not shown.
@ -11,7 +11,7 @@ class TestAIResponder(TestBotBase):
|
|||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
await super().asyncSetUp()
|
await super().asyncSetUp()
|
||||||
self.system = r"""
|
self.system = r"""
|
||||||
You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you. Current date is {date} and time is {time}.
|
You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you.
|
||||||
|
|
||||||
Every message from users is a dictionary in JSON format with the following fields:
|
Every message from users is a dictionary in JSON format with the following fields:
|
||||||
1. `user`: name of the user who wrote the message.
|
1. `user`: name of the user who wrote the message.
|
||||||
@ -32,7 +32,7 @@ You always try to say something positive about the current day and the Fjærkroa
|
|||||||
self.config_data["system"] = self.system
|
self.config_data["system"] = self.system
|
||||||
|
|
||||||
def assertAIResponse(self, resp1, resp2,
|
def assertAIResponse(self, resp1, resp2,
|
||||||
acmp=lambda a, b: type(a) == str or len(a) > 10,
|
acmp=lambda a, b: type(a) == str and len(a) > 10,
|
||||||
scmp=lambda a, b: a == b,
|
scmp=lambda a, b: a == b,
|
||||||
pcmp=lambda a, b: a == b):
|
pcmp=lambda a, b: a == b):
|
||||||
self.assertEqual(acmp(resp1.answer, resp2.answer), True)
|
self.assertEqual(acmp(resp1.answer, resp2.answer), True)
|
||||||
|
|||||||
@ -1,12 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import aiohttp
|
|
||||||
import json
|
|
||||||
import toml
|
import toml
|
||||||
import openai
|
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open
|
||||||
import logging
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY
|
|
||||||
from fjerkroa_bot import FjerkroaBot
|
from fjerkroa_bot import FjerkroaBot
|
||||||
from fjerkroa_bot.ai_responder import parse_maybe_json, AIResponse, AIMessage
|
from fjerkroa_bot.ai_responder import parse_maybe_json, AIResponse, AIMessage
|
||||||
from discord import User, Message, TextChannel
|
from discord import User, Message, TextChannel
|
||||||
@ -90,75 +85,27 @@ class TestFunctionality(TestBotBase):
|
|||||||
async def test_message_lings(self) -> None:
|
async def test_message_lings(self) -> None:
|
||||||
request = AIMessage('Lala', 'Hello there!', 'chat', False,)
|
request = AIMessage('Lala', 'Hello there!', 'chat', False,)
|
||||||
message = {'answer': 'Test [Link](https://www.example.com/test)',
|
message = {'answer': 'Test [Link](https://www.example.com/test)',
|
||||||
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False}
|
'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False}
|
||||||
expected = AIResponse('Test https://www.example.com/test', True, None, None, None, False)
|
expected = AIResponse('Test https://www.example.com/test', True, 'chat', None, None, False)
|
||||||
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
|
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
|
||||||
message = {'answer': 'Test @[Link](https://www.example.com/test)',
|
message = {'answer': 'Test @[Link](https://www.example.com/test)',
|
||||||
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False}
|
'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False}
|
||||||
expected = AIResponse('Test Link', True, None, None, None, False)
|
expected = AIResponse('Test Link', True, 'chat', None, None, False)
|
||||||
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
|
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
|
||||||
message = {'answer': 'Test [Link](https://www.example.com/test) and [Link2](https://xxx) lala',
|
message = {'answer': 'Test [Link](https://www.example.com/test) and [Link2](https://xxx) lala',
|
||||||
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False}
|
'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False}
|
||||||
expected = AIResponse('Test https://www.example.com/test and https://xxx lala', True, None, None, None, False)
|
expected = AIResponse('Test https://www.example.com/test and https://xxx lala', True, 'chat', None, None, False)
|
||||||
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
|
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
|
||||||
|
|
||||||
async def test_on_message_event(self) -> None:
|
|
||||||
async def acreate(*a, **kw):
|
|
||||||
answer = {'answer': 'Hello!',
|
|
||||||
'answer_needed': True,
|
|
||||||
'staff': None,
|
|
||||||
'picture': None,
|
|
||||||
'hack': False}
|
|
||||||
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
|
||||||
message = self.create_message("Hello there! How are you?")
|
|
||||||
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
|
||||||
await self.bot.on_message(message)
|
|
||||||
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore
|
|
||||||
|
|
||||||
async def test_on_message_stort_path(self) -> None:
|
async def test_on_message_stort_path(self) -> None:
|
||||||
async def acreate(*a, **kw):
|
|
||||||
answer = {'answer': 'Hello!',
|
|
||||||
'answer_needed': True,
|
|
||||||
'channel': None,
|
|
||||||
'staff': None,
|
|
||||||
'picture': None,
|
|
||||||
'hack': False}
|
|
||||||
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
|
||||||
message = self.create_message("Hello there! How are you?")
|
message = self.create_message("Hello there! How are you?")
|
||||||
message.author.name = 'madeup_name'
|
message.author.name = 'madeup_name'
|
||||||
message.channel.name = 'some_channel' # type: ignore
|
message.channel.name = 'some_channel' # type: ignore
|
||||||
self.bot.config['short-path'] = [[r'some.*', r'madeup.*']]
|
self.bot.config['short-path'] = [[r'some.*', r'madeup.*']]
|
||||||
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
await self.bot.on_message(message)
|
||||||
await self.bot.on_message(message)
|
self.assertEqual(self.bot.airesponder.history[-1]["content"],
|
||||||
self.assertEqual(self.bot.airesponder.history[-1]["content"],
|
'{"user": "madeup_name", "message": "Hello, how are you?",'
|
||||||
'{"user": "madeup_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}')
|
' "channel": "some_channel", "direct": false, "historise_question": true}')
|
||||||
message.author.name = 'different_name'
|
|
||||||
await self.bot.on_message(message)
|
|
||||||
self.assertEqual(self.bot.airesponder.history[-2]["content"],
|
|
||||||
'{"user": "different_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}')
|
|
||||||
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore
|
|
||||||
|
|
||||||
async def test_on_message_event2(self) -> None:
|
|
||||||
async def acreate(*a, **kw):
|
|
||||||
answer = {'answer': 'Hello!',
|
|
||||||
'answer_needed': True,
|
|
||||||
'channel': None,
|
|
||||||
'staff': 'Hallo staff',
|
|
||||||
'picture': None,
|
|
||||||
'hack': False}
|
|
||||||
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
|
||||||
message = self.create_message("Hello there! How are you?")
|
|
||||||
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
|
||||||
await self.bot.on_message(message)
|
|
||||||
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore
|
|
||||||
|
|
||||||
async def test_on_message_event3(self) -> None:
|
|
||||||
async def acreate(*a, **kw):
|
|
||||||
return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]}
|
|
||||||
message = self.create_message("Hello there! How are you?")
|
|
||||||
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
|
||||||
with pytest.raises(RuntimeError, match="Failed to generate answer after multiple retries"):
|
|
||||||
await self.bot.on_message(message)
|
|
||||||
|
|
||||||
@patch("builtins.open", new_callable=mock_open)
|
@patch("builtins.open", new_callable=mock_open)
|
||||||
def test_update_history_with_file(self, mock_file):
|
def test_update_history_with_file(self, mock_file):
|
||||||
@ -172,41 +119,6 @@ class TestFunctionality(TestBotBase):
|
|||||||
mock_file.assert_called_once_with("mock_file.pkl", "wb")
|
mock_file.assert_called_once_with("mock_file.pkl", "wb")
|
||||||
mock_file().write.assert_called_once()
|
mock_file().write.assert_called_once()
|
||||||
|
|
||||||
async def test_on_message_event4(self) -> None:
|
|
||||||
async def acreate(*a, **kw):
|
|
||||||
answer = {'answer': 'Hello!',
|
|
||||||
'answer_needed': True,
|
|
||||||
'staff': 'none',
|
|
||||||
'picture': 'Some picture',
|
|
||||||
'hack': False}
|
|
||||||
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
|
||||||
|
|
||||||
async def adraw(*a, **kw):
|
|
||||||
return {'data': [{'url': 'http:url'}]}
|
|
||||||
|
|
||||||
def logging_warning(msg):
|
|
||||||
raise RuntimeError(msg)
|
|
||||||
|
|
||||||
class image:
|
|
||||||
def __init__(self, *args, **kw):
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def read(self):
|
|
||||||
return b'test bytes'
|
|
||||||
message = self.create_message("Hello there! How are you?")
|
|
||||||
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
|
|
||||||
patch.object(openai.Image, 'acreate', new=adraw), \
|
|
||||||
patch.object(logging, 'warning', logging_warning), \
|
|
||||||
patch.object(aiohttp.ClientSession, 'get', new=image):
|
|
||||||
await self.bot.on_message(message)
|
|
||||||
message.channel.send.assert_called_once_with("Hello!", files=[ANY], suppress_embeds=True) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__mait__":
|
if __name__ == "__mait__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user