Compare commits

...

2 Commits

Author SHA1 Message Date
Fjerkroa Auto
a2c7aec1e3 Merge branch 'master' of stage.fjerkroa.no:Fjerkroa/discord_bot 2023-11-14 10:27:29 +01:00
Fjerkroa Auto
488a8ef174 Support new openai API. 2023-11-14 10:27:03 +01:00

View File

@ -1,3 +1,4 @@
import os
import json import json
import asyncio import asyncio
import random import random
@ -119,7 +120,7 @@ class AIResponder(object):
self.config = config self.config = config
self.history: List[Dict[str, Any]] = [] self.history: List[Dict[str, Any]] = []
self.channel = channel if channel is not None else 'system' self.channel = channel if channel is not None else 'system'
openai.api_key = self.config['openai-token'] self.client = openai.AsyncOpenAI(api_key=self.config['openai-token'])
self.rate_limit_backoff = exponential_backoff() self.rate_limit_backoff = exponential_backoff()
self.history_file: Optional[Path] = None self.history_file: Optional[Path] = None
if 'history-directory' in self.config: if 'history-directory' in self.config:
@ -133,6 +134,11 @@ class AIResponder(object):
system = self.config.get(self.channel, self.config['system']) system = self.config.get(self.channel, self.config['system'])
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\ system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
.replace('{time}', time.strftime('%H:%M:%S')) .replace('{time}', time.strftime('%H:%M:%S'))
news_feed = self.config.get('news')
if news_feed and os.path.exists(news_feed):
with open(news_feed) as fd:
news_feed = fd.read().strip()
system = system.replace('{news}', news_feed)
messages.append({"role": "system", "content": system}) messages.append({"role": "system", "content": system})
if limit is not None: if limit is not None:
while len(self.history) > limit: while len(self.history) > limit:
@ -150,9 +156,9 @@ class AIResponder(object):
async def _draw_openai(self, description: str) -> BytesIO: async def _draw_openai(self, description: str) -> BytesIO:
for _ in range(3): for _ in range(3):
try: try:
response = await openai.Image.acreate(prompt=description, n=1, size="512x512") response = await self.client.images.generate(prompt=description, n=1, size="1024x1024", model="dall-e-3")
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(response['data'][0]['url']) as image: async with session.get(response.data[0].url) as image:
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}') logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
return BytesIO(await image.read()) return BytesIO(await image.read())
except Exception as err: except Exception as err:
@ -264,29 +270,27 @@ class AIResponder(object):
return True return True
return False return False
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]: async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[openai.types.chat.ChatCompletionMessage], int]:
model = self.config["model"] model = self.config["model"]
try: try:
result = await openai.ChatCompletion.acreate(model=model, result = await self.client.chat.completions.create(model=model,
messages=messages, messages=messages,
temperature=self.config["temperature"], temperature=self.config["temperature"],
max_tokens=self.config["max-tokens"], max_tokens=self.config["max-tokens"],
top_p=self.config["top-p"], top_p=self.config["top-p"],
presence_penalty=self.config["presence-penalty"], presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"]) frequency_penalty=self.config["frequency-penalty"])
answer = result['choices'][0]['message'] answer = result.choices[0].message
if type(answer) != dict:
answer = answer.to_dict()
self.rate_limit_backoff = exponential_backoff() self.rate_limit_backoff = exponential_backoff()
logging.info(f"generated response {result.get('usage')}: {repr(answer)}") logging.info(f"generated response {result.usage}: {repr(answer)}")
return answer, limit return answer, limit
except openai.error.InvalidRequestError as err: except openai.BadRequestError as err:
if 'maximum context length is' in str(err) and limit > 4: if 'maximum context length is' in str(err) and limit > 4:
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}") logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
limit -= 1 limit -= 1
return None, limit return None, limit
raise err raise err
except openai.error.RateLimitError as err: except openai.RateLimitError as err:
rate_limit_sleep = next(self.rate_limit_backoff) rate_limit_sleep = next(self.rate_limit_backoff)
if "retry-model" in self.config: if "retry-model" in self.config:
model = self.config["retry-model"] model = self.config["retry-model"]
@ -302,12 +306,12 @@ class AIResponder(object):
messages = [{"role": "system", "content": self.config["fix-description"]}, messages = [{"role": "system", "content": self.config["fix-description"]},
{"role": "user", "content": answer}] {"role": "user", "content": answer}]
try: try:
result = await openai.ChatCompletion.acreate(model=self.config["fix-model"], result = await self.client.chat.completions.create(model=self.config["fix-model"],
messages=messages, messages=messages,
temperature=0.2, temperature=0.2,
max_tokens=2048) max_tokens=2048)
logging.info(f"got this message as fix:\n{pp(result['choices'][0]['message']['content'])}") logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
response = result['choices'][0]['message']['content'] response = result.choices[0].message.content
start, end = response.find("{"), response.rfind("}") start, end = response.find("{"), response.rfind("}")
if start == -1 or end == -1 or (start + 3) >= end: if start == -1 or end == -1 or (start + 3) >= end:
return answer return answer
@ -326,11 +330,11 @@ class AIResponder(object):
f" if it is not already in {language}, otherwise you just copy it."}, f" if it is not already in {language}, otherwise you just copy it."},
{"role": "user", "content": text}] {"role": "user", "content": text}]
try: try:
result = await openai.ChatCompletion.acreate(model=self.config["fix-model"], result = await self.client.chat.completions.create(model=self.config["fix-model"],
messages=message, messages=message,
temperature=0.2, temperature=0.2,
max_tokens=2048) max_tokens=2048)
response = result['choices'][0]['message']['content'] response = result.choices[0].message.content
logging.info(f"got this translated message:\n{pp(response)}") logging.info(f"got this translated message:\n{pp(response)}")
return response return response
except Exception as err: except Exception as err:
@ -386,16 +390,16 @@ class AIResponder(object):
# Attempt to parse the AI's response # Attempt to parse the AI's response
try: try:
response = parse_json(answer['content']) response = parse_json(answer.content)
except Exception as err: except Exception as err:
logging.warning(f"failed to parse the answer: {pp(err)}\n{repr(answer['content'])}") logging.warning(f"failed to parse the answer: {pp(err)}\n{repr(answer.content)}")
answer['content'] = await self.fix(answer['content']) answer.content = await self.fix(answer['content'])
# Retry parsing the fixed content # Retry parsing the fixed content
try: try:
response = parse_json(answer['content']) response = parse_json(answer.content)
except Exception as err: except Exception as err:
logging.error(f"failed to parse the fixed answer: {pp(err)}\n{repr(answer['content'])}") logging.error(f"failed to parse the fixed answer: {pp(err)}\n{repr(answer.content)}")
retries -= 1 retries -= 1
continue continue
@ -410,7 +414,7 @@ class AIResponder(object):
# Post-process the message and update the answer's content # Post-process the message and update the answer's content
answer_message = await self.post_process(message, response) answer_message = await self.post_process(message, response)
answer['content'] = str(answer_message) answer.content = str(answer_message)
# Update message history # Update message history
self.update_history(messages[-1], answer, limit, message.historise_question) self.update_history(messages[-1], answer, limit, message.historise_question)