142 lines
7.3 KiB
Python
142 lines
7.3 KiB
Python
import openai
|
|
import aiohttp
|
|
import logging
|
|
import asyncio
|
|
from .ai_responder import AIResponder, async_cache_to_file, exponential_backoff, pp
|
|
from .leonardo_draw import LeonardoAIDrawMixIn
|
|
from io import BytesIO
|
|
from typing import Dict, Any, Optional, List, Tuple
|
|
|
|
|
|
@async_cache_to_file('openai_chat.dat')
|
|
async def openai_chat(client, *args, **kwargs):
|
|
return await client.chat.completions.create(*args, **kwargs)
|
|
|
|
|
|
@async_cache_to_file('openai_chat.dat')
|
|
async def openai_image(client, *args, **kwargs):
|
|
response = await client.images.generate(*args, **kwargs)
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.get(response.data[0].url) as image:
|
|
return BytesIO(await image.read())
|
|
|
|
|
|
class OpenAIResponder(AIResponder, LeonardoAIDrawMixIn):
|
|
def __init__(self, config: Dict[str, Any], channel: Optional[str] = None) -> None:
|
|
super().__init__(config, channel)
|
|
self.client = openai.AsyncOpenAI(api_key=self.config['openai-token'])
|
|
|
|
async def draw_openai(self, description: str) -> BytesIO:
|
|
for _ in range(3):
|
|
try:
|
|
response = await openai_image(self.client, prompt=description, n=1, size="1024x1024", model="dall-e-3")
|
|
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
|
|
return response
|
|
except Exception as err:
|
|
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
|
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
|
|
|
async def chat(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
|
if type(messages[-1]['content']) == str:
|
|
model = self.config["model"]
|
|
elif 'model-vision' in self.config:
|
|
model = self.config["model-vision"]
|
|
else:
|
|
messages[-1]['content'] = messages[-1]['content'][0]['text']
|
|
try:
|
|
result = await openai_chat(self.client,
|
|
model=model,
|
|
messages=messages,
|
|
temperature=self.config["temperature"],
|
|
max_tokens=self.config["max-tokens"],
|
|
top_p=self.config["top-p"],
|
|
presence_penalty=self.config["presence-penalty"],
|
|
frequency_penalty=self.config["frequency-penalty"])
|
|
answer_obj = result.choices[0].message
|
|
answer = {'content': answer_obj.content, 'role': answer_obj.role}
|
|
self.rate_limit_backoff = exponential_backoff()
|
|
logging.info(f"generated response {result.usage}: {repr(answer)}")
|
|
return answer, limit
|
|
except openai.BadRequestError as err:
|
|
if 'maximum context length is' in str(err) and limit > 4:
|
|
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
|
|
limit -= 1
|
|
return None, limit
|
|
raise err
|
|
except openai.RateLimitError as err:
|
|
rate_limit_sleep = next(self.rate_limit_backoff)
|
|
if "retry-model" in self.config:
|
|
model = self.config["retry-model"]
|
|
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
|
|
await asyncio.sleep(rate_limit_sleep)
|
|
except Exception as err:
|
|
logging.warning(f"failed to generate response: {repr(err)}")
|
|
return None, limit
|
|
|
|
async def fix(self, answer: str) -> str:
|
|
if 'fix-model' not in self.config:
|
|
return answer
|
|
messages = [{"role": "system", "content": self.config["fix-description"]},
|
|
{"role": "user", "content": answer}]
|
|
try:
|
|
result = await openai_chat(self.client,
|
|
model=self.config["fix-model"],
|
|
messages=messages,
|
|
temperature=0.2,
|
|
max_tokens=2048)
|
|
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
|
|
response = result.choices[0].message.content
|
|
start, end = response.find("{"), response.rfind("}")
|
|
if start == -1 or end == -1 or (start + 3) >= end:
|
|
return answer
|
|
response = response[start:end + 1]
|
|
logging.info(f"fixed answer:\n{pp(response)}")
|
|
return response
|
|
except Exception as err:
|
|
logging.warning(f"failed to execute a fix for the answer: {repr(err)}")
|
|
return answer
|
|
|
|
async def translate(self, text: str, language: str = "english") -> str:
|
|
if 'fix-model' not in self.config:
|
|
return text
|
|
message = [{"role": "system", "content": f"You are an professional translator to {language} language,"
|
|
f" you translate everything you get directly to {language}"
|
|
f" if it is not already in {language}, otherwise you just copy it."},
|
|
{"role": "user", "content": text}]
|
|
try:
|
|
result = await openai_chat(self.client,
|
|
model=self.config["fix-model"],
|
|
messages=message,
|
|
temperature=0.2,
|
|
max_tokens=2048)
|
|
response = result.choices[0].message.content
|
|
logging.info(f"got this translated message:\n{pp(response)}")
|
|
return response
|
|
except Exception as err:
|
|
logging.warning(f"failed to translate the text: {repr(err)}")
|
|
return text
|
|
|
|
async def memory_rewrite(self, memory: str, message_user: str, answer_user: str, question: str, answer: str) -> str:
|
|
if 'memory-model' not in self.config:
|
|
return memory
|
|
messages = [{'role': 'system', 'content': self.config.get('memory-system', 'You are an memory assistant.')},
|
|
{'role': 'user', 'content': f'Here is my previous memory:\n```\n{memory}\n```\n\n'
|
|
f'Here is my conversanion:\n```\n{message_user}: {question}\n\n{answer_user}: {answer}\n```\n\n'
|
|
f'Please rewrite the memory in a way, that it contain the content mentioned in conversation. '
|
|
f'Summarize the memory if required, try to keep important information. '
|
|
f'Write just new memory data without any comments.'}]
|
|
logging.info(f'Rewrite memory:\n{pp(messages)}')
|
|
try:
|
|
# logging.info(f'send this memory request:\n{pp(messages)}')
|
|
result = await openai_chat(self.client,
|
|
model=self.config['memory-model'],
|
|
messages=messages,
|
|
temperature=0.6,
|
|
max_tokens=4096)
|
|
new_memory = result.choices[0].message.content
|
|
logging.info(f'new memory:\n{new_memory}')
|
|
return new_memory
|
|
except Exception as err:
|
|
logging.warning(f"failed to create new memory: {repr(err)}")
|
|
return memory
|