This commit is contained in:
OK 2024-03-16 13:44:46 +01:00
parent 39d7c48ee3
commit 0c23da7c9e
7 changed files with 297 additions and 163 deletions

View File

@ -1,10 +1,7 @@
import os
import json
import asyncio
import random
import multiline
import openai
import aiohttp
import logging
import time
import re
@ -84,16 +81,6 @@ def async_cache_to_file(filename):
return decorator
@async_cache_to_file('openai_chat.dat')
async def openai_chat(client, *args, **kwargs):
return await client.chat.completions.create(*args, **kwargs)
@async_cache_to_file('openai_chat.dat')
async def openai_image(client, *args, **kwargs):
return await client.images.generate(*args, **kwargs)
def parse_maybe_json(json_string):
if json_string is None:
return None
@ -152,12 +139,17 @@ class AIResponse(AIMessageBase):
self.hack = hack
class AIResponder(object):
class AIResponderBase(object):
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
super().__init__()
self.config = config
self.history: List[Dict[str, Any]] = []
self.channel = channel if channel is not None else 'system'
self.client = openai.AsyncOpenAI(api_key=self.config['openai-token'])
class AIResponder(AIResponderBase):
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
super().__init__(config, channel)
self.history: List[Dict[str, Any]] = []
self.rate_limit_backoff = exponential_backoff()
self.history_file: Optional[Path] = None
if 'history-directory' in self.config:
@ -166,7 +158,7 @@ class AIResponder(object):
with open(self.history_file, 'rb') as fd:
self.history = pickle.load(fd)
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
def message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
messages = []
system = self.config.get(self.channel, self.config['system'])
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
@ -187,79 +179,14 @@ class AIResponder(object):
async def draw(self, description: str) -> BytesIO:
if self.config.get('leonardo-token') is not None:
return await self._draw_leonardo(description)
return await self._draw_openai(description)
return await self.draw_leonardo(description)
return await self.draw_openai(description)
async def _draw_openai(self, description: str) -> BytesIO:
for _ in range(3):
try:
response = await openai_image(self.client, prompt=description, n=1, size="1024x1024", model="dall-e-3")
async with aiohttp.ClientSession() as session:
async with session.get(response.data[0].url) as image:
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
return BytesIO(await image.read())
except Exception as err:
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
async def draw_leonardo(self, description: str) -> BytesIO:
raise NotImplementedError()
async def _draw_leonardo(self, description: str) -> BytesIO:
error_backoff = exponential_backoff(max_attempts=12)
generation_id = None
image_url = None
image_bytes = None
while True:
error_sleep = next(error_backoff)
try:
async with aiohttp.ClientSession() as session:
if generation_id is None:
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
json={"prompt": description,
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
"num_images": 1,
"sd_version": "v2",
"promptMagic": True,
"unzoomAmount": 1,
"width": 512,
"height": 512},
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
"Accept": "application/json",
"Content-Type": "application/json"},
) as response:
response = await response.json()
if "sdGenerationJob" not in response:
logging.warning(f"No 'sdGenerationJob' found in response, sleep for {error_sleep}s: {repr(response)}")
await asyncio.sleep(error_sleep)
continue
generation_id = response["sdGenerationJob"]["generationId"]
if image_url is None:
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
"Accept": "application/json"},
) as response:
response = await response.json()
if "generations_by_pk" not in response:
logging.warning(f"Unexpected response, sleep for {error_sleep}s: {repr(response)}")
await asyncio.sleep(error_sleep)
continue
if len(response["generations_by_pk"]["generated_images"]) == 0:
await asyncio.sleep(error_sleep)
continue
image_url = response["generations_by_pk"]["generated_images"][0]["url"]
if image_bytes is None:
async with session.get(image_url) as response:
image_bytes = BytesIO(await response.read())
async with session.delete(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
headers={"Authorization": f"Bearer {self.config['leonardo-token']}"},
) as response:
await response.json()
logging.info(f'Drawed a picture with leonardo AI on this description: {repr(description)}')
return image_bytes
except Exception as err:
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}\n{repr(err)}")
else:
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}")
await asyncio.sleep(error_sleep)
raise RuntimeError(f"Failed to generate image {repr(description)}")
async def draw_openai(self, description: str) -> BytesIO:
raise NotImplementedError()
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
for fld in ('answer', 'channel', 'staff', 'picture', 'hack'):
@ -307,80 +234,14 @@ class AIResponder(object):
return True
return False
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
model = self.config["model"]
try:
result = await openai_chat(self.client,
model=model,
messages=messages,
temperature=self.config["temperature"],
max_tokens=self.config["max-tokens"],
top_p=self.config["top-p"],
presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"])
answer_obj = result.choices[0].message
answer = {'content': answer_obj.content, 'role': answer_obj.role}
self.rate_limit_backoff = exponential_backoff()
logging.info(f"generated response {result.usage}: {repr(answer)}")
return answer, limit
except openai.BadRequestError as err:
if 'maximum context length is' in str(err) and limit > 4:
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
limit -= 1
return None, limit
raise err
except openai.RateLimitError as err:
rate_limit_sleep = next(self.rate_limit_backoff)
if "retry-model" in self.config:
model = self.config["retry-model"]
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
await asyncio.sleep(rate_limit_sleep)
except Exception as err:
logging.warning(f"failed to generate response: {repr(err)}")
return None, limit
async def chat(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
raise NotImplementedError()
async def fix(self, answer: str) -> str:
if 'fix-model' not in self.config:
return answer
messages = [{"role": "system", "content": self.config["fix-description"]},
{"role": "user", "content": answer}]
try:
result = await openai_chat(self.client,
model=self.config["fix-model"],
messages=messages,
temperature=0.2,
max_tokens=2048)
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
response = result.choices[0].message.content
start, end = response.find("{"), response.rfind("}")
if start == -1 or end == -1 or (start + 3) >= end:
return answer
response = response[start:end + 1]
logging.info(f"fixed answer:\n{pp(response)}")
return response
except Exception as err:
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
return answer
raise NotImplementedError()
async def translate(self, text: str, language: str = "english") -> str:
if 'fix-model' not in self.config:
return text
message = [{"role": "system", "content": f"You are an professional translator to {language} language,"
f" you translate everything you get directly to {language}"
f" if it is not already in {language}, otherwise you just copy it."},
{"role": "user", "content": text}]
try:
result = await openai_chat(self.client,
model=self.config["fix-model"],
messages=message,
temperature=0.2,
max_tokens=2048)
response = result.choices[0].message.content
logging.info(f"got this translated message:\n{pp(response)}")
return response
except Exception as err:
logging.warning(f"failed to translate the text: {repr(err)}")
return text
raise NotImplementedError()
def shrink_history_by_one(self, index: int = 0) -> None:
if index >= len(self.history):
@ -420,11 +281,11 @@ class AIResponder(object):
while retries > 0:
# Get the message queue
messages = self._message(message, limit)
messages = self.message(message, limit)
logging.info(f"try to send this messages:\n{pp(messages)}")
# Attempt to send the message to the AI
answer, limit = await self._acreate(messages, limit)
answer, limit = await self.chat(messages, limit)
if answer is None:
continue

View File

@ -12,7 +12,8 @@ from discord import Message, TextChannel, DMChannel
from discord.ext import commands
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from .ai_responder import AIResponder, AIMessage
from .ai_responder import AIMessage
from .openai_responder import OpenAIResponder
from typing import Optional, Union
@ -45,8 +46,8 @@ class FjerkroaBot(commands.Bot):
self.observer.start()
def init_aichannels(self):
self.airesponder = AIResponder(self.config)
self.aichannels = {chan_name: AIResponder(self.config, chan_name) for chan_name in self.config['additional-responders']}
self.airesponder = OpenAIResponder(self.config)
self.aichannels = {chan_name: OpenAIResponder(self.config, chan_name) for chan_name in self.config['additional-responders']}
def init_channels(self):
if 'chat-channel' in self.config:

80
fjerkroa_bot/igdblib.py Normal file
View File

@ -0,0 +1,80 @@
import requests
from functools import cache
class IGDBQuery(object):
def __init__(self, client_id, igdb_api_key):
self.client_id = client_id
self.igdb_api_key = igdb_api_key
def send_igdb_request(self, endpoint, query_body):
igdb_url = f'https://api.igdb.com/v4/{endpoint}'
headers = {
'Client-ID': self.client_id,
'Authorization': f'Bearer {self.igdb_api_key}'
}
try:
response = requests.post(igdb_url, headers=headers, data=query_body)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
print(f"Error during IGDB API request: {e}")
return None
@staticmethod
def build_query(fields, filters=None, limit=10, offset=None):
query = f"fields {','.join(fields) if fields is not None and len(fields) > 0 else '*'}; limit {limit};"
if offset is not None:
query += f' offset {offset};'
if filters:
filter_statements = [f"{key} {value}" for key, value in filters.items()]
query += " where " + " & ".join(filter_statements) + ";"
return query
def generalized_igdb_query(self, params, endpoint, fields, additional_filters=None, limit=10, offset=None):
all_filters = {key: f'~ "{value}"*' for key, value in params.items() if value}
if additional_filters:
all_filters.update(additional_filters)
query = self.build_query(fields, all_filters, limit, offset)
data = self.send_igdb_request(endpoint, query)
print(f'{endpoint}: {query} -> {data}')
return data
def create_query_function(self, name, description, parameters, endpoint, fields, additional_filters=None, limit=10):
return {
"name": name,
"description": description,
"parameters": {"type": "object", "properties": parameters},
"function": lambda params: self.generalized_igdb_query(params, endpoint, fields, additional_filters, limit)
}
@cache
def platform_families(self):
families = self.generalized_igdb_query({}, 'platform_families', ['id', 'name'], limit=500)
return {v['id']: v['name'] for v in families}
@cache
def platforms(self):
platforms = self.generalized_igdb_query({}, 'platforms',
['id', 'name', 'alternative_name', 'abbreviation', 'platform_family'],
limit=500)
ret = {}
for p in platforms:
names = p['name']
if 'alternative_name' in p:
names.append(p['alternative_name'])
if 'abbreviation' in p:
names.append(p['abbreviation'])
family = self.platform_families()[p['id']] if 'platform_family' in p else None
ret[p['id']] = {'names': names, 'family': family}
return ret
def game_info(self, name):
game_info = self.generalized_igdb_query({'name': name},
['id', 'name', 'alternative_names', 'category',
'release_dates', 'franchise', 'language_supports',
'keywords', 'platforms', 'rating', 'summary'],
limit=100)
return game_info

View File

@ -0,0 +1,66 @@
import logging
import asyncio
import aiohttp
from .ai_responder import exponential_backoff, AIResponderBase
from io import BytesIO
class LeonardoAIDrawMixIn(AIResponderBase):
async def draw_leonardo(self, description: str) -> BytesIO:
error_backoff = exponential_backoff(max_attempts=12)
generation_id = None
image_url = None
image_bytes = None
while True:
error_sleep = next(error_backoff)
try:
async with aiohttp.ClientSession() as session:
if generation_id is None:
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
json={"prompt": description,
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
"num_images": 1,
"sd_version": "v2",
"promptMagic": True,
"unzoomAmount": 1,
"width": 512,
"height": 512},
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
"Accept": "application/json",
"Content-Type": "application/json"},
) as response:
response = await response.json()
if "sdGenerationJob" not in response:
logging.warning(f"No 'sdGenerationJob' found in response, sleep for {error_sleep}s: {repr(response)}")
await asyncio.sleep(error_sleep)
continue
generation_id = response["sdGenerationJob"]["generationId"]
if image_url is None:
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
"Accept": "application/json"},
) as response:
response = await response.json()
if "generations_by_pk" not in response:
logging.warning(f"Unexpected response, sleep for {error_sleep}s: {repr(response)}")
await asyncio.sleep(error_sleep)
continue
if len(response["generations_by_pk"]["generated_images"]) == 0:
await asyncio.sleep(error_sleep)
continue
image_url = response["generations_by_pk"]["generated_images"][0]["url"]
if image_bytes is None:
async with session.get(image_url) as response:
image_bytes = BytesIO(await response.read())
async with session.delete(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
headers={"Authorization": f"Bearer {self.config['leonardo-token']}"},
) as response:
await response.json()
logging.info(f'Drawed a picture with leonardo AI on this description: {repr(description)}')
return image_bytes
except Exception as err:
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}\n{repr(err)}")
else:
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}")
await asyncio.sleep(error_sleep)
raise RuntimeError(f"Failed to generate image {repr(description)}")

View File

@ -0,0 +1,112 @@
import openai
import aiohttp
import logging
import asyncio
from .ai_responder import AIResponder, async_cache_to_file, exponential_backoff, pp
from .leonardo_draw import LeonardoAIDrawMixIn
from io import BytesIO
from typing import Dict, Any, Optional, List, Tuple
@async_cache_to_file('openai_chat.dat')
async def openai_chat(client, *args, **kwargs):
return await client.chat.completions.create(*args, **kwargs)
@async_cache_to_file('openai_chat.dat')
async def openai_image(client, *args, **kwargs):
response = await client.images.generate(*args, **kwargs)
async with aiohttp.ClientSession() as session:
async with session.get(response.data[0].url) as image:
return BytesIO(await image.read())
class OpenAIResponder(AIResponder, LeonardoAIDrawMixIn):
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
super().__init__(config, channel)
self.client = openai.AsyncOpenAI(api_key=self.config['openai-token'])
async def draw_openai(self, description: str) -> BytesIO:
for _ in range(3):
try:
response = await openai_image(self.client, prompt=description, n=1, size="1024x1024", model="dall-e-3")
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
return response
except Exception as err:
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
async def chat(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
model = self.config["model"]
try:
result = await openai_chat(self.client,
model=model,
messages=messages,
temperature=self.config["temperature"],
max_tokens=self.config["max-tokens"],
top_p=self.config["top-p"],
presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"])
answer_obj = result.choices[0].message
answer = {'content': answer_obj.content, 'role': answer_obj.role}
self.rate_limit_backoff = exponential_backoff()
logging.info(f"generated response {result.usage}: {repr(answer)}")
return answer, limit
except openai.BadRequestError as err:
if 'maximum context length is' in str(err) and limit > 4:
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
limit -= 1
return None, limit
raise err
except openai.RateLimitError as err:
rate_limit_sleep = next(self.rate_limit_backoff)
if "retry-model" in self.config:
model = self.config["retry-model"]
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
await asyncio.sleep(rate_limit_sleep)
except Exception as err:
logging.warning(f"failed to generate response: {repr(err)}")
return None, limit
async def fix(self, answer: str) -> str:
if 'fix-model' not in self.config:
return answer
messages = [{"role": "system", "content": self.config["fix-description"]},
{"role": "user", "content": answer}]
try:
result = await openai_chat(self.client,
model=self.config["fix-model"],
messages=messages,
temperature=0.2,
max_tokens=2048)
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
response = result.choices[0].message.content
start, end = response.find("{"), response.rfind("}")
if start == -1 or end == -1 or (start + 3) >= end:
return answer
response = response[start:end + 1]
logging.info(f"fixed answer:\n{pp(response)}")
return response
except Exception as err:
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
return answer
async def translate(self, text: str, language: str = "english") -> str:
if 'fix-model' not in self.config:
return text
message = [{"role": "system", "content": f"You are an professional translator to {language} language,"
f" you translate everything you get directly to {language}"
f" if it is not already in {language}, otherwise you just copy it."},
{"role": "user", "content": text}]
try:
result = await openai_chat(self.client,
model=self.config["fix-model"],
messages=message,
temperature=0.2,
max_tokens=2048)
response = result.choices[0].message.content
logging.info(f"got this translated message:\n{pp(response)}")
return response
except Exception as err:
logging.warning(f"failed to translate the text: {repr(err)}")
return text

Binary file not shown.

View File

@ -45,6 +45,20 @@ You always try to say something positive about the current day and the Fjærkroa
print(f"\n{response}")
self.assertAIResponse(response, AIResponse('test', True, None, None, None, False))
async def test_picture1(self) -> None:
response = await self.bot.airesponder.send(AIMessage("lala", "draw me a picture of you."))
print(f"\n{response}")
self.assertAIResponse(response, AIResponse('test', False, None, None, "I am an anime girl with long pink hair, wearing a cute cafe uniform and holding a tray with a cup of coffee on it. I have a warm and friendly smile on my face.", False))
image = await self.bot.airesponder.draw(response.picture)
self.assertEqual(image.read()[:len(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR')], b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR')
async def test_translate1(self) -> None:
self.bot.airesponder.config['fix-model'] = 'gpt-3.5-turbo'
response = await self.bot.airesponder.translate('Das ist ein komischer Text.')
self.assertEqual(response, 'This is a strange text.')
response = await self.bot.airesponder.translate('This is a strange text.', language='german')
self.assertEqual(response, 'Dies ist ein seltsamer Text.')
async def test_fix1(self) -> None:
old_config = self.bot.airesponder.config
config = {k: v for k, v in old_config.items()}