From 72c7d837666997bb33d790de5f8682a419115cf1 Mon Sep 17 00:00:00 2001 From: Fjerkroa Auto Date: Fri, 24 Mar 2023 15:29:13 +0100 Subject: [PATCH] Fix error handling when JSON can not be parsed. --- fjerkroa_bot/ai_responder.py | 62 +++++++++++++++++++++--------------- tests/test_main.py | 5 ++- 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index 902f17c..45a894a 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -6,7 +6,7 @@ import time import re from io import BytesIO from pprint import pformat -from typing import Optional, List, Dict, Any +from typing import Optional, List, Dict, Any, Tuple class AIMessageBase(object): @@ -85,47 +85,57 @@ class AIResponder(object): response['picture'], response['hack']) - def short_path(self, message: AIMessage) -> bool: + def short_path(self, message: AIMessage, limit: int) -> bool: if 'short-path' not in self.config: return False for chan_re, user_re in self.config['short-path']: chan_ma = re.match(chan_re, message.channel) user_ma = re.match(user_re, message.user) if chan_ma and user_ma: + self.history.append({"role": "user", "content": str(message)}) + self.history = self.history[-limit:] return True return False + async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]: + try: + result = await openai.ChatCompletion.acreate(model=self.config["model"], + messages=messages, + temperature=self.config["temperature"], + top_p=self.config["top-p"], + presence_penalty=self.config["presence-penalty"], + frequency_penalty=self.config["frequency-penalty"]) + answer = result['choices'][0]['message'] + if type(answer) != dict: + answer = answer.to_dict() + return answer, limit + except openai.error.InvalidRequestError as err: + if 'maximum context length is' in str(err) and limit > 4: + limit -= 1 + return None, limit + raise err + except Exception as err: + logging.warning(f"failed to generate response: {repr(err)}") + return None, limit + async def send(self, message: AIMessage) -> AIResponse: limit = self.config["history-limit"] - if self.short_path(message): - self.history.append({"role": "user", "content": str(message)}) - self.history = self.history[-limit:] + if self.short_path(message, limit): return AIResponse(None, False, None, None, False) for _ in range(14): messages = self._message(message, limit) logging.info(f"try to send this messages:\n{pformat(messages)}") - try: - result = await openai.ChatCompletion.acreate(model=self.config["model"], - messages=messages, - temperature=self.config["temperature"], - top_p=self.config["top-p"], - presence_penalty=self.config["presence-penalty"], - frequency_penalty=self.config["frequency-penalty"]) - answer = result['choices'][0]['message'] - if type(answer) != dict: - answer = answer.to_dict() - response = json.loads(answer['content']) - if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str): - continue - logging.info(f"got this answer:\n{pformat(response)}") - except openai.error.InvalidRequestError as err: - if 'maximum context length is' in str(err) and limit > 4: - limit -= 1 - continue - raise err - except Exception as err: - logging.warning(f"failed to generate response: {repr(err)}") + answer, limit = await self._acreate(messages, limit) + if answer is None: continue + try: + response = json.loads(answer['content']) + except Exception as err: + logging.error(f"failed to parse the answer: {pformat(err)}\n{pformat(answer['content'])}") + return AIResponse(None, False, f"ERROR: I could not parse this answer: {pformat(answer['content'])}", None, False) + if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str): + continue + logging.info(f"got this answer:\n{pformat(response)}") self.history.append(messages[-1]) self.history.append(answer) if len(self.history) > limit: diff --git a/tests/test_main.py b/tests/test_main.py index 219dcfb..62fe0f8 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,5 @@ import os import unittest -import pytest import aiohttp import json import toml @@ -118,8 +117,8 @@ class TestFunctionality(TestBotBase): message = self.create_message("Hello there! How are you?") with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \ patch.object(logging, 'warning', logging_warning): - with pytest.raises(RuntimeError, match='failed.*JSONDecodeError.*'): - await self.bot.on_message(message) + await self.bot.on_message(message) + self.bot.staff_channel.send.assert_called_once_with("ERROR: I could not parse this answer: '{ \"test\": 3 ]'", suppress_embeds=True) async def test_on_message_event4(self) -> None: async def acreate(*a, **kw):