diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index f9db137..9582037 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -1,3 +1,4 @@ +import os import json import asyncio import random @@ -119,7 +120,7 @@ class AIResponder(object): self.config = config self.history: List[Dict[str, Any]] = [] self.channel = channel if channel is not None else 'system' - openai.api_key = self.config['openai-token'] + self.client = openai.AsyncOpenAI(api_key=self.config['openai-token']) self.rate_limit_backoff = exponential_backoff() self.history_file: Optional[Path] = None if 'history-directory' in self.config: @@ -133,6 +134,11 @@ class AIResponder(object): system = self.config.get(self.channel, self.config['system']) system = system.replace('{date}', time.strftime('%Y-%m-%d'))\ .replace('{time}', time.strftime('%H:%M:%S')) + news_feed = self.config.get('news') + if news_feed and os.path.exists(news_feed): + with open(news_feed) as fd: + news_feed = fd.read().strip() + system = system.replace('{news}', news_feed) messages.append({"role": "system", "content": system}) if limit is not None: while len(self.history) > limit: @@ -150,9 +156,9 @@ class AIResponder(object): async def _draw_openai(self, description: str) -> BytesIO: for _ in range(3): try: - response = await openai.Image.acreate(prompt=description, n=1, size="512x512") + response = await self.client.images.generate(prompt=description, n=1, size="1024x1024", model="dall-e-3") async with aiohttp.ClientSession() as session: - async with session.get(response['data'][0]['url']) as image: + async with session.get(response.data[0].url) as image: logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}') return BytesIO(await image.read()) except Exception as err: @@ -263,29 +269,27 @@ class AIResponder(object): return True return False - async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]: + async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[openai.types.chat.ChatCompletionMessage], int]: model = self.config["model"] try: - result = await openai.ChatCompletion.acreate(model=model, - messages=messages, - temperature=self.config["temperature"], - max_tokens=self.config["max-tokens"], - top_p=self.config["top-p"], - presence_penalty=self.config["presence-penalty"], - frequency_penalty=self.config["frequency-penalty"]) - answer = result['choices'][0]['message'] - if type(answer) != dict: - answer = answer.to_dict() + result = await self.client.chat.completions.create(model=model, + messages=messages, + temperature=self.config["temperature"], + max_tokens=self.config["max-tokens"], + top_p=self.config["top-p"], + presence_penalty=self.config["presence-penalty"], + frequency_penalty=self.config["frequency-penalty"]) + answer = result.choices[0].message self.rate_limit_backoff = exponential_backoff() - logging.info(f"generated response {result.get('usage')}: {repr(answer)}") + logging.info(f"generated response {result.usage}: {repr(answer)}") return answer, limit - except openai.error.InvalidRequestError as err: + except openai.BadRequestError as err: if 'maximum context length is' in str(err) and limit > 4: logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}") limit -= 1 return None, limit raise err - except openai.error.RateLimitError as err: + except openai.RateLimitError as err: rate_limit_sleep = next(self.rate_limit_backoff) if "retry-model" in self.config: model = self.config["retry-model"] @@ -301,12 +305,12 @@ class AIResponder(object): messages = [{"role": "system", "content": self.config["fix-description"]}, {"role": "user", "content": answer}] try: - result = await openai.ChatCompletion.acreate(model=self.config["fix-model"], - messages=messages, - temperature=0.2, - max_tokens=2048) - logging.info(f"got this message as fix:\n{pp(result['choices'][0]['message']['content'])}") - response = result['choices'][0]['message']['content'] + result = await self.client.chat.completions.create(model=self.config["fix-model"], + messages=messages, + temperature=0.2, + max_tokens=2048) + logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}") + response = result.choices[0].message.content start, end = response.find("{"), response.rfind("}") if start == -1 or end == -1 or (start + 3) >= end: return answer @@ -325,11 +329,11 @@ class AIResponder(object): f" if it is not already in {language}, otherwise you just copy it."}, {"role": "user", "content": text}] try: - result = await openai.ChatCompletion.acreate(model=self.config["fix-model"], - messages=message, - temperature=0.2, - max_tokens=2048) - response = result['choices'][0]['message']['content'] + result = await self.client.chat.completions.create(model=self.config["fix-model"], + messages=message, + temperature=0.2, + max_tokens=2048) + response = result.choices[0].message.content logging.info(f"got this translated message:\n{pp(response)}") return response except Exception as err: @@ -385,16 +389,16 @@ class AIResponder(object): # Attempt to parse the AI's response try: - response = parse_json(answer['content']) + response = parse_json(answer.content) except Exception as err: - logging.warning(f"failed to parse the answer: {pp(err)}\n{repr(answer['content'])}") - answer['content'] = await self.fix(answer['content']) + logging.warning(f"failed to parse the answer: {pp(err)}\n{repr(answer.content)}") + answer.content = await self.fix(answer['content']) # Retry parsing the fixed content try: - response = parse_json(answer['content']) + response = parse_json(answer.content) except Exception as err: - logging.error(f"failed to parse the fixed answer: {pp(err)}\n{repr(answer['content'])}") + logging.error(f"failed to parse the fixed answer: {pp(err)}\n{repr(answer.content)}") retries -= 1 continue @@ -409,7 +413,7 @@ class AIResponder(object): # Post-process the message and update the answer's content answer_message = await self.post_process(message, response) - answer['content'] = str(answer_message) + answer.content = str(answer_message) # Update message history self.update_history(messages[-1], answer, limit, message.historise_question)