Changes.
This commit is contained in:
parent
39d7c48ee3
commit
0c23da7c9e
@ -1,10 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import asyncio
|
|
||||||
import random
|
import random
|
||||||
import multiline
|
import multiline
|
||||||
import openai
|
|
||||||
import aiohttp
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import re
|
import re
|
||||||
@ -84,16 +81,6 @@ def async_cache_to_file(filename):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
@async_cache_to_file('openai_chat.dat')
|
|
||||||
async def openai_chat(client, *args, **kwargs):
|
|
||||||
return await client.chat.completions.create(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
@async_cache_to_file('openai_chat.dat')
|
|
||||||
async def openai_image(client, *args, **kwargs):
|
|
||||||
return await client.images.generate(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_maybe_json(json_string):
|
def parse_maybe_json(json_string):
|
||||||
if json_string is None:
|
if json_string is None:
|
||||||
return None
|
return None
|
||||||
@ -152,12 +139,17 @@ class AIResponse(AIMessageBase):
|
|||||||
self.hack = hack
|
self.hack = hack
|
||||||
|
|
||||||
|
|
||||||
class AIResponder(object):
|
class AIResponderBase(object):
|
||||||
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.history: List[Dict[str, Any]] = []
|
|
||||||
self.channel = channel if channel is not None else 'system'
|
self.channel = channel if channel is not None else 'system'
|
||||||
self.client = openai.AsyncOpenAI(api_key=self.config['openai-token'])
|
|
||||||
|
|
||||||
|
class AIResponder(AIResponderBase):
|
||||||
|
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
||||||
|
super().__init__(config, channel)
|
||||||
|
self.history: List[Dict[str, Any]] = []
|
||||||
self.rate_limit_backoff = exponential_backoff()
|
self.rate_limit_backoff = exponential_backoff()
|
||||||
self.history_file: Optional[Path] = None
|
self.history_file: Optional[Path] = None
|
||||||
if 'history-directory' in self.config:
|
if 'history-directory' in self.config:
|
||||||
@ -166,7 +158,7 @@ class AIResponder(object):
|
|||||||
with open(self.history_file, 'rb') as fd:
|
with open(self.history_file, 'rb') as fd:
|
||||||
self.history = pickle.load(fd)
|
self.history = pickle.load(fd)
|
||||||
|
|
||||||
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
def message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||||
messages = []
|
messages = []
|
||||||
system = self.config.get(self.channel, self.config['system'])
|
system = self.config.get(self.channel, self.config['system'])
|
||||||
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
|
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
|
||||||
@ -187,79 +179,14 @@ class AIResponder(object):
|
|||||||
|
|
||||||
async def draw(self, description: str) -> BytesIO:
|
async def draw(self, description: str) -> BytesIO:
|
||||||
if self.config.get('leonardo-token') is not None:
|
if self.config.get('leonardo-token') is not None:
|
||||||
return await self._draw_leonardo(description)
|
return await self.draw_leonardo(description)
|
||||||
return await self._draw_openai(description)
|
return await self.draw_openai(description)
|
||||||
|
|
||||||
async def _draw_openai(self, description: str) -> BytesIO:
|
async def draw_leonardo(self, description: str) -> BytesIO:
|
||||||
for _ in range(3):
|
raise NotImplementedError()
|
||||||
try:
|
|
||||||
response = await openai_image(self.client, prompt=description, n=1, size="1024x1024", model="dall-e-3")
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response.data[0].url) as image:
|
|
||||||
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
|
|
||||||
return BytesIO(await image.read())
|
|
||||||
except Exception as err:
|
|
||||||
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
|
||||||
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
|
||||||
|
|
||||||
async def _draw_leonardo(self, description: str) -> BytesIO:
|
async def draw_openai(self, description: str) -> BytesIO:
|
||||||
error_backoff = exponential_backoff(max_attempts=12)
|
raise NotImplementedError()
|
||||||
generation_id = None
|
|
||||||
image_url = None
|
|
||||||
image_bytes = None
|
|
||||||
while True:
|
|
||||||
error_sleep = next(error_backoff)
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
if generation_id is None:
|
|
||||||
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
|
|
||||||
json={"prompt": description,
|
|
||||||
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
|
|
||||||
"num_images": 1,
|
|
||||||
"sd_version": "v2",
|
|
||||||
"promptMagic": True,
|
|
||||||
"unzoomAmount": 1,
|
|
||||||
"width": 512,
|
|
||||||
"height": 512},
|
|
||||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
|
||||||
"Accept": "application/json",
|
|
||||||
"Content-Type": "application/json"},
|
|
||||||
) as response:
|
|
||||||
response = await response.json()
|
|
||||||
if "sdGenerationJob" not in response:
|
|
||||||
logging.warning(f"No 'sdGenerationJob' found in response, sleep for {error_sleep}s: {repr(response)}")
|
|
||||||
await asyncio.sleep(error_sleep)
|
|
||||||
continue
|
|
||||||
generation_id = response["sdGenerationJob"]["generationId"]
|
|
||||||
if image_url is None:
|
|
||||||
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
|
||||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
|
||||||
"Accept": "application/json"},
|
|
||||||
) as response:
|
|
||||||
response = await response.json()
|
|
||||||
if "generations_by_pk" not in response:
|
|
||||||
logging.warning(f"Unexpected response, sleep for {error_sleep}s: {repr(response)}")
|
|
||||||
await asyncio.sleep(error_sleep)
|
|
||||||
continue
|
|
||||||
if len(response["generations_by_pk"]["generated_images"]) == 0:
|
|
||||||
await asyncio.sleep(error_sleep)
|
|
||||||
continue
|
|
||||||
image_url = response["generations_by_pk"]["generated_images"][0]["url"]
|
|
||||||
if image_bytes is None:
|
|
||||||
async with session.get(image_url) as response:
|
|
||||||
image_bytes = BytesIO(await response.read())
|
|
||||||
async with session.delete(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
|
||||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}"},
|
|
||||||
) as response:
|
|
||||||
await response.json()
|
|
||||||
logging.info(f'Drawed a picture with leonardo AI on this description: {repr(description)}')
|
|
||||||
return image_bytes
|
|
||||||
except Exception as err:
|
|
||||||
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}\n{repr(err)}")
|
|
||||||
else:
|
|
||||||
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}")
|
|
||||||
await asyncio.sleep(error_sleep)
|
|
||||||
raise RuntimeError(f"Failed to generate image {repr(description)}")
|
|
||||||
|
|
||||||
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
|
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
|
||||||
for fld in ('answer', 'channel', 'staff', 'picture', 'hack'):
|
for fld in ('answer', 'channel', 'staff', 'picture', 'hack'):
|
||||||
@ -307,80 +234,14 @@ class AIResponder(object):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
async def chat(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
||||||
model = self.config["model"]
|
raise NotImplementedError()
|
||||||
try:
|
|
||||||
result = await openai_chat(self.client,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
temperature=self.config["temperature"],
|
|
||||||
max_tokens=self.config["max-tokens"],
|
|
||||||
top_p=self.config["top-p"],
|
|
||||||
presence_penalty=self.config["presence-penalty"],
|
|
||||||
frequency_penalty=self.config["frequency-penalty"])
|
|
||||||
answer_obj = result.choices[0].message
|
|
||||||
answer = {'content': answer_obj.content, 'role': answer_obj.role}
|
|
||||||
self.rate_limit_backoff = exponential_backoff()
|
|
||||||
logging.info(f"generated response {result.usage}: {repr(answer)}")
|
|
||||||
return answer, limit
|
|
||||||
except openai.BadRequestError as err:
|
|
||||||
if 'maximum context length is' in str(err) and limit > 4:
|
|
||||||
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
|
|
||||||
limit -= 1
|
|
||||||
return None, limit
|
|
||||||
raise err
|
|
||||||
except openai.RateLimitError as err:
|
|
||||||
rate_limit_sleep = next(self.rate_limit_backoff)
|
|
||||||
if "retry-model" in self.config:
|
|
||||||
model = self.config["retry-model"]
|
|
||||||
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
|
|
||||||
await asyncio.sleep(rate_limit_sleep)
|
|
||||||
except Exception as err:
|
|
||||||
logging.warning(f"failed to generate response: {repr(err)}")
|
|
||||||
return None, limit
|
|
||||||
|
|
||||||
async def fix(self, answer: str) -> str:
|
async def fix(self, answer: str) -> str:
|
||||||
if 'fix-model' not in self.config:
|
raise NotImplementedError()
|
||||||
return answer
|
|
||||||
messages = [{"role": "system", "content": self.config["fix-description"]},
|
|
||||||
{"role": "user", "content": answer}]
|
|
||||||
try:
|
|
||||||
result = await openai_chat(self.client,
|
|
||||||
model=self.config["fix-model"],
|
|
||||||
messages=messages,
|
|
||||||
temperature=0.2,
|
|
||||||
max_tokens=2048)
|
|
||||||
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
|
|
||||||
response = result.choices[0].message.content
|
|
||||||
start, end = response.find("{"), response.rfind("}")
|
|
||||||
if start == -1 or end == -1 or (start + 3) >= end:
|
|
||||||
return answer
|
|
||||||
response = response[start:end + 1]
|
|
||||||
logging.info(f"fixed answer:\n{pp(response)}")
|
|
||||||
return response
|
|
||||||
except Exception as err:
|
|
||||||
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
|
|
||||||
return answer
|
|
||||||
|
|
||||||
async def translate(self, text: str, language: str = "english") -> str:
|
async def translate(self, text: str, language: str = "english") -> str:
|
||||||
if 'fix-model' not in self.config:
|
raise NotImplementedError()
|
||||||
return text
|
|
||||||
message = [{"role": "system", "content": f"You are an professional translator to {language} language,"
|
|
||||||
f" you translate everything you get directly to {language}"
|
|
||||||
f" if it is not already in {language}, otherwise you just copy it."},
|
|
||||||
{"role": "user", "content": text}]
|
|
||||||
try:
|
|
||||||
result = await openai_chat(self.client,
|
|
||||||
model=self.config["fix-model"],
|
|
||||||
messages=message,
|
|
||||||
temperature=0.2,
|
|
||||||
max_tokens=2048)
|
|
||||||
response = result.choices[0].message.content
|
|
||||||
logging.info(f"got this translated message:\n{pp(response)}")
|
|
||||||
return response
|
|
||||||
except Exception as err:
|
|
||||||
logging.warning(f"failed to translate the text: {repr(err)}")
|
|
||||||
return text
|
|
||||||
|
|
||||||
def shrink_history_by_one(self, index: int = 0) -> None:
|
def shrink_history_by_one(self, index: int = 0) -> None:
|
||||||
if index >= len(self.history):
|
if index >= len(self.history):
|
||||||
@ -420,11 +281,11 @@ class AIResponder(object):
|
|||||||
|
|
||||||
while retries > 0:
|
while retries > 0:
|
||||||
# Get the message queue
|
# Get the message queue
|
||||||
messages = self._message(message, limit)
|
messages = self.message(message, limit)
|
||||||
logging.info(f"try to send this messages:\n{pp(messages)}")
|
logging.info(f"try to send this messages:\n{pp(messages)}")
|
||||||
|
|
||||||
# Attempt to send the message to the AI
|
# Attempt to send the message to the AI
|
||||||
answer, limit = await self._acreate(messages, limit)
|
answer, limit = await self.chat(messages, limit)
|
||||||
|
|
||||||
if answer is None:
|
if answer is None:
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -12,7 +12,8 @@ from discord import Message, TextChannel, DMChannel
|
|||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from watchdog.observers import Observer
|
from watchdog.observers import Observer
|
||||||
from watchdog.events import FileSystemEventHandler
|
from watchdog.events import FileSystemEventHandler
|
||||||
from .ai_responder import AIResponder, AIMessage
|
from .ai_responder import AIMessage
|
||||||
|
from .openai_responder import OpenAIResponder
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
|
||||||
@ -45,8 +46,8 @@ class FjerkroaBot(commands.Bot):
|
|||||||
self.observer.start()
|
self.observer.start()
|
||||||
|
|
||||||
def init_aichannels(self):
|
def init_aichannels(self):
|
||||||
self.airesponder = AIResponder(self.config)
|
self.airesponder = OpenAIResponder(self.config)
|
||||||
self.aichannels = {chan_name: AIResponder(self.config, chan_name) for chan_name in self.config['additional-responders']}
|
self.aichannels = {chan_name: OpenAIResponder(self.config, chan_name) for chan_name in self.config['additional-responders']}
|
||||||
|
|
||||||
def init_channels(self):
|
def init_channels(self):
|
||||||
if 'chat-channel' in self.config:
|
if 'chat-channel' in self.config:
|
||||||
|
|||||||
80
fjerkroa_bot/igdblib.py
Normal file
80
fjerkroa_bot/igdblib.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import requests
|
||||||
|
from functools import cache
|
||||||
|
|
||||||
|
|
||||||
|
class IGDBQuery(object):
|
||||||
|
def __init__(self, client_id, igdb_api_key):
|
||||||
|
self.client_id = client_id
|
||||||
|
self.igdb_api_key = igdb_api_key
|
||||||
|
|
||||||
|
def send_igdb_request(self, endpoint, query_body):
|
||||||
|
igdb_url = f'https://api.igdb.com/v4/{endpoint}'
|
||||||
|
headers = {
|
||||||
|
'Client-ID': self.client_id,
|
||||||
|
'Authorization': f'Bearer {self.igdb_api_key}'
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(igdb_url, headers=headers, data=query_body)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
except requests.RequestException as e:
|
||||||
|
print(f"Error during IGDB API request: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_query(fields, filters=None, limit=10, offset=None):
|
||||||
|
query = f"fields {','.join(fields) if fields is not None and len(fields) > 0 else '*'}; limit {limit};"
|
||||||
|
if offset is not None:
|
||||||
|
query += f' offset {offset};'
|
||||||
|
if filters:
|
||||||
|
filter_statements = [f"{key} {value}" for key, value in filters.items()]
|
||||||
|
query += " where " + " & ".join(filter_statements) + ";"
|
||||||
|
return query
|
||||||
|
|
||||||
|
def generalized_igdb_query(self, params, endpoint, fields, additional_filters=None, limit=10, offset=None):
|
||||||
|
all_filters = {key: f'~ "{value}"*' for key, value in params.items() if value}
|
||||||
|
if additional_filters:
|
||||||
|
all_filters.update(additional_filters)
|
||||||
|
|
||||||
|
query = self.build_query(fields, all_filters, limit, offset)
|
||||||
|
data = self.send_igdb_request(endpoint, query)
|
||||||
|
print(f'{endpoint}: {query} -> {data}')
|
||||||
|
return data
|
||||||
|
|
||||||
|
def create_query_function(self, name, description, parameters, endpoint, fields, additional_filters=None, limit=10):
|
||||||
|
return {
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"parameters": {"type": "object", "properties": parameters},
|
||||||
|
"function": lambda params: self.generalized_igdb_query(params, endpoint, fields, additional_filters, limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def platform_families(self):
|
||||||
|
families = self.generalized_igdb_query({}, 'platform_families', ['id', 'name'], limit=500)
|
||||||
|
return {v['id']: v['name'] for v in families}
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def platforms(self):
|
||||||
|
platforms = self.generalized_igdb_query({}, 'platforms',
|
||||||
|
['id', 'name', 'alternative_name', 'abbreviation', 'platform_family'],
|
||||||
|
limit=500)
|
||||||
|
ret = {}
|
||||||
|
for p in platforms:
|
||||||
|
names = p['name']
|
||||||
|
if 'alternative_name' in p:
|
||||||
|
names.append(p['alternative_name'])
|
||||||
|
if 'abbreviation' in p:
|
||||||
|
names.append(p['abbreviation'])
|
||||||
|
family = self.platform_families()[p['id']] if 'platform_family' in p else None
|
||||||
|
ret[p['id']] = {'names': names, 'family': family}
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def game_info(self, name):
|
||||||
|
game_info = self.generalized_igdb_query({'name': name},
|
||||||
|
['id', 'name', 'alternative_names', 'category',
|
||||||
|
'release_dates', 'franchise', 'language_supports',
|
||||||
|
'keywords', 'platforms', 'rating', 'summary'],
|
||||||
|
limit=100)
|
||||||
|
return game_info
|
||||||
66
fjerkroa_bot/leonardo_draw.py
Normal file
66
fjerkroa_bot/leonardo_draw.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from .ai_responder import exponential_backoff, AIResponderBase
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
|
class LeonardoAIDrawMixIn(AIResponderBase):
|
||||||
|
async def draw_leonardo(self, description: str) -> BytesIO:
|
||||||
|
error_backoff = exponential_backoff(max_attempts=12)
|
||||||
|
generation_id = None
|
||||||
|
image_url = None
|
||||||
|
image_bytes = None
|
||||||
|
while True:
|
||||||
|
error_sleep = next(error_backoff)
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
if generation_id is None:
|
||||||
|
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
|
||||||
|
json={"prompt": description,
|
||||||
|
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
|
||||||
|
"num_images": 1,
|
||||||
|
"sd_version": "v2",
|
||||||
|
"promptMagic": True,
|
||||||
|
"unzoomAmount": 1,
|
||||||
|
"width": 512,
|
||||||
|
"height": 512},
|
||||||
|
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/json"},
|
||||||
|
) as response:
|
||||||
|
response = await response.json()
|
||||||
|
if "sdGenerationJob" not in response:
|
||||||
|
logging.warning(f"No 'sdGenerationJob' found in response, sleep for {error_sleep}s: {repr(response)}")
|
||||||
|
await asyncio.sleep(error_sleep)
|
||||||
|
continue
|
||||||
|
generation_id = response["sdGenerationJob"]["generationId"]
|
||||||
|
if image_url is None:
|
||||||
|
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
||||||
|
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
||||||
|
"Accept": "application/json"},
|
||||||
|
) as response:
|
||||||
|
response = await response.json()
|
||||||
|
if "generations_by_pk" not in response:
|
||||||
|
logging.warning(f"Unexpected response, sleep for {error_sleep}s: {repr(response)}")
|
||||||
|
await asyncio.sleep(error_sleep)
|
||||||
|
continue
|
||||||
|
if len(response["generations_by_pk"]["generated_images"]) == 0:
|
||||||
|
await asyncio.sleep(error_sleep)
|
||||||
|
continue
|
||||||
|
image_url = response["generations_by_pk"]["generated_images"][0]["url"]
|
||||||
|
if image_bytes is None:
|
||||||
|
async with session.get(image_url) as response:
|
||||||
|
image_bytes = BytesIO(await response.read())
|
||||||
|
async with session.delete(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
||||||
|
headers={"Authorization": f"Bearer {self.config['leonardo-token']}"},
|
||||||
|
) as response:
|
||||||
|
await response.json()
|
||||||
|
logging.info(f'Drawed a picture with leonardo AI on this description: {repr(description)}')
|
||||||
|
return image_bytes
|
||||||
|
except Exception as err:
|
||||||
|
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}\n{repr(err)}")
|
||||||
|
else:
|
||||||
|
logging.warning(f"Failed to generate image, sleep for {error_sleep}s: {repr(description)}")
|
||||||
|
await asyncio.sleep(error_sleep)
|
||||||
|
raise RuntimeError(f"Failed to generate image {repr(description)}")
|
||||||
112
fjerkroa_bot/openai_responder.py
Normal file
112
fjerkroa_bot/openai_responder.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import openai
|
||||||
|
import aiohttp
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from .ai_responder import AIResponder, async_cache_to_file, exponential_backoff, pp
|
||||||
|
from .leonardo_draw import LeonardoAIDrawMixIn
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Dict, Any, Optional, List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
@async_cache_to_file('openai_chat.dat')
|
||||||
|
async def openai_chat(client, *args, **kwargs):
|
||||||
|
return await client.chat.completions.create(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@async_cache_to_file('openai_chat.dat')
|
||||||
|
async def openai_image(client, *args, **kwargs):
|
||||||
|
response = await client.images.generate(*args, **kwargs)
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(response.data[0].url) as image:
|
||||||
|
return BytesIO(await image.read())
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIResponder(AIResponder, LeonardoAIDrawMixIn):
|
||||||
|
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
||||||
|
super().__init__(config, channel)
|
||||||
|
self.client = openai.AsyncOpenAI(api_key=self.config['openai-token'])
|
||||||
|
|
||||||
|
async def draw_openai(self, description: str) -> BytesIO:
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
response = await openai_image(self.client, prompt=description, n=1, size="1024x1024", model="dall-e-3")
|
||||||
|
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
|
||||||
|
return response
|
||||||
|
except Exception as err:
|
||||||
|
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
||||||
|
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
||||||
|
|
||||||
|
async def chat(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
||||||
|
model = self.config["model"]
|
||||||
|
try:
|
||||||
|
result = await openai_chat(self.client,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.config["temperature"],
|
||||||
|
max_tokens=self.config["max-tokens"],
|
||||||
|
top_p=self.config["top-p"],
|
||||||
|
presence_penalty=self.config["presence-penalty"],
|
||||||
|
frequency_penalty=self.config["frequency-penalty"])
|
||||||
|
answer_obj = result.choices[0].message
|
||||||
|
answer = {'content': answer_obj.content, 'role': answer_obj.role}
|
||||||
|
self.rate_limit_backoff = exponential_backoff()
|
||||||
|
logging.info(f"generated response {result.usage}: {repr(answer)}")
|
||||||
|
return answer, limit
|
||||||
|
except openai.BadRequestError as err:
|
||||||
|
if 'maximum context length is' in str(err) and limit > 4:
|
||||||
|
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
|
||||||
|
limit -= 1
|
||||||
|
return None, limit
|
||||||
|
raise err
|
||||||
|
except openai.RateLimitError as err:
|
||||||
|
rate_limit_sleep = next(self.rate_limit_backoff)
|
||||||
|
if "retry-model" in self.config:
|
||||||
|
model = self.config["retry-model"]
|
||||||
|
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
|
||||||
|
await asyncio.sleep(rate_limit_sleep)
|
||||||
|
except Exception as err:
|
||||||
|
logging.warning(f"failed to generate response: {repr(err)}")
|
||||||
|
return None, limit
|
||||||
|
|
||||||
|
async def fix(self, answer: str) -> str:
|
||||||
|
if 'fix-model' not in self.config:
|
||||||
|
return answer
|
||||||
|
messages = [{"role": "system", "content": self.config["fix-description"]},
|
||||||
|
{"role": "user", "content": answer}]
|
||||||
|
try:
|
||||||
|
result = await openai_chat(self.client,
|
||||||
|
model=self.config["fix-model"],
|
||||||
|
messages=messages,
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=2048)
|
||||||
|
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
|
||||||
|
response = result.choices[0].message.content
|
||||||
|
start, end = response.find("{"), response.rfind("}")
|
||||||
|
if start == -1 or end == -1 or (start + 3) >= end:
|
||||||
|
return answer
|
||||||
|
response = response[start:end + 1]
|
||||||
|
logging.info(f"fixed answer:\n{pp(response)}")
|
||||||
|
return response
|
||||||
|
except Exception as err:
|
||||||
|
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
|
||||||
|
return answer
|
||||||
|
|
||||||
|
async def translate(self, text: str, language: str = "english") -> str:
|
||||||
|
if 'fix-model' not in self.config:
|
||||||
|
return text
|
||||||
|
message = [{"role": "system", "content": f"You are an professional translator to {language} language,"
|
||||||
|
f" you translate everything you get directly to {language}"
|
||||||
|
f" if it is not already in {language}, otherwise you just copy it."},
|
||||||
|
{"role": "user", "content": text}]
|
||||||
|
try:
|
||||||
|
result = await openai_chat(self.client,
|
||||||
|
model=self.config["fix-model"],
|
||||||
|
messages=message,
|
||||||
|
temperature=0.2,
|
||||||
|
max_tokens=2048)
|
||||||
|
response = result.choices[0].message.content
|
||||||
|
logging.info(f"got this translated message:\n{pp(response)}")
|
||||||
|
return response
|
||||||
|
except Exception as err:
|
||||||
|
logging.warning(f"failed to translate the text: {repr(err)}")
|
||||||
|
return text
|
||||||
BIN
openai_chat.dat
BIN
openai_chat.dat
Binary file not shown.
@ -45,6 +45,20 @@ You always try to say something positive about the current day and the Fjærkroa
|
|||||||
print(f"\n{response}")
|
print(f"\n{response}")
|
||||||
self.assertAIResponse(response, AIResponse('test', True, None, None, None, False))
|
self.assertAIResponse(response, AIResponse('test', True, None, None, None, False))
|
||||||
|
|
||||||
|
async def test_picture1(self) -> None:
|
||||||
|
response = await self.bot.airesponder.send(AIMessage("lala", "draw me a picture of you."))
|
||||||
|
print(f"\n{response}")
|
||||||
|
self.assertAIResponse(response, AIResponse('test', False, None, None, "I am an anime girl with long pink hair, wearing a cute cafe uniform and holding a tray with a cup of coffee on it. I have a warm and friendly smile on my face.", False))
|
||||||
|
image = await self.bot.airesponder.draw(response.picture)
|
||||||
|
self.assertEqual(image.read()[:len(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR')], b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR')
|
||||||
|
|
||||||
|
async def test_translate1(self) -> None:
|
||||||
|
self.bot.airesponder.config['fix-model'] = 'gpt-3.5-turbo'
|
||||||
|
response = await self.bot.airesponder.translate('Das ist ein komischer Text.')
|
||||||
|
self.assertEqual(response, 'This is a strange text.')
|
||||||
|
response = await self.bot.airesponder.translate('This is a strange text.', language='german')
|
||||||
|
self.assertEqual(response, 'Dies ist ein seltsamer Text.')
|
||||||
|
|
||||||
async def test_fix1(self) -> None:
|
async def test_fix1(self) -> None:
|
||||||
old_config = self.bot.airesponder.config
|
old_config = self.bot.airesponder.config
|
||||||
config = {k: v for k, v in old_config.items()}
|
config = {k: v for k, v in old_config.items()}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user