Fix and remove some tests, make openai calls cachable in a file.

This commit is contained in:
OK 2023-12-02 22:36:00 +01:00
parent 7bcadecb17
commit c010603178
4 changed files with 74 additions and 118 deletions

View File

@ -1,3 +1,4 @@
import sys
import os import os
import json import json
import asyncio import asyncio
@ -12,7 +13,7 @@ import pickle
from pathlib import Path from pathlib import Path
from io import BytesIO from io import BytesIO
from pprint import pformat from pprint import pformat
from functools import lru_cache from functools import lru_cache, wraps
from typing import Optional, List, Dict, Any, Tuple from typing import Optional, List, Dict, Any, Tuple
@ -57,6 +58,46 @@ def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1, max_attempts
raise RuntimeError("Max attempts reached in exponential backoff.") raise RuntimeError("Max attempts reached in exponential backoff.")
def async_cache_to_file(filename):
cache_file = Path(filename)
cache = None
if cache_file.exists():
try:
with cache_file.open('rb') as fd:
cache = pickle.load(fd)
except Exception:
cache = {}
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
if cache is None:
sys.stderr.write(f'@@@ forward {func.__name__}({repr(args)}, {repr(kwargs)}')
return await func(*args, **kwargs)
key = json.dumps((func.__name__, list(args[1:]), kwargs), sort_keys=True)
if key in cache:
sys.stderr.write(f'@@@ cache {func.__name__}({repr(args)}, {repr(kwargs)} -> {cache[key]}')
return cache[key]
sys.stderr.write(f'@@@ execute {func.__name__}({repr(args)}, {repr(kwargs)}')
result = await func(*args, **kwargs)
cache[key] = result
with cache_file.open('wb') as fd:
pickle.dump(cache, fd)
return result
return wrapper
return decorator
@async_cache_to_file('openai_chat.dat')
async def openai_chat(client, *args, **kwargs):
return await client.chat.completions.create(*args, **kwargs)
@async_cache_to_file('openai_chat.dat')
async def openai_image(client, *args, **kwargs):
return await client.images.generate(*args, **kwargs)
def parse_maybe_json(json_string): def parse_maybe_json(json_string):
if json_string is None: if json_string is None:
return None return None
@ -156,7 +197,7 @@ class AIResponder(object):
async def _draw_openai(self, description: str) -> BytesIO: async def _draw_openai(self, description: str) -> BytesIO:
for _ in range(3): for _ in range(3):
try: try:
response = await self.client.images.generate(prompt=description, n=1, size="1024x1024", model="dall-e-3") response = await openai_image(self.client, prompt=description, n=1, size="1024x1024", model="dall-e-3")
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(response.data[0].url) as image: async with session.get(response.data[0].url) as image:
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}') logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
@ -273,13 +314,14 @@ class AIResponder(object):
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]: async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
model = self.config["model"] model = self.config["model"]
try: try:
result = await self.client.chat.completions.create(model=model, result = await openai_chat(self.client,
messages=messages, model=model,
temperature=self.config["temperature"], messages=messages,
max_tokens=self.config["max-tokens"], temperature=self.config["temperature"],
top_p=self.config["top-p"], max_tokens=self.config["max-tokens"],
presence_penalty=self.config["presence-penalty"], top_p=self.config["top-p"],
frequency_penalty=self.config["frequency-penalty"]) presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"])
answer_obj = result.choices[0].message answer_obj = result.choices[0].message
answer = {'content': answer_obj.content, 'role': answer_obj.role} answer = {'content': answer_obj.content, 'role': answer_obj.role}
self.rate_limit_backoff = exponential_backoff() self.rate_limit_backoff = exponential_backoff()
@ -307,10 +349,11 @@ class AIResponder(object):
messages = [{"role": "system", "content": self.config["fix-description"]}, messages = [{"role": "system", "content": self.config["fix-description"]},
{"role": "user", "content": answer}] {"role": "user", "content": answer}]
try: try:
result = await self.client.chat.completions.create(model=self.config["fix-model"], result = await openai_chat(self.client,
messages=messages, model=self.config["fix-model"],
temperature=0.2, messages=messages,
max_tokens=2048) temperature=0.2,
max_tokens=2048)
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}") logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
response = result.choices[0].message.content response = result.choices[0].message.content
start, end = response.find("{"), response.rfind("}") start, end = response.find("{"), response.rfind("}")
@ -331,10 +374,11 @@ class AIResponder(object):
f" if it is not already in {language}, otherwise you just copy it."}, f" if it is not already in {language}, otherwise you just copy it."},
{"role": "user", "content": text}] {"role": "user", "content": text}]
try: try:
result = await self.client.chat.completions.create(model=self.config["fix-model"], result = await openai_chat(self.client,
messages=message, model=self.config["fix-model"],
temperature=0.2, messages=message,
max_tokens=2048) temperature=0.2,
max_tokens=2048)
response = result.choices[0].message.content response = result.choices[0].message.content
logging.info(f"got this translated message:\n{pp(response)}") logging.info(f"got this translated message:\n{pp(response)}")
return response return response

BIN
openai_chat.dat Normal file

Binary file not shown.

View File

@ -11,7 +11,7 @@ class TestAIResponder(TestBotBase):
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp() await super().asyncSetUp()
self.system = r""" self.system = r"""
You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you. Current date is {date} and time is {time}. You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you.
Every message from users is a dictionary in JSON format with the following fields: Every message from users is a dictionary in JSON format with the following fields:
1. `user`: name of the user who wrote the message. 1. `user`: name of the user who wrote the message.
@ -32,7 +32,7 @@ You always try to say something positive about the current day and the Fjærkroa
self.config_data["system"] = self.system self.config_data["system"] = self.system
def assertAIResponse(self, resp1, resp2, def assertAIResponse(self, resp1, resp2,
acmp=lambda a, b: type(a) == str or len(a) > 10, acmp=lambda a, b: type(a) == str and len(a) > 10,
scmp=lambda a, b: a == b, scmp=lambda a, b: a == b,
pcmp=lambda a, b: a == b): pcmp=lambda a, b: a == b):
self.assertEqual(acmp(resp1.answer, resp2.answer), True) self.assertEqual(acmp(resp1.answer, resp2.answer), True)

View File

@ -1,12 +1,7 @@
import os import os
import unittest import unittest
import aiohttp
import json
import toml import toml
import openai from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open
import logging
import pytest
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY
from fjerkroa_bot import FjerkroaBot from fjerkroa_bot import FjerkroaBot
from fjerkroa_bot.ai_responder import parse_maybe_json, AIResponse, AIMessage from fjerkroa_bot.ai_responder import parse_maybe_json, AIResponse, AIMessage
from discord import User, Message, TextChannel from discord import User, Message, TextChannel
@ -90,75 +85,27 @@ class TestFunctionality(TestBotBase):
async def test_message_lings(self) -> None: async def test_message_lings(self) -> None:
request = AIMessage('Lala', 'Hello there!', 'chat', False,) request = AIMessage('Lala', 'Hello there!', 'chat', False,)
message = {'answer': 'Test [Link](https://www.example.com/test)', message = {'answer': 'Test [Link](https://www.example.com/test)',
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False} 'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False}
expected = AIResponse('Test https://www.example.com/test', True, None, None, None, False) expected = AIResponse('Test https://www.example.com/test', True, 'chat', None, None, False)
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected)) self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
message = {'answer': 'Test @[Link](https://www.example.com/test)', message = {'answer': 'Test @[Link](https://www.example.com/test)',
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False} 'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False}
expected = AIResponse('Test Link', True, None, None, None, False) expected = AIResponse('Test Link', True, 'chat', None, None, False)
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected)) self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
message = {'answer': 'Test [Link](https://www.example.com/test) and [Link2](https://xxx) lala', message = {'answer': 'Test [Link](https://www.example.com/test) and [Link2](https://xxx) lala',
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False} 'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False}
expected = AIResponse('Test https://www.example.com/test and https://xxx lala', True, None, None, None, False) expected = AIResponse('Test https://www.example.com/test and https://xxx lala', True, 'chat', None, None, False)
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected)) self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
async def test_on_message_event(self) -> None:
async def acreate(*a, **kw):
answer = {'answer': 'Hello!',
'answer_needed': True,
'staff': None,
'picture': None,
'hack': False}
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
await self.bot.on_message(message)
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore
async def test_on_message_stort_path(self) -> None: async def test_on_message_stort_path(self) -> None:
async def acreate(*a, **kw):
answer = {'answer': 'Hello!',
'answer_needed': True,
'channel': None,
'staff': None,
'picture': None,
'hack': False}
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
message = self.create_message("Hello there! How are you?") message = self.create_message("Hello there! How are you?")
message.author.name = 'madeup_name' message.author.name = 'madeup_name'
message.channel.name = 'some_channel' # type: ignore message.channel.name = 'some_channel' # type: ignore
self.bot.config['short-path'] = [[r'some.*', r'madeup.*']] self.bot.config['short-path'] = [[r'some.*', r'madeup.*']]
with patch.object(openai.ChatCompletion, 'acreate', new=acreate): await self.bot.on_message(message)
await self.bot.on_message(message) self.assertEqual(self.bot.airesponder.history[-1]["content"],
self.assertEqual(self.bot.airesponder.history[-1]["content"], '{"user": "madeup_name", "message": "Hello, how are you?",'
'{"user": "madeup_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}') ' "channel": "some_channel", "direct": false, "historise_question": true}')
message.author.name = 'different_name'
await self.bot.on_message(message)
self.assertEqual(self.bot.airesponder.history[-2]["content"],
'{"user": "different_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}')
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore
async def test_on_message_event2(self) -> None:
async def acreate(*a, **kw):
answer = {'answer': 'Hello!',
'answer_needed': True,
'channel': None,
'staff': 'Hallo staff',
'picture': None,
'hack': False}
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
await self.bot.on_message(message)
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore
async def test_on_message_event3(self) -> None:
async def acreate(*a, **kw):
return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]}
message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
with pytest.raises(RuntimeError, match="Failed to generate answer after multiple retries"):
await self.bot.on_message(message)
@patch("builtins.open", new_callable=mock_open) @patch("builtins.open", new_callable=mock_open)
def test_update_history_with_file(self, mock_file): def test_update_history_with_file(self, mock_file):
@ -172,41 +119,6 @@ class TestFunctionality(TestBotBase):
mock_file.assert_called_once_with("mock_file.pkl", "wb") mock_file.assert_called_once_with("mock_file.pkl", "wb")
mock_file().write.assert_called_once() mock_file().write.assert_called_once()
async def test_on_message_event4(self) -> None:
async def acreate(*a, **kw):
answer = {'answer': 'Hello!',
'answer_needed': True,
'staff': 'none',
'picture': 'Some picture',
'hack': False}
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
async def adraw(*a, **kw):
return {'data': [{'url': 'http:url'}]}
def logging_warning(msg):
raise RuntimeError(msg)
class image:
def __init__(self, *args, **kw):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
return False
async def read(self):
return b'test bytes'
message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
patch.object(openai.Image, 'acreate', new=adraw), \
patch.object(logging, 'warning', logging_warning), \
patch.object(aiohttp.ClientSession, 'get', new=image):
await self.bot.on_message(message)
message.channel.send.assert_called_once_with("Hello!", files=[ANY], suppress_embeds=True) # type: ignore
if __name__ == "__mait__": if __name__ == "__mait__":
unittest.main() unittest.main()