From c010603178e89ce329ca3e8a516e5921889a4c9f Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 2 Dec 2023 22:36:00 +0100 Subject: [PATCH] Fix and remove some tests, make openai calls cachable in a file. --- fjerkroa_bot/ai_responder.py | 78 +++++++++++++++++++------ openai_chat.dat | Bin 0 -> 13318 bytes tests/test_ai.py | 4 +- tests/test_main.py | 110 ++++------------------------------- 4 files changed, 74 insertions(+), 118 deletions(-) create mode 100644 openai_chat.dat diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index 2002833..a0260e0 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -1,3 +1,4 @@ +import sys import os import json import asyncio @@ -12,7 +13,7 @@ import pickle from pathlib import Path from io import BytesIO from pprint import pformat -from functools import lru_cache +from functools import lru_cache, wraps from typing import Optional, List, Dict, Any, Tuple @@ -57,6 +58,46 @@ def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1, max_attempts raise RuntimeError("Max attempts reached in exponential backoff.") +def async_cache_to_file(filename): + cache_file = Path(filename) + cache = None + if cache_file.exists(): + try: + with cache_file.open('rb') as fd: + cache = pickle.load(fd) + except Exception: + cache = {} + + def decorator(func): + @wraps(func) + async def wrapper(*args, **kwargs): + if cache is None: + sys.stderr.write(f'@@@ forward {func.__name__}({repr(args)}, {repr(kwargs)}') + return await func(*args, **kwargs) + key = json.dumps((func.__name__, list(args[1:]), kwargs), sort_keys=True) + if key in cache: + sys.stderr.write(f'@@@ cache {func.__name__}({repr(args)}, {repr(kwargs)} -> {cache[key]}') + return cache[key] + sys.stderr.write(f'@@@ execute {func.__name__}({repr(args)}, {repr(kwargs)}') + result = await func(*args, **kwargs) + cache[key] = result + with cache_file.open('wb') as fd: + pickle.dump(cache, fd) + return result + return wrapper + return decorator + + +@async_cache_to_file('openai_chat.dat') +async def openai_chat(client, *args, **kwargs): + return await client.chat.completions.create(*args, **kwargs) + + +@async_cache_to_file('openai_chat.dat') +async def openai_image(client, *args, **kwargs): + return await client.images.generate(*args, **kwargs) + + def parse_maybe_json(json_string): if json_string is None: return None @@ -156,7 +197,7 @@ class AIResponder(object): async def _draw_openai(self, description: str) -> BytesIO: for _ in range(3): try: - response = await self.client.images.generate(prompt=description, n=1, size="1024x1024", model="dall-e-3") + response = await openai_image(self.client, prompt=description, n=1, size="1024x1024", model="dall-e-3") async with aiohttp.ClientSession() as session: async with session.get(response.data[0].url) as image: logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}') @@ -273,13 +314,14 @@ class AIResponder(object): async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]: model = self.config["model"] try: - result = await self.client.chat.completions.create(model=model, - messages=messages, - temperature=self.config["temperature"], - max_tokens=self.config["max-tokens"], - top_p=self.config["top-p"], - presence_penalty=self.config["presence-penalty"], - frequency_penalty=self.config["frequency-penalty"]) + result = await openai_chat(self.client, + model=model, + messages=messages, + temperature=self.config["temperature"], + max_tokens=self.config["max-tokens"], + top_p=self.config["top-p"], + presence_penalty=self.config["presence-penalty"], + frequency_penalty=self.config["frequency-penalty"]) answer_obj = result.choices[0].message answer = {'content': answer_obj.content, 'role': answer_obj.role} self.rate_limit_backoff = exponential_backoff() @@ -307,10 +349,11 @@ class AIResponder(object): messages = [{"role": "system", "content": self.config["fix-description"]}, {"role": "user", "content": answer}] try: - result = await self.client.chat.completions.create(model=self.config["fix-model"], - messages=messages, - temperature=0.2, - max_tokens=2048) + result = await openai_chat(self.client, + model=self.config["fix-model"], + messages=messages, + temperature=0.2, + max_tokens=2048) logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}") response = result.choices[0].message.content start, end = response.find("{"), response.rfind("}") @@ -331,10 +374,11 @@ class AIResponder(object): f" if it is not already in {language}, otherwise you just copy it."}, {"role": "user", "content": text}] try: - result = await self.client.chat.completions.create(model=self.config["fix-model"], - messages=message, - temperature=0.2, - max_tokens=2048) + result = await openai_chat(self.client, + model=self.config["fix-model"], + messages=message, + temperature=0.2, + max_tokens=2048) response = result.choices[0].message.content logging.info(f"got this translated message:\n{pp(response)}") return response diff --git a/openai_chat.dat b/openai_chat.dat new file mode 100644 index 0000000000000000000000000000000000000000..e82bf6fb756b98dd27ea49eb3fb4e8912fc1557a GIT binary patch literal 13318 zcmeHO&2JmW6}JsLj*})q4F~NZ4aP)~#xBLmk}WCjA&^zsmSxGVB|EZBgBmS|oE3L> zHv2(}7KA3(V1dT)7D&_dkOXOZNb*nQU+JYlkRIAwFa5olCAllgA9b4+N$7*6ozFLK ze(!7E(_06B_s_|F`rl6a>e^qP-nZ{|N%_L#cFmGJDwWyoJ7uqMx9R7 zEuCejD&u8Va`{#*QZ3SBAeEWLn0yc1#g-C1h*tp;6}iR7a>-QY%tpNc2!)ddR1?p)@K?2StC$x zwvru!r*+$kY~^tb0lk-2t}N5>+}@Fh)eYr1s$+Xi*06J<_nh%Y#;3b-auQO}_?+gNp>YUf977 zy9mW+;T-z9A40fUrAWhI>VBA^+!OGD|R=qIg2ZD!#zTWm zDt79F^}N_|`UDm_iGReN5K>0yrR5n&rM5i%AmB6z2qWHTV0+#_7s#fu1^Hxp7=vV( z6Rh|5_kd6W58!6Iz6v4^aljHG@6r!V32lMx&R{zqfQ$n{*P~cH;2mt7wbAMkZ;`I> zrHWLU)kSCp_6B4vVww{;1_vntZAjb~5w{(18o&E04DAg^lxa)Df!?|u+i@6HX+oI< zoNkN9k%AD|7Iz$IkaZ6Ns=4W@J~bIr^Yd_qDV#xatAzt(oX}7y6IQsEJP^||N`%hd zumeB_C*xk1`GIKLDh9C-Q(rgZERbmGgkq$HD{P^`z|`y<*(L@)*uUygKk0%ERqtZq zy4D(`8|WH92Z1t#F7yG+M(pTyjA&<+(+~sjG7+ZHC7^jhE(xE}4^SG=tkQ-GvNZDO z7W(%~Dw0mO2SzfwJAn-qfqTo)sC7#ZZ51|)8_8&ILP|&n6J|1?RK1%dOGYb1&|Mjx zvUJ8^W$P|a&FcNNQhpcB<(O#fgOq({E3jP z5p|Q)?BuB<8cI8GJxZ>s3x^19`q9{A<;)lq#)hJ?xU~>2i5PN8a4w%j8=3_Z$1AU4 zaHRa2k1seaJ9pCLxN&kSQP+o+CW`fcH5h4FUm`0_o}WYW+@RrpdgARAZc_4etwxr# zR?}Yx?0TBKLQ7bfJ2rFe+R6n>p5Ag6h1-sug{ZZ9{?^R={M^i&r_$uGC6#T7Fujw? z1Ibf!`YY`%Ef3TB&oyk%4rMKX>p+AY1k3z1Ib?fvv6U|FOAcpNG)+!C@*K7P*cXfE zSu#%AuDuUuSV)t{efjG_<7a(dsBt5aQSA#tViSWD#u84BnGjLj&0%G-%B+ z_{;=}&Oxf3RwCVsZw<>GaG0bbO^%pdVY5{NQ`t3R|va!5-=nj9?%#Mt=MsmU~XN^NXH)2GP`1jtGr&PR>J z*8r}j2(U9wxttt|$x~00pA>=AM|~~pe_1pqc3-0;N^V3V`Qu!?AKA)d^G*evi*y29m`N#Rb zmy&}-a4KKU3!P-k^7wM@?nz*@6y)0SvLR1g`WiAJPwpuVvOFvXva$>;P2_2rTHG(E z7k@5i7R$M99v_VRk30yBqo{clHSdb$Mp1J<>SUv+nL@=;)JzX7qo`TYa|XiFQPiv- zTfbXT^PGumEChccoAv3DcPS2Si_KnLFfX-N@p_=VZ)r$f@Eemof?^q;Ur`s;+LhUJ z!PMpJt;yLlcNex+SKVSzoOdDrjYzoQULs^NIfr!Z9wLR|K$adhj8AoZMUt zy%Rm;aw=@qJU!z1qn812?2Uk%hr2c&AoY_`gy;)SKAzIE5Eh<(gf|H*%O2=$t}c{W z7~@S150>kGcZ?G###{mglp$IO4uw{Fplesx(`TJW$jwmNU^gRt#6$ELSYekD;NYdH<>7EJUKI1_ z2X0q^DY+NJ;|~hipD5rS1jjAu4h02u*^~{_nh%%{fACEO%!Lpl*(>~A8w!8_J{0~g z=EL8&QU#P_)q10vIm(Sw()a;>*2oth z)y(Od$WhJQ_pN5`f9#77*Y%m_5>uo17|v=nSCIb%9;*p73Qsh>NJgaZ~+Le>-rRIfrZIX|lySq8JveLbvq8q79c31P<>s9mIM>S7cNVmxMl7HkQ zp(_ol5)MMt=3Ul_mW7|8z!&WFv%mcSW0CZr8d+o-sPKBgZC&<6$`NUW3Y+&f@mC)B zHPDe1aw`3@!#%SLS(#}MU40z&D+0Gmy3QCK+q%xRbXFI+5{OWt?BcU`@1>ocau0oj zb&7R^^8G^f(g(}(kNN7QM+Hanqf2KI6v>bGu3q|NkLsmQml1vxs+T@nyeL0kx_Ll; MvGm~+MKE#YUmB^2)&Kwi literal 0 HcmV?d00001 diff --git a/tests/test_ai.py b/tests/test_ai.py index a01fae6..a25a722 100644 --- a/tests/test_ai.py +++ b/tests/test_ai.py @@ -11,7 +11,7 @@ class TestAIResponder(TestBotBase): async def asyncSetUp(self): await super().asyncSetUp() self.system = r""" -You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you. Current date is {date} and time is {time}. +You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you. Every message from users is a dictionary in JSON format with the following fields: 1. `user`: name of the user who wrote the message. @@ -32,7 +32,7 @@ You always try to say something positive about the current day and the Fjærkroa self.config_data["system"] = self.system def assertAIResponse(self, resp1, resp2, - acmp=lambda a, b: type(a) == str or len(a) > 10, + acmp=lambda a, b: type(a) == str and len(a) > 10, scmp=lambda a, b: a == b, pcmp=lambda a, b: a == b): self.assertEqual(acmp(resp1.answer, resp2.answer), True) diff --git a/tests/test_main.py b/tests/test_main.py index 9d0a319..7dfe901 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,12 +1,7 @@ import os import unittest -import aiohttp -import json import toml -import openai -import logging -import pytest -from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY +from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open from fjerkroa_bot import FjerkroaBot from fjerkroa_bot.ai_responder import parse_maybe_json, AIResponse, AIMessage from discord import User, Message, TextChannel @@ -90,75 +85,27 @@ class TestFunctionality(TestBotBase): async def test_message_lings(self) -> None: request = AIMessage('Lala', 'Hello there!', 'chat', False,) message = {'answer': 'Test [Link](https://www.example.com/test)', - 'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False} - expected = AIResponse('Test https://www.example.com/test', True, None, None, None, False) + 'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False} + expected = AIResponse('Test https://www.example.com/test', True, 'chat', None, None, False) self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected)) message = {'answer': 'Test @[Link](https://www.example.com/test)', - 'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False} - expected = AIResponse('Test Link', True, None, None, None, False) + 'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False} + expected = AIResponse('Test Link', True, 'chat', None, None, False) self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected)) message = {'answer': 'Test [Link](https://www.example.com/test) and [Link2](https://xxx) lala', - 'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False} - expected = AIResponse('Test https://www.example.com/test and https://xxx lala', True, None, None, None, False) + 'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False} + expected = AIResponse('Test https://www.example.com/test and https://xxx lala', True, 'chat', None, None, False) self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected)) - async def test_on_message_event(self) -> None: - async def acreate(*a, **kw): - answer = {'answer': 'Hello!', - 'answer_needed': True, - 'staff': None, - 'picture': None, - 'hack': False} - return {'choices': [{'message': {'content': json.dumps(answer)}}]} - message = self.create_message("Hello there! How are you?") - with patch.object(openai.ChatCompletion, 'acreate', new=acreate): - await self.bot.on_message(message) - message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore - async def test_on_message_stort_path(self) -> None: - async def acreate(*a, **kw): - answer = {'answer': 'Hello!', - 'answer_needed': True, - 'channel': None, - 'staff': None, - 'picture': None, - 'hack': False} - return {'choices': [{'message': {'content': json.dumps(answer)}}]} message = self.create_message("Hello there! How are you?") message.author.name = 'madeup_name' message.channel.name = 'some_channel' # type: ignore self.bot.config['short-path'] = [[r'some.*', r'madeup.*']] - with patch.object(openai.ChatCompletion, 'acreate', new=acreate): - await self.bot.on_message(message) - self.assertEqual(self.bot.airesponder.history[-1]["content"], - '{"user": "madeup_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}') - message.author.name = 'different_name' - await self.bot.on_message(message) - self.assertEqual(self.bot.airesponder.history[-2]["content"], - '{"user": "different_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}') - message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore - - async def test_on_message_event2(self) -> None: - async def acreate(*a, **kw): - answer = {'answer': 'Hello!', - 'answer_needed': True, - 'channel': None, - 'staff': 'Hallo staff', - 'picture': None, - 'hack': False} - return {'choices': [{'message': {'content': json.dumps(answer)}}]} - message = self.create_message("Hello there! How are you?") - with patch.object(openai.ChatCompletion, 'acreate', new=acreate): - await self.bot.on_message(message) - message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore - - async def test_on_message_event3(self) -> None: - async def acreate(*a, **kw): - return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]} - message = self.create_message("Hello there! How are you?") - with patch.object(openai.ChatCompletion, 'acreate', new=acreate): - with pytest.raises(RuntimeError, match="Failed to generate answer after multiple retries"): - await self.bot.on_message(message) + await self.bot.on_message(message) + self.assertEqual(self.bot.airesponder.history[-1]["content"], + '{"user": "madeup_name", "message": "Hello, how are you?",' + ' "channel": "some_channel", "direct": false, "historise_question": true}') @patch("builtins.open", new_callable=mock_open) def test_update_history_with_file(self, mock_file): @@ -172,41 +119,6 @@ class TestFunctionality(TestBotBase): mock_file.assert_called_once_with("mock_file.pkl", "wb") mock_file().write.assert_called_once() - async def test_on_message_event4(self) -> None: - async def acreate(*a, **kw): - answer = {'answer': 'Hello!', - 'answer_needed': True, - 'staff': 'none', - 'picture': 'Some picture', - 'hack': False} - return {'choices': [{'message': {'content': json.dumps(answer)}}]} - - async def adraw(*a, **kw): - return {'data': [{'url': 'http:url'}]} - - def logging_warning(msg): - raise RuntimeError(msg) - - class image: - def __init__(self, *args, **kw): - pass - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_value, traceback): - return False - - async def read(self): - return b'test bytes' - message = self.create_message("Hello there! How are you?") - with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \ - patch.object(openai.Image, 'acreate', new=adraw), \ - patch.object(logging, 'warning', logging_warning), \ - patch.object(aiohttp.ClientSession, 'get', new=image): - await self.bot.on_message(message) - message.channel.send.assert_called_once_with("Hello!", files=[ANY], suppress_embeds=True) # type: ignore - if __name__ == "__mait__": unittest.main()