Compare commits

..

4 Commits

9 changed files with 346 additions and 252 deletions

View File

@ -9,3 +9,11 @@ repos:
rev: 6.0.0 rev: 6.0.0
hooks: hooks:
- id: flake8 - id: flake8
- repo: local
hooks:
- id: pytest
name: pytest
entry: pytest
language: system
pass_filenames: false

View File

@ -1,10 +1,7 @@
import os import os
import json import json
import asyncio
import random import random
import multiline import multiline
import openai
import aiohttp
import logging import logging
import time import time
import re import re
@ -12,7 +9,7 @@ import pickle
from pathlib import Path from pathlib import Path
from io import BytesIO from io import BytesIO
from pprint import pformat from pprint import pformat
from functools import lru_cache from functools import lru_cache, wraps
from typing import Optional, List, Dict, Any, Tuple from typing import Optional, List, Dict, Any, Tuple
@ -57,6 +54,33 @@ def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1, max_attempts
raise RuntimeError("Max attempts reached in exponential backoff.") raise RuntimeError("Max attempts reached in exponential backoff.")
def async_cache_to_file(filename):
cache_file = Path(filename)
cache = None
if cache_file.exists():
try:
with cache_file.open('rb') as fd:
cache = pickle.load(fd)
except Exception:
cache = {}
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
if cache is None:
return await func(*args, **kwargs)
key = json.dumps((func.__name__, list(args[1:]), kwargs), sort_keys=True)
if key in cache:
return cache[key]
result = await func(*args, **kwargs)
cache[key] = result
with cache_file.open('wb') as fd:
pickle.dump(cache, fd)
return result
return wrapper
return decorator
def parse_maybe_json(json_string): def parse_maybe_json(json_string):
if json_string is None: if json_string is None:
return None return None
@ -115,12 +139,17 @@ class AIResponse(AIMessageBase):
self.hack = hack self.hack = hack
class AIResponder(object): class AIResponderBase(object):
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None: def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
super().__init__()
self.config = config self.config = config
self.history: List[Dict[str, Any]] = []
self.channel = channel if channel is not None else 'system' self.channel = channel if channel is not None else 'system'
self.client = openai.AsyncOpenAI(api_key=self.config['openai-token'])
class AIResponder(AIResponderBase):
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
super().__init__(config, channel)
self.history: List[Dict[str, Any]] = []
self.rate_limit_backoff = exponential_backoff() self.rate_limit_backoff = exponential_backoff()
self.history_file: Optional[Path] = None self.history_file: Optional[Path] = None
if 'history-directory' in self.config: if 'history-directory' in self.config:
@ -129,7 +158,7 @@ class AIResponder(object):
with open(self.history_file, 'rb') as fd: with open(self.history_file, 'rb') as fd:
self.history = pickle.load(fd) self.history = pickle.load(fd)
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]: def message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
messages = [] messages = []
system = self.config.get(self.channel, self.config['system']) system = self.config.get(self.channel, self.config['system'])
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\ system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
@ -150,79 +179,14 @@ class AIResponder(object):
async def draw(self, description: str) -> BytesIO: async def draw(self, description: str) -> BytesIO:
if self.config.get('leonardo-token') is not None: if self.config.get('leonardo-token') is not None:
return await self._draw_leonardo(description) return await self.draw_leonardo(description)
return await self._draw_openai(description) return await self.draw_openai(description)
async def _draw_openai(self, description: str) -> BytesIO: async def draw_leonardo(self, description: str) -> BytesIO:
for _ in range(3): raise NotImplementedError()
try:
response = await self.client.images.generate(prompt=description, n=1, size="1024x1024", model="dall-e-3")
async with aiohttp.ClientSession() as session:
async with session.get(response.data[0].url) as image:
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
return BytesIO(await image.read())
except Exception as err:
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
async def _draw_leonardo(self, description: str) -> BytesIO: async def draw_openai(self, description: str) -> BytesIO:
error_backoff = exponential_backoff(max_attempts=12) raise NotImplementedError()
generation_id = None
image_url = None
image_bytes = None
while True:
error_sleep = next(error_backoff)
try:
async with aiohttp.ClientSession() as session:
if generation_id is None:
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
json={"prompt": description,
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
"num_images": 1,
"sd_version": "v2",
"promptMagic": True,
"unzoomAmount": 1,
"width": 512,
"height": 512},
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
"Accept": "application/json",
"Content-Type": "application/json"},
) as response:
response = await response.json()
if "sdGenerationJob" not in response:
logging.warning(f"No 'sdGenerationJob' found in response, sleep for {error_sleep}s: {repr(response)}")
await asyncio.sleep(error_sleep)
continue
generation_id = response["sdGenerationJob"]["generationId"]
if image_url is None:
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
"Accept": "application/json"},
) as response:
response = await response.json()
if "generations_by_pk" not in response:
logging.warning(f"Unexpected response, sleep for {error_sleep}s: {repr(response)}")
await asyncio.sleep(error_sleep)
continue
if len(response["generations_by_pk"]["generated_images"]) == 0:
await asyncio.sleep(error_sleep)
continue
image_url = response["generations_by_pk"]["generated_images"][0]["url"]
if image_bytes is None:
async with session.get(image_url) as response:
image_bytes = BytesIO(await response.read())
async with session.delete(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
headers={"Authorization": f"Bearer {self.config['leonardo-token']}"},
) as response:
await response.json()
logging.info(f'Drawed a picture with leonardo AI on this description: {repr(description)}')
return image_bytes
except Exception as err:
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}\n{repr(err)}")
else:
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}")
await asyncio.sleep(error_sleep)
raise RuntimeError(f"Failed to generate image {repr(description)}")
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse: async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
for fld in ('answer', 'channel', 'staff', 'picture', 'hack'): for fld in ('answer', 'channel', 'staff', 'picture', 'hack'):
@ -270,77 +234,14 @@ class AIResponder(object):
return True return True
return False return False
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]: async def chat(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
model = self.config["model"] raise NotImplementedError()
try:
result = await self.client.chat.completions.create(model=model,
messages=messages,
temperature=self.config["temperature"],
max_tokens=self.config["max-tokens"],
top_p=self.config["top-p"],
presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"])
answer_obj = result.choices[0].message
answer = {'content': answer_obj.content, 'role': answer_obj.role}
self.rate_limit_backoff = exponential_backoff()
logging.info(f"generated response {result.usage}: {repr(answer)}")
return answer, limit
except openai.BadRequestError as err:
if 'maximum context length is' in str(err) and limit > 4:
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
limit -= 1
return None, limit
raise err
except openai.RateLimitError as err:
rate_limit_sleep = next(self.rate_limit_backoff)
if "retry-model" in self.config:
model = self.config["retry-model"]
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
await asyncio.sleep(rate_limit_sleep)
except Exception as err:
logging.warning(f"failed to generate response: {repr(err)}")
return None, limit
async def fix(self, answer: str) -> str: async def fix(self, answer: str) -> str:
if 'fix-model' not in self.config: raise NotImplementedError()
return answer
messages = [{"role": "system", "content": self.config["fix-description"]},
{"role": "user", "content": answer}]
try:
result = await self.client.chat.completions.create(model=self.config["fix-model"],
messages=messages,
temperature=0.2,
max_tokens=2048)
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
response = result.choices[0].message.content
start, end = response.find("{"), response.rfind("}")
if start == -1 or end == -1 or (start + 3) >= end:
return answer
response = response[start:end + 1]
logging.info(f"fixed answer:\n{pp(response)}")
return response
except Exception as err:
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
return answer
async def translate(self, text: str, language: str = "english") -> str: async def translate(self, text: str, language: str = "english") -> str:
if 'fix-model' not in self.config: raise NotImplementedError()
return text
message = [{"role": "system", "content": f"You are an professional translator to {language} language,"
f" you translate everything you get directly to {language}"
f" if it is not already in {language}, otherwise you just copy it."},
{"role": "user", "content": text}]
try:
result = await self.client.chat.completions.create(model=self.config["fix-model"],
messages=message,
temperature=0.2,
max_tokens=2048)
response = result.choices[0].message.content
logging.info(f"got this translated message:\n{pp(response)}")
return response
except Exception as err:
logging.warning(f"failed to translate the text: {repr(err)}")
return text
def shrink_history_by_one(self, index: int = 0) -> None: def shrink_history_by_one(self, index: int = 0) -> None:
if index >= len(self.history): if index >= len(self.history):
@ -380,11 +281,11 @@ class AIResponder(object):
while retries > 0: while retries > 0:
# Get the message queue # Get the message queue
messages = self._message(message, limit) messages = self.message(message, limit)
logging.info(f"try to send this messages:\n{pp(messages)}") logging.info(f"try to send this messages:\n{pp(messages)}")
# Attempt to send the message to the AI # Attempt to send the message to the AI
answer, limit = await self._acreate(messages, limit) answer, limit = await self.chat(messages, limit)
if answer is None: if answer is None:
continue continue

View File

@ -12,7 +12,8 @@ from discord import Message, TextChannel, DMChannel
from discord.ext import commands from discord.ext import commands
from watchdog.observers import Observer from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler from watchdog.events import FileSystemEventHandler
from .ai_responder import AIResponder, AIMessage from .ai_responder import AIMessage
from .openai_responder import OpenAIResponder
from typing import Optional, Union from typing import Optional, Union
@ -45,8 +46,8 @@ class FjerkroaBot(commands.Bot):
self.observer.start() self.observer.start()
def init_aichannels(self): def init_aichannels(self):
self.airesponder = AIResponder(self.config) self.airesponder = OpenAIResponder(self.config)
self.aichannels = {chan_name: AIResponder(self.config, chan_name) for chan_name in self.config['additional-responders']} self.aichannels = {chan_name: OpenAIResponder(self.config, chan_name) for chan_name in self.config['additional-responders']}
def init_channels(self): def init_channels(self):
if 'chat-channel' in self.config: if 'chat-channel' in self.config:

80
fjerkroa_bot/igdblib.py Normal file
View File

@ -0,0 +1,80 @@
import requests
from functools import cache
class IGDBQuery(object):
def __init__(self, client_id, igdb_api_key):
self.client_id = client_id
self.igdb_api_key = igdb_api_key
def send_igdb_request(self, endpoint, query_body):
igdb_url = f'https://api.igdb.com/v4/{endpoint}'
headers = {
'Client-ID': self.client_id,
'Authorization': f'Bearer {self.igdb_api_key}'
}
try:
response = requests.post(igdb_url, headers=headers, data=query_body)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
print(f"Error during IGDB API request: {e}")
return None
@staticmethod
def build_query(fields, filters=None, limit=10, offset=None):
query = f"fields {','.join(fields) if fields is not None and len(fields) > 0 else '*'}; limit {limit};"
if offset is not None:
query += f' offset {offset};'
if filters:
filter_statements = [f"{key} {value}" for key, value in filters.items()]
query += " where " + " & ".join(filter_statements) + ";"
return query
def generalized_igdb_query(self, params, endpoint, fields, additional_filters=None, limit=10, offset=None):
all_filters = {key: f'~ "{value}"*' for key, value in params.items() if value}
if additional_filters:
all_filters.update(additional_filters)
query = self.build_query(fields, all_filters, limit, offset)
data = self.send_igdb_request(endpoint, query)
print(f'{endpoint}: {query} -> {data}')
return data
def create_query_function(self, name, description, parameters, endpoint, fields, additional_filters=None, limit=10):
return {
"name": name,
"description": description,
"parameters": {"type": "object", "properties": parameters},
"function": lambda params: self.generalized_igdb_query(params, endpoint, fields, additional_filters, limit)
}
@cache
def platform_families(self):
families = self.generalized_igdb_query({}, 'platform_families', ['id', 'name'], limit=500)
return {v['id']: v['name'] for v in families}
@cache
def platforms(self):
platforms = self.generalized_igdb_query({}, 'platforms',
['id', 'name', 'alternative_name', 'abbreviation', 'platform_family'],
limit=500)
ret = {}
for p in platforms:
names = p['name']
if 'alternative_name' in p:
names.append(p['alternative_name'])
if 'abbreviation' in p:
names.append(p['abbreviation'])
family = self.platform_families()[p['id']] if 'platform_family' in p else None
ret[p['id']] = {'names': names, 'family': family}
return ret
def game_info(self, name):
game_info = self.generalized_igdb_query({'name': name},
['id', 'name', 'alternative_names', 'category',
'release_dates', 'franchise', 'language_supports',
'keywords', 'platforms', 'rating', 'summary'],
limit=100)
return game_info

View File

@ -0,0 +1,66 @@
import logging
import asyncio
import aiohttp
from .ai_responder import exponential_backoff, AIResponderBase
from io import BytesIO
class LeonardoAIDrawMixIn(AIResponderBase):
async def draw_leonardo(self, description: str) -> BytesIO:
error_backoff = exponential_backoff(max_attempts=12)
generation_id = None
image_url = None
image_bytes = None
while True:
error_sleep = next(error_backoff)
try:
async with aiohttp.ClientSession() as session:
if generation_id is None:
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
json={"prompt": description,
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
"num_images": 1,
"sd_version": "v2",
"promptMagic": True,
"unzoomAmount": 1,
"width": 512,
"height": 512},
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
"Accept": "application/json",
"Content-Type": "application/json"},
) as response:
response = await response.json()
if "sdGenerationJob" not in response:
logging.warning(f"No 'sdGenerationJob' found in response, sleep for {error_sleep}s: {repr(response)}")
await asyncio.sleep(error_sleep)
continue
generation_id = response["sdGenerationJob"]["generationId"]
if image_url is None:
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
"Accept": "application/json"},
) as response:
response = await response.json()
if "generations_by_pk" not in response:
logging.warning(f"Unexpected response, sleep for {error_sleep}s: {repr(response)}")
await asyncio.sleep(error_sleep)
continue
if len(response["generations_by_pk"]["generated_images"]) == 0:
await asyncio.sleep(error_sleep)
continue
image_url = response["generations_by_pk"]["generated_images"][0]["url"]
if image_bytes is None:
async with session.get(image_url) as response:
image_bytes = BytesIO(await response.read())
async with session.delete(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
headers={"Authorization": f"Bearer {self.config['leonardo-token']}"},
) as response:
await response.json()
logging.info(f'Drawed a picture with leonardo AI on this description: {repr(description)}')
return image_bytes
except Exception as err:
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}\n{repr(err)}")
else:
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}")
await asyncio.sleep(error_sleep)
raise RuntimeError(f"Failed to generate image {repr(description)}")

View File

@ -0,0 +1,112 @@
import openai
import aiohttp
import logging
import asyncio
from .ai_responder import AIResponder, async_cache_to_file, exponential_backoff, pp
from .leonardo_draw import LeonardoAIDrawMixIn
from io import BytesIO
from typing import Dict, Any, Optional, List, Tuple
@async_cache_to_file('openai_chat.dat')
async def openai_chat(client, *args, **kwargs):
return await client.chat.completions.create(*args, **kwargs)
@async_cache_to_file('openai_chat.dat')
async def openai_image(client, *args, **kwargs):
response = await client.images.generate(*args, **kwargs)
async with aiohttp.ClientSession() as session:
async with session.get(response.data[0].url) as image:
return BytesIO(await image.read())
class OpenAIResponder(AIResponder, LeonardoAIDrawMixIn):
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
super().__init__(config, channel)
self.client = openai.AsyncOpenAI(api_key=self.config['openai-token'])
async def draw_openai(self, description: str) -> BytesIO:
for _ in range(3):
try:
response = await openai_image(self.client, prompt=description, n=1, size="1024x1024", model="dall-e-3")
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
return response
except Exception as err:
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
async def chat(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
model = self.config["model"]
try:
result = await openai_chat(self.client,
model=model,
messages=messages,
temperature=self.config["temperature"],
max_tokens=self.config["max-tokens"],
top_p=self.config["top-p"],
presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"])
answer_obj = result.choices[0].message
answer = {'content': answer_obj.content, 'role': answer_obj.role}
self.rate_limit_backoff = exponential_backoff()
logging.info(f"generated response {result.usage}: {repr(answer)}")
return answer, limit
except openai.BadRequestError as err:
if 'maximum context length is' in str(err) and limit > 4:
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
limit -= 1
return None, limit
raise err
except openai.RateLimitError as err:
rate_limit_sleep = next(self.rate_limit_backoff)
if "retry-model" in self.config:
model = self.config["retry-model"]
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
await asyncio.sleep(rate_limit_sleep)
except Exception as err:
logging.warning(f"failed to generate response: {repr(err)}")
return None, limit
async def fix(self, answer: str) -> str:
if 'fix-model' not in self.config:
return answer
messages = [{"role": "system", "content": self.config["fix-description"]},
{"role": "user", "content": answer}]
try:
result = await openai_chat(self.client,
model=self.config["fix-model"],
messages=messages,
temperature=0.2,
max_tokens=2048)
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
response = result.choices[0].message.content
start, end = response.find("{"), response.rfind("}")
if start == -1 or end == -1 or (start + 3) >= end:
return answer
response = response[start:end + 1]
logging.info(f"fixed answer:\n{pp(response)}")
return response
except Exception as err:
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
return answer
async def translate(self, text: str, language: str = "english") -> str:
if 'fix-model' not in self.config:
return text
message = [{"role": "system", "content": f"You are an professional translator to {language} language,"
f" you translate everything you get directly to {language}"
f" if it is not already in {language}, otherwise you just copy it."},
{"role": "user", "content": text}]
try:
result = await openai_chat(self.client,
model=self.config["fix-model"],
messages=message,
temperature=0.2,
max_tokens=2048)
response = result.choices[0].message.content
logging.info(f"got this translated message:\n{pp(response)}")
return response
except Exception as err:
logging.warning(f"failed to translate the text: {repr(err)}")
return text

BIN
openai_chat.dat Normal file

Binary file not shown.

View File

@ -11,7 +11,7 @@ class TestAIResponder(TestBotBase):
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp() await super().asyncSetUp()
self.system = r""" self.system = r"""
You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you. Current date is {date} and time is {time}. You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you.
Every message from users is a dictionary in JSON format with the following fields: Every message from users is a dictionary in JSON format with the following fields:
1. `user`: name of the user who wrote the message. 1. `user`: name of the user who wrote the message.
@ -32,7 +32,7 @@ You always try to say something positive about the current day and the Fjærkroa
self.config_data["system"] = self.system self.config_data["system"] = self.system
def assertAIResponse(self, resp1, resp2, def assertAIResponse(self, resp1, resp2,
acmp=lambda a, b: type(a) == str or len(a) > 10, acmp=lambda a, b: type(a) == str and len(a) > 10,
scmp=lambda a, b: a == b, scmp=lambda a, b: a == b,
pcmp=lambda a, b: a == b): pcmp=lambda a, b: a == b):
self.assertEqual(acmp(resp1.answer, resp2.answer), True) self.assertEqual(acmp(resp1.answer, resp2.answer), True)
@ -45,6 +45,20 @@ You always try to say something positive about the current day and the Fjærkroa
print(f"\n{response}") print(f"\n{response}")
self.assertAIResponse(response, AIResponse('test', True, None, None, None, False)) self.assertAIResponse(response, AIResponse('test', True, None, None, None, False))
async def test_picture1(self) -> None:
response = await self.bot.airesponder.send(AIMessage("lala", "draw me a picture of you."))
print(f"\n{response}")
self.assertAIResponse(response, AIResponse('test', False, None, None, "I am an anime girl with long pink hair, wearing a cute cafe uniform and holding a tray with a cup of coffee on it. I have a warm and friendly smile on my face.", False))
image = await self.bot.airesponder.draw(response.picture)
self.assertEqual(image.read()[:len(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR')], b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR')
async def test_translate1(self) -> None:
self.bot.airesponder.config['fix-model'] = 'gpt-3.5-turbo'
response = await self.bot.airesponder.translate('Das ist ein komischer Text.')
self.assertEqual(response, 'This is a strange text.')
response = await self.bot.airesponder.translate('This is a strange text.', language='german')
self.assertEqual(response, 'Dies ist ein seltsamer Text.')
async def test_fix1(self) -> None: async def test_fix1(self) -> None:
old_config = self.bot.airesponder.config old_config = self.bot.airesponder.config
config = {k: v for k, v in old_config.items()} config = {k: v for k, v in old_config.items()}

View File

@ -1,12 +1,7 @@
import os import os
import unittest import unittest
import aiohttp
import json
import toml import toml
import openai from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open
import logging
import pytest
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY
from fjerkroa_bot import FjerkroaBot from fjerkroa_bot import FjerkroaBot
from fjerkroa_bot.ai_responder import parse_maybe_json, AIResponse, AIMessage from fjerkroa_bot.ai_responder import parse_maybe_json, AIResponse, AIMessage
from discord import User, Message, TextChannel from discord import User, Message, TextChannel
@ -90,75 +85,27 @@ class TestFunctionality(TestBotBase):
async def test_message_lings(self) -> None: async def test_message_lings(self) -> None:
request = AIMessage('Lala', 'Hello there!', 'chat', False,) request = AIMessage('Lala', 'Hello there!', 'chat', False,)
message = {'answer': 'Test [Link](https://www.example.com/test)', message = {'answer': 'Test [Link](https://www.example.com/test)',
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False} 'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False}
expected = AIResponse('Test https://www.example.com/test', True, None, None, None, False) expected = AIResponse('Test https://www.example.com/test', True, 'chat', None, None, False)
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected)) self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
message = {'answer': 'Test @[Link](https://www.example.com/test)', message = {'answer': 'Test @[Link](https://www.example.com/test)',
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False} 'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False}
expected = AIResponse('Test Link', True, None, None, None, False) expected = AIResponse('Test Link', True, 'chat', None, None, False)
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected)) self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
message = {'answer': 'Test [Link](https://www.example.com/test) and [Link2](https://xxx) lala', message = {'answer': 'Test [Link](https://www.example.com/test) and [Link2](https://xxx) lala',
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False} 'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False}
expected = AIResponse('Test https://www.example.com/test and https://xxx lala', True, None, None, None, False) expected = AIResponse('Test https://www.example.com/test and https://xxx lala', True, 'chat', None, None, False)
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected)) self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
async def test_on_message_event(self) -> None:
async def acreate(*a, **kw):
answer = {'answer': 'Hello!',
'answer_needed': True,
'staff': None,
'picture': None,
'hack': False}
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
await self.bot.on_message(message)
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore
async def test_on_message_stort_path(self) -> None: async def test_on_message_stort_path(self) -> None:
async def acreate(*a, **kw):
answer = {'answer': 'Hello!',
'answer_needed': True,
'channel': None,
'staff': None,
'picture': None,
'hack': False}
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
message = self.create_message("Hello there! How are you?") message = self.create_message("Hello there! How are you?")
message.author.name = 'madeup_name' message.author.name = 'madeup_name'
message.channel.name = 'some_channel' # type: ignore message.channel.name = 'some_channel' # type: ignore
self.bot.config['short-path'] = [[r'some.*', r'madeup.*']] self.bot.config['short-path'] = [[r'some.*', r'madeup.*']]
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
await self.bot.on_message(message) await self.bot.on_message(message)
self.assertEqual(self.bot.airesponder.history[-1]["content"], self.assertEqual(self.bot.airesponder.history[-1]["content"],
'{"user": "madeup_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}') '{"user": "madeup_name", "message": "Hello, how are you?",'
message.author.name = 'different_name' ' "channel": "some_channel", "direct": false, "historise_question": true}')
await self.bot.on_message(message)
self.assertEqual(self.bot.airesponder.history[-2]["content"],
'{"user": "different_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}')
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore
async def test_on_message_event2(self) -> None:
async def acreate(*a, **kw):
answer = {'answer': 'Hello!',
'answer_needed': True,
'channel': None,
'staff': 'Hallo staff',
'picture': None,
'hack': False}
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
await self.bot.on_message(message)
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore
async def test_on_message_event3(self) -> None:
async def acreate(*a, **kw):
return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]}
message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
with pytest.raises(RuntimeError, match="Failed to generate answer after multiple retries"):
await self.bot.on_message(message)
@patch("builtins.open", new_callable=mock_open) @patch("builtins.open", new_callable=mock_open)
def test_update_history_with_file(self, mock_file): def test_update_history_with_file(self, mock_file):
@ -172,41 +119,6 @@ class TestFunctionality(TestBotBase):
mock_file.assert_called_once_with("mock_file.pkl", "wb") mock_file.assert_called_once_with("mock_file.pkl", "wb")
mock_file().write.assert_called_once() mock_file().write.assert_called_once()
async def test_on_message_event4(self) -> None:
async def acreate(*a, **kw):
answer = {'answer': 'Hello!',
'answer_needed': True,
'staff': 'none',
'picture': 'Some picture',
'hack': False}
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
async def adraw(*a, **kw):
return {'data': [{'url': 'http:url'}]}
def logging_warning(msg):
raise RuntimeError(msg)
class image:
def __init__(self, *args, **kw):
pass
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_value, traceback):
return False
async def read(self):
return b'test bytes'
message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
patch.object(openai.Image, 'acreate', new=adraw), \
patch.object(logging, 'warning', logging_warning), \
patch.object(aiohttp.ClientSession, 'get', new=image):
await self.bot.on_message(message)
message.channel.send.assert_called_once_with("Hello!", files=[ANY], suppress_embeds=True) # type: ignore
if __name__ == "__mait__": if __name__ == "__mait__":
unittest.main() unittest.main()