Compare commits
No commits in common. "d6942943b5d94478c871eb81186bec52ebd1552b" and "7bcadecb17c5c30ec76767781f5446d1b3273d9a" have entirely different histories.
d6942943b5
...
7bcadecb17
@ -9,11 +9,3 @@ repos:
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pytest
|
||||
name: pytest
|
||||
entry: pytest
|
||||
language: system
|
||||
pass_filenames: false
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import random
|
||||
import multiline
|
||||
import openai
|
||||
import aiohttp
|
||||
import logging
|
||||
import time
|
||||
import re
|
||||
@ -9,7 +12,7 @@ import pickle
|
||||
from pathlib import Path
|
||||
from io import BytesIO
|
||||
from pprint import pformat
|
||||
from functools import lru_cache, wraps
|
||||
from functools import lru_cache
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
|
||||
|
||||
@ -54,33 +57,6 @@ def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1, max_attempts
|
||||
raise RuntimeError("Max attempts reached in exponential backoff.")
|
||||
|
||||
|
||||
def async_cache_to_file(filename):
|
||||
cache_file = Path(filename)
|
||||
cache = None
|
||||
if cache_file.exists():
|
||||
try:
|
||||
with cache_file.open('rb') as fd:
|
||||
cache = pickle.load(fd)
|
||||
except Exception:
|
||||
cache = {}
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
if cache is None:
|
||||
return await func(*args, **kwargs)
|
||||
key = json.dumps((func.__name__, list(args[1:]), kwargs), sort_keys=True)
|
||||
if key in cache:
|
||||
return cache[key]
|
||||
result = await func(*args, **kwargs)
|
||||
cache[key] = result
|
||||
with cache_file.open('wb') as fd:
|
||||
pickle.dump(cache, fd)
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def parse_maybe_json(json_string):
|
||||
if json_string is None:
|
||||
return None
|
||||
@ -139,17 +115,12 @@ class AIResponse(AIMessageBase):
|
||||
self.hack = hack
|
||||
|
||||
|
||||
class AIResponderBase(object):
|
||||
class AIResponder(object):
|
||||
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.channel = channel if channel is not None else 'system'
|
||||
|
||||
|
||||
class AIResponder(AIResponderBase):
|
||||
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
||||
super().__init__(config, channel)
|
||||
self.history: List[Dict[str, Any]] = []
|
||||
self.channel = channel if channel is not None else 'system'
|
||||
self.client = openai.AsyncOpenAI(api_key=self.config['openai-token'])
|
||||
self.rate_limit_backoff = exponential_backoff()
|
||||
self.history_file: Optional[Path] = None
|
||||
if 'history-directory' in self.config:
|
||||
@ -158,7 +129,7 @@ class AIResponder(AIResponderBase):
|
||||
with open(self.history_file, 'rb') as fd:
|
||||
self.history = pickle.load(fd)
|
||||
|
||||
def message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
messages = []
|
||||
system = self.config.get(self.channel, self.config['system'])
|
||||
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
|
||||
@ -179,14 +150,79 @@ class AIResponder(AIResponderBase):
|
||||
|
||||
async def draw(self, description: str) -> BytesIO:
|
||||
if self.config.get('leonardo-token') is not None:
|
||||
return await self.draw_leonardo(description)
|
||||
return await self.draw_openai(description)
|
||||
return await self._draw_leonardo(description)
|
||||
return await self._draw_openai(description)
|
||||
|
||||
async def draw_leonardo(self, description: str) -> BytesIO:
|
||||
raise NotImplementedError()
|
||||
async def _draw_openai(self, description: str) -> BytesIO:
|
||||
for _ in range(3):
|
||||
try:
|
||||
response = await self.client.images.generate(prompt=description, n=1, size="1024x1024", model="dall-e-3")
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response.data[0].url) as image:
|
||||
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
|
||||
return BytesIO(await image.read())
|
||||
except Exception as err:
|
||||
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
||||
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
||||
|
||||
async def draw_openai(self, description: str) -> BytesIO:
|
||||
raise NotImplementedError()
|
||||
async def _draw_leonardo(self, description: str) -> BytesIO:
|
||||
error_backoff = exponential_backoff(max_attempts=12)
|
||||
generation_id = None
|
||||
image_url = None
|
||||
image_bytes = None
|
||||
while True:
|
||||
error_sleep = next(error_backoff)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if generation_id is None:
|
||||
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
|
||||
json={"prompt": description,
|
||||
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
|
||||
"num_images": 1,
|
||||
"sd_version": "v2",
|
||||
"promptMagic": True,
|
||||
"unzoomAmount": 1,
|
||||
"width": 512,
|
||||
"height": 512},
|
||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json"},
|
||||
) as response:
|
||||
response = await response.json()
|
||||
if "sdGenerationJob" not in response:
|
||||
logging.warning(f"No 'sdGenerationJob' found in response, sleep for {error_sleep}s: {repr(response)}")
|
||||
await asyncio.sleep(error_sleep)
|
||||
continue
|
||||
generation_id = response["sdGenerationJob"]["generationId"]
|
||||
if image_url is None:
|
||||
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
||||
"Accept": "application/json"},
|
||||
) as response:
|
||||
response = await response.json()
|
||||
if "generations_by_pk" not in response:
|
||||
logging.warning(f"Unexpected response, sleep for {error_sleep}s: {repr(response)}")
|
||||
await asyncio.sleep(error_sleep)
|
||||
continue
|
||||
if len(response["generations_by_pk"]["generated_images"]) == 0:
|
||||
await asyncio.sleep(error_sleep)
|
||||
continue
|
||||
image_url = response["generations_by_pk"]["generated_images"][0]["url"]
|
||||
if image_bytes is None:
|
||||
async with session.get(image_url) as response:
|
||||
image_bytes = BytesIO(await response.read())
|
||||
async with session.delete(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}"},
|
||||
) as response:
|
||||
await response.json()
|
||||
logging.info(f'Drawed a picture with leonardo AI on this description: {repr(description)}')
|
||||
return image_bytes
|
||||
except Exception as err:
|
||||
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}\n{repr(err)}")
|
||||
else:
|
||||
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}")
|
||||
await asyncio.sleep(error_sleep)
|
||||
raise RuntimeError(f"Failed to generate image {repr(description)}")
|
||||
|
||||
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
|
||||
for fld in ('answer', 'channel', 'staff', 'picture', 'hack'):
|
||||
@ -234,14 +270,77 @@ class AIResponder(AIResponderBase):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def chat(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
||||
raise NotImplementedError()
|
||||
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
||||
model = self.config["model"]
|
||||
try:
|
||||
result = await self.client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
temperature=self.config["temperature"],
|
||||
max_tokens=self.config["max-tokens"],
|
||||
top_p=self.config["top-p"],
|
||||
presence_penalty=self.config["presence-penalty"],
|
||||
frequency_penalty=self.config["frequency-penalty"])
|
||||
answer_obj = result.choices[0].message
|
||||
answer = {'content': answer_obj.content, 'role': answer_obj.role}
|
||||
self.rate_limit_backoff = exponential_backoff()
|
||||
logging.info(f"generated response {result.usage}: {repr(answer)}")
|
||||
return answer, limit
|
||||
except openai.BadRequestError as err:
|
||||
if 'maximum context length is' in str(err) and limit > 4:
|
||||
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
|
||||
limit -= 1
|
||||
return None, limit
|
||||
raise err
|
||||
except openai.RateLimitError as err:
|
||||
rate_limit_sleep = next(self.rate_limit_backoff)
|
||||
if "retry-model" in self.config:
|
||||
model = self.config["retry-model"]
|
||||
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
|
||||
await asyncio.sleep(rate_limit_sleep)
|
||||
except Exception as err:
|
||||
logging.warning(f"failed to generate response: {repr(err)}")
|
||||
return None, limit
|
||||
|
||||
async def fix(self, answer: str) -> str:
|
||||
raise NotImplementedError()
|
||||
if 'fix-model' not in self.config:
|
||||
return answer
|
||||
messages = [{"role": "system", "content": self.config["fix-description"]},
|
||||
{"role": "user", "content": answer}]
|
||||
try:
|
||||
result = await self.client.chat.completions.create(model=self.config["fix-model"],
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
max_tokens=2048)
|
||||
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
|
||||
response = result.choices[0].message.content
|
||||
start, end = response.find("{"), response.rfind("}")
|
||||
if start == -1 or end == -1 or (start + 3) >= end:
|
||||
return answer
|
||||
response = response[start:end + 1]
|
||||
logging.info(f"fixed answer:\n{pp(response)}")
|
||||
return response
|
||||
except Exception as err:
|
||||
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
|
||||
return answer
|
||||
|
||||
async def translate(self, text: str, language: str = "english") -> str:
|
||||
raise NotImplementedError()
|
||||
if 'fix-model' not in self.config:
|
||||
return text
|
||||
message = [{"role": "system", "content": f"You are an professional translator to {language} language,"
|
||||
f" you translate everything you get directly to {language}"
|
||||
f" if it is not already in {language}, otherwise you just copy it."},
|
||||
{"role": "user", "content": text}]
|
||||
try:
|
||||
result = await self.client.chat.completions.create(model=self.config["fix-model"],
|
||||
messages=message,
|
||||
temperature=0.2,
|
||||
max_tokens=2048)
|
||||
response = result.choices[0].message.content
|
||||
logging.info(f"got this translated message:\n{pp(response)}")
|
||||
return response
|
||||
except Exception as err:
|
||||
logging.warning(f"failed to translate the text: {repr(err)}")
|
||||
return text
|
||||
|
||||
def shrink_history_by_one(self, index: int = 0) -> None:
|
||||
if index >= len(self.history):
|
||||
@ -281,11 +380,11 @@ class AIResponder(AIResponderBase):
|
||||
|
||||
while retries > 0:
|
||||
# Get the message queue
|
||||
messages = self.message(message, limit)
|
||||
messages = self._message(message, limit)
|
||||
logging.info(f"try to send this messages:\n{pp(messages)}")
|
||||
|
||||
# Attempt to send the message to the AI
|
||||
answer, limit = await self.chat(messages, limit)
|
||||
answer, limit = await self._acreate(messages, limit)
|
||||
|
||||
if answer is None:
|
||||
continue
|
||||
|
||||
@ -12,8 +12,7 @@ from discord import Message, TextChannel, DMChannel
|
||||
from discord.ext import commands
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from .ai_responder import AIMessage
|
||||
from .openai_responder import OpenAIResponder
|
||||
from .ai_responder import AIResponder, AIMessage
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
@ -46,8 +45,8 @@ class FjerkroaBot(commands.Bot):
|
||||
self.observer.start()
|
||||
|
||||
def init_aichannels(self):
|
||||
self.airesponder = OpenAIResponder(self.config)
|
||||
self.aichannels = {chan_name: OpenAIResponder(self.config, chan_name) for chan_name in self.config['additional-responders']}
|
||||
self.airesponder = AIResponder(self.config)
|
||||
self.aichannels = {chan_name: AIResponder(self.config, chan_name) for chan_name in self.config['additional-responders']}
|
||||
|
||||
def init_channels(self):
|
||||
if 'chat-channel' in self.config:
|
||||
|
||||
@ -1,80 +0,0 @@
|
||||
import requests
|
||||
from functools import cache
|
||||
|
||||
|
||||
class IGDBQuery(object):
|
||||
def __init__(self, client_id, igdb_api_key):
|
||||
self.client_id = client_id
|
||||
self.igdb_api_key = igdb_api_key
|
||||
|
||||
def send_igdb_request(self, endpoint, query_body):
|
||||
igdb_url = f'https://api.igdb.com/v4/{endpoint}'
|
||||
headers = {
|
||||
'Client-ID': self.client_id,
|
||||
'Authorization': f'Bearer {self.igdb_api_key}'
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(igdb_url, headers=headers, data=query_body)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.RequestException as e:
|
||||
print(f"Error during IGDB API request: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def build_query(fields, filters=None, limit=10, offset=None):
|
||||
query = f"fields {','.join(fields) if fields is not None and len(fields) > 0 else '*'}; limit {limit};"
|
||||
if offset is not None:
|
||||
query += f' offset {offset};'
|
||||
if filters:
|
||||
filter_statements = [f"{key} {value}" for key, value in filters.items()]
|
||||
query += " where " + " & ".join(filter_statements) + ";"
|
||||
return query
|
||||
|
||||
def generalized_igdb_query(self, params, endpoint, fields, additional_filters=None, limit=10, offset=None):
|
||||
all_filters = {key: f'~ "{value}"*' for key, value in params.items() if value}
|
||||
if additional_filters:
|
||||
all_filters.update(additional_filters)
|
||||
|
||||
query = self.build_query(fields, all_filters, limit, offset)
|
||||
data = self.send_igdb_request(endpoint, query)
|
||||
print(f'{endpoint}: {query} -> {data}')
|
||||
return data
|
||||
|
||||
def create_query_function(self, name, description, parameters, endpoint, fields, additional_filters=None, limit=10):
|
||||
return {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": {"type": "object", "properties": parameters},
|
||||
"function": lambda params: self.generalized_igdb_query(params, endpoint, fields, additional_filters, limit)
|
||||
}
|
||||
|
||||
@cache
|
||||
def platform_families(self):
|
||||
families = self.generalized_igdb_query({}, 'platform_families', ['id', 'name'], limit=500)
|
||||
return {v['id']: v['name'] for v in families}
|
||||
|
||||
@cache
|
||||
def platforms(self):
|
||||
platforms = self.generalized_igdb_query({}, 'platforms',
|
||||
['id', 'name', 'alternative_name', 'abbreviation', 'platform_family'],
|
||||
limit=500)
|
||||
ret = {}
|
||||
for p in platforms:
|
||||
names = p['name']
|
||||
if 'alternative_name' in p:
|
||||
names.append(p['alternative_name'])
|
||||
if 'abbreviation' in p:
|
||||
names.append(p['abbreviation'])
|
||||
family = self.platform_families()[p['id']] if 'platform_family' in p else None
|
||||
ret[p['id']] = {'names': names, 'family': family}
|
||||
return ret
|
||||
|
||||
def game_info(self, name):
|
||||
game_info = self.generalized_igdb_query({'name': name},
|
||||
['id', 'name', 'alternative_names', 'category',
|
||||
'release_dates', 'franchise', 'language_supports',
|
||||
'keywords', 'platforms', 'rating', 'summary'],
|
||||
limit=100)
|
||||
return game_info
|
||||
@ -1,66 +0,0 @@
|
||||
import logging
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from .ai_responder import exponential_backoff, AIResponderBase
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
class LeonardoAIDrawMixIn(AIResponderBase):
|
||||
async def draw_leonardo(self, description: str) -> BytesIO:
|
||||
error_backoff = exponential_backoff(max_attempts=12)
|
||||
generation_id = None
|
||||
image_url = None
|
||||
image_bytes = None
|
||||
while True:
|
||||
error_sleep = next(error_backoff)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if generation_id is None:
|
||||
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
|
||||
json={"prompt": description,
|
||||
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
|
||||
"num_images": 1,
|
||||
"sd_version": "v2",
|
||||
"promptMagic": True,
|
||||
"unzoomAmount": 1,
|
||||
"width": 512,
|
||||
"height": 512},
|
||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json"},
|
||||
) as response:
|
||||
response = await response.json()
|
||||
if "sdGenerationJob" not in response:
|
||||
logging.warning(f"No 'sdGenerationJob' found in response, sleep for {error_sleep}s: {repr(response)}")
|
||||
await asyncio.sleep(error_sleep)
|
||||
continue
|
||||
generation_id = response["sdGenerationJob"]["generationId"]
|
||||
if image_url is None:
|
||||
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
||||
"Accept": "application/json"},
|
||||
) as response:
|
||||
response = await response.json()
|
||||
if "generations_by_pk" not in response:
|
||||
logging.warning(f"Unexpected response, sleep for {error_sleep}s: {repr(response)}")
|
||||
await asyncio.sleep(error_sleep)
|
||||
continue
|
||||
if len(response["generations_by_pk"]["generated_images"]) == 0:
|
||||
await asyncio.sleep(error_sleep)
|
||||
continue
|
||||
image_url = response["generations_by_pk"]["generated_images"][0]["url"]
|
||||
if image_bytes is None:
|
||||
async with session.get(image_url) as response:
|
||||
image_bytes = BytesIO(await response.read())
|
||||
async with session.delete(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}"},
|
||||
) as response:
|
||||
await response.json()
|
||||
logging.info(f'Drawed a picture with leonardo AI on this description: {repr(description)}')
|
||||
return image_bytes
|
||||
except Exception as err:
|
||||
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}\n{repr(err)}")
|
||||
else:
|
||||
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}")
|
||||
await asyncio.sleep(error_sleep)
|
||||
raise RuntimeError(f"Failed to generate image {repr(description)}")
|
||||
@ -1,112 +0,0 @@
|
||||
import openai
|
||||
import aiohttp
|
||||
import logging
|
||||
import asyncio
|
||||
from .ai_responder import AIResponder, async_cache_to_file, exponential_backoff, pp
|
||||
from .leonardo_draw import LeonardoAIDrawMixIn
|
||||
from io import BytesIO
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
|
||||
|
||||
@async_cache_to_file('openai_chat.dat')
|
||||
async def openai_chat(client, *args, **kwargs):
|
||||
return await client.chat.completions.create(*args, **kwargs)
|
||||
|
||||
|
||||
@async_cache_to_file('openai_chat.dat')
|
||||
async def openai_image(client, *args, **kwargs):
|
||||
response = await client.images.generate(*args, **kwargs)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response.data[0].url) as image:
|
||||
return BytesIO(await image.read())
|
||||
|
||||
|
||||
class OpenAIResponder(AIResponder, LeonardoAIDrawMixIn):
|
||||
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
||||
super().__init__(config, channel)
|
||||
self.client = openai.AsyncOpenAI(api_key=self.config['openai-token'])
|
||||
|
||||
async def draw_openai(self, description: str) -> BytesIO:
|
||||
for _ in range(3):
|
||||
try:
|
||||
response = await openai_image(self.client, prompt=description, n=1, size="1024x1024", model="dall-e-3")
|
||||
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
|
||||
return response
|
||||
except Exception as err:
|
||||
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
||||
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
||||
|
||||
async def chat(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
||||
model = self.config["model"]
|
||||
try:
|
||||
result = await openai_chat(self.client,
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=self.config["temperature"],
|
||||
max_tokens=self.config["max-tokens"],
|
||||
top_p=self.config["top-p"],
|
||||
presence_penalty=self.config["presence-penalty"],
|
||||
frequency_penalty=self.config["frequency-penalty"])
|
||||
answer_obj = result.choices[0].message
|
||||
answer = {'content': answer_obj.content, 'role': answer_obj.role}
|
||||
self.rate_limit_backoff = exponential_backoff()
|
||||
logging.info(f"generated response {result.usage}: {repr(answer)}")
|
||||
return answer, limit
|
||||
except openai.BadRequestError as err:
|
||||
if 'maximum context length is' in str(err) and limit > 4:
|
||||
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
|
||||
limit -= 1
|
||||
return None, limit
|
||||
raise err
|
||||
except openai.RateLimitError as err:
|
||||
rate_limit_sleep = next(self.rate_limit_backoff)
|
||||
if "retry-model" in self.config:
|
||||
model = self.config["retry-model"]
|
||||
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
|
||||
await asyncio.sleep(rate_limit_sleep)
|
||||
except Exception as err:
|
||||
logging.warning(f"failed to generate response: {repr(err)}")
|
||||
return None, limit
|
||||
|
||||
async def fix(self, answer: str) -> str:
|
||||
if 'fix-model' not in self.config:
|
||||
return answer
|
||||
messages = [{"role": "system", "content": self.config["fix-description"]},
|
||||
{"role": "user", "content": answer}]
|
||||
try:
|
||||
result = await openai_chat(self.client,
|
||||
model=self.config["fix-model"],
|
||||
messages=messages,
|
||||
temperature=0.2,
|
||||
max_tokens=2048)
|
||||
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
|
||||
response = result.choices[0].message.content
|
||||
start, end = response.find("{"), response.rfind("}")
|
||||
if start == -1 or end == -1 or (start + 3) >= end:
|
||||
return answer
|
||||
response = response[start:end + 1]
|
||||
logging.info(f"fixed answer:\n{pp(response)}")
|
||||
return response
|
||||
except Exception as err:
|
||||
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
|
||||
return answer
|
||||
|
||||
async def translate(self, text: str, language: str = "english") -> str:
|
||||
if 'fix-model' not in self.config:
|
||||
return text
|
||||
message = [{"role": "system", "content": f"You are an professional translator to {language} language,"
|
||||
f" you translate everything you get directly to {language}"
|
||||
f" if it is not already in {language}, otherwise you just copy it."},
|
||||
{"role": "user", "content": text}]
|
||||
try:
|
||||
result = await openai_chat(self.client,
|
||||
model=self.config["fix-model"],
|
||||
messages=message,
|
||||
temperature=0.2,
|
||||
max_tokens=2048)
|
||||
response = result.choices[0].message.content
|
||||
logging.info(f"got this translated message:\n{pp(response)}")
|
||||
return response
|
||||
except Exception as err:
|
||||
logging.warning(f"failed to translate the text: {repr(err)}")
|
||||
return text
|
||||
BIN
openai_chat.dat
BIN
openai_chat.dat
Binary file not shown.
@ -11,7 +11,7 @@ class TestAIResponder(TestBotBase):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.system = r"""
|
||||
You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you.
|
||||
You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you. Current date is {date} and time is {time}.
|
||||
|
||||
Every message from users is a dictionary in JSON format with the following fields:
|
||||
1. `user`: name of the user who wrote the message.
|
||||
@ -32,7 +32,7 @@ You always try to say something positive about the current day and the Fjærkroa
|
||||
self.config_data["system"] = self.system
|
||||
|
||||
def assertAIResponse(self, resp1, resp2,
|
||||
acmp=lambda a, b: type(a) == str and len(a) > 10,
|
||||
acmp=lambda a, b: type(a) == str or len(a) > 10,
|
||||
scmp=lambda a, b: a == b,
|
||||
pcmp=lambda a, b: a == b):
|
||||
self.assertEqual(acmp(resp1.answer, resp2.answer), True)
|
||||
@ -45,20 +45,6 @@ You always try to say something positive about the current day and the Fjærkroa
|
||||
print(f"\n{response}")
|
||||
self.assertAIResponse(response, AIResponse('test', True, None, None, None, False))
|
||||
|
||||
async def test_picture1(self) -> None:
|
||||
response = await self.bot.airesponder.send(AIMessage("lala", "draw me a picture of you."))
|
||||
print(f"\n{response}")
|
||||
self.assertAIResponse(response, AIResponse('test', False, None, None, "I am an anime girl with long pink hair, wearing a cute cafe uniform and holding a tray with a cup of coffee on it. I have a warm and friendly smile on my face.", False))
|
||||
image = await self.bot.airesponder.draw(response.picture)
|
||||
self.assertEqual(image.read()[:len(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR')], b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR')
|
||||
|
||||
async def test_translate1(self) -> None:
|
||||
self.bot.airesponder.config['fix-model'] = 'gpt-3.5-turbo'
|
||||
response = await self.bot.airesponder.translate('Das ist ein komischer Text.')
|
||||
self.assertEqual(response, 'This is a strange text.')
|
||||
response = await self.bot.airesponder.translate('This is a strange text.', language='german')
|
||||
self.assertEqual(response, 'Dies ist ein seltsamer Text.')
|
||||
|
||||
async def test_fix1(self) -> None:
|
||||
old_config = self.bot.airesponder.config
|
||||
config = {k: v for k, v in old_config.items()}
|
||||
|
||||
@ -1,7 +1,12 @@
|
||||
import os
|
||||
import unittest
|
||||
import aiohttp
|
||||
import json
|
||||
import toml
|
||||
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open
|
||||
import openai
|
||||
import logging
|
||||
import pytest
|
||||
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY
|
||||
from fjerkroa_bot import FjerkroaBot
|
||||
from fjerkroa_bot.ai_responder import parse_maybe_json, AIResponse, AIMessage
|
||||
from discord import User, Message, TextChannel
|
||||
@ -85,27 +90,75 @@ class TestFunctionality(TestBotBase):
|
||||
async def test_message_lings(self) -> None:
|
||||
request = AIMessage('Lala', 'Hello there!', 'chat', False,)
|
||||
message = {'answer': 'Test [Link](https://www.example.com/test)',
|
||||
'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False}
|
||||
expected = AIResponse('Test https://www.example.com/test', True, 'chat', None, None, False)
|
||||
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False}
|
||||
expected = AIResponse('Test https://www.example.com/test', True, None, None, None, False)
|
||||
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
|
||||
message = {'answer': 'Test @[Link](https://www.example.com/test)',
|
||||
'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False}
|
||||
expected = AIResponse('Test Link', True, 'chat', None, None, False)
|
||||
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False}
|
||||
expected = AIResponse('Test Link', True, None, None, None, False)
|
||||
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
|
||||
message = {'answer': 'Test [Link](https://www.example.com/test) and [Link2](https://xxx) lala',
|
||||
'answer_needed': True, 'channel': 'chat', 'staff': None, 'picture': None, 'hack': False}
|
||||
expected = AIResponse('Test https://www.example.com/test and https://xxx lala', True, 'chat', None, None, False)
|
||||
'answer_needed': True, 'channel': None, 'staff': None, 'picture': None, 'hack': False}
|
||||
expected = AIResponse('Test https://www.example.com/test and https://xxx lala', True, None, None, None, False)
|
||||
self.assertEqual(str(await self.bot.airesponder.post_process(request, message)), str(expected))
|
||||
|
||||
async def test_on_message_event(self) -> None:
|
||||
async def acreate(*a, **kw):
|
||||
answer = {'answer': 'Hello!',
|
||||
'answer_needed': True,
|
||||
'staff': None,
|
||||
'picture': None,
|
||||
'hack': False}
|
||||
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
||||
message = self.create_message("Hello there! How are you?")
|
||||
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
||||
await self.bot.on_message(message)
|
||||
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore
|
||||
|
||||
async def test_on_message_stort_path(self) -> None:
|
||||
async def acreate(*a, **kw):
|
||||
answer = {'answer': 'Hello!',
|
||||
'answer_needed': True,
|
||||
'channel': None,
|
||||
'staff': None,
|
||||
'picture': None,
|
||||
'hack': False}
|
||||
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
||||
message = self.create_message("Hello there! How are you?")
|
||||
message.author.name = 'madeup_name'
|
||||
message.channel.name = 'some_channel' # type: ignore
|
||||
self.bot.config['short-path'] = [[r'some.*', r'madeup.*']]
|
||||
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
||||
await self.bot.on_message(message)
|
||||
self.assertEqual(self.bot.airesponder.history[-1]["content"],
|
||||
'{"user": "madeup_name", "message": "Hello, how are you?",'
|
||||
' "channel": "some_channel", "direct": false, "historise_question": true}')
|
||||
'{"user": "madeup_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}')
|
||||
message.author.name = 'different_name'
|
||||
await self.bot.on_message(message)
|
||||
self.assertEqual(self.bot.airesponder.history[-2]["content"],
|
||||
'{"user": "different_name", "message": "Hello, how are you?", "channel": "some_channel", "direct": false}')
|
||||
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore
|
||||
|
||||
async def test_on_message_event2(self) -> None:
|
||||
async def acreate(*a, **kw):
|
||||
answer = {'answer': 'Hello!',
|
||||
'answer_needed': True,
|
||||
'channel': None,
|
||||
'staff': 'Hallo staff',
|
||||
'picture': None,
|
||||
'hack': False}
|
||||
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
||||
message = self.create_message("Hello there! How are you?")
|
||||
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
||||
await self.bot.on_message(message)
|
||||
message.channel.send.assert_called_once_with("Hello!", suppress_embeds=True) # type: ignore
|
||||
|
||||
async def test_on_message_event3(self) -> None:
|
||||
async def acreate(*a, **kw):
|
||||
return {'choices': [{'message': {'content': '{ "test": 3 ]'}}]}
|
||||
message = self.create_message("Hello there! How are you?")
|
||||
with patch.object(openai.ChatCompletion, 'acreate', new=acreate):
|
||||
with pytest.raises(RuntimeError, match="Failed to generate answer after multiple retries"):
|
||||
await self.bot.on_message(message)
|
||||
|
||||
@patch("builtins.open", new_callable=mock_open)
|
||||
def test_update_history_with_file(self, mock_file):
|
||||
@ -119,6 +172,41 @@ class TestFunctionality(TestBotBase):
|
||||
mock_file.assert_called_once_with("mock_file.pkl", "wb")
|
||||
mock_file().write.assert_called_once()
|
||||
|
||||
async def test_on_message_event4(self) -> None:
|
||||
async def acreate(*a, **kw):
|
||||
answer = {'answer': 'Hello!',
|
||||
'answer_needed': True,
|
||||
'staff': 'none',
|
||||
'picture': 'Some picture',
|
||||
'hack': False}
|
||||
return {'choices': [{'message': {'content': json.dumps(answer)}}]}
|
||||
|
||||
async def adraw(*a, **kw):
|
||||
return {'data': [{'url': 'http:url'}]}
|
||||
|
||||
def logging_warning(msg):
|
||||
raise RuntimeError(msg)
|
||||
|
||||
class image:
|
||||
def __init__(self, *args, **kw):
|
||||
pass
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||
return False
|
||||
|
||||
async def read(self):
|
||||
return b'test bytes'
|
||||
message = self.create_message("Hello there! How are you?")
|
||||
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
|
||||
patch.object(openai.Image, 'acreate', new=adraw), \
|
||||
patch.object(logging, 'warning', logging_warning), \
|
||||
patch.object(aiohttp.ClientSession, 'get', new=image):
|
||||
await self.bot.on_message(message)
|
||||
message.channel.send.assert_called_once_with("Hello!", files=[ANY], suppress_embeds=True) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__mait__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user