Fix error handling when JSON can not be parsed.

This commit is contained in:
Fjerkroa Auto 2023-03-24 15:29:13 +01:00
parent 78591ef13a
commit 72c7d83766
2 changed files with 38 additions and 29 deletions

View File

@ -6,7 +6,7 @@ import time
import re import re
from io import BytesIO from io import BytesIO
from pprint import pformat from pprint import pformat
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any, Tuple
class AIMessageBase(object): class AIMessageBase(object):
@ -85,25 +85,19 @@ class AIResponder(object):
response['picture'], response['picture'],
response['hack']) response['hack'])
def short_path(self, message: AIMessage) -> bool: def short_path(self, message: AIMessage, limit: int) -> bool:
if 'short-path' not in self.config: if 'short-path' not in self.config:
return False return False
for chan_re, user_re in self.config['short-path']: for chan_re, user_re in self.config['short-path']:
chan_ma = re.match(chan_re, message.channel) chan_ma = re.match(chan_re, message.channel)
user_ma = re.match(user_re, message.user) user_ma = re.match(user_re, message.user)
if chan_ma and user_ma: if chan_ma and user_ma:
self.history.append({"role": "user", "content": str(message)})
self.history = self.history[-limit:]
return True return True
return False return False
async def send(self, message: AIMessage) -> AIResponse: async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
limit = self.config["history-limit"]
if self.short_path(message):
self.history.append({"role": "user", "content": str(message)})
self.history = self.history[-limit:]
return AIResponse(None, False, None, None, False)
for _ in range(14):
messages = self._message(message, limit)
logging.info(f"try to send this messages:\n{pformat(messages)}")
try: try:
result = await openai.ChatCompletion.acreate(model=self.config["model"], result = await openai.ChatCompletion.acreate(model=self.config["model"],
messages=messages, messages=messages,
@ -114,18 +108,34 @@ class AIResponder(object):
answer = result['choices'][0]['message'] answer = result['choices'][0]['message']
if type(answer) != dict: if type(answer) != dict:
answer = answer.to_dict() answer = answer.to_dict()
response = json.loads(answer['content']) return answer, limit
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
continue
logging.info(f"got this answer:\n{pformat(response)}")
except openai.error.InvalidRequestError as err: except openai.error.InvalidRequestError as err:
if 'maximum context length is' in str(err) and limit > 4: if 'maximum context length is' in str(err) and limit > 4:
limit -= 1 limit -= 1
continue return None, limit
raise err raise err
except Exception as err: except Exception as err:
logging.warning(f"failed to generate response: {repr(err)}") logging.warning(f"failed to generate response: {repr(err)}")
return None, limit
async def send(self, message: AIMessage) -> AIResponse:
limit = self.config["history-limit"]
if self.short_path(message, limit):
return AIResponse(None, False, None, None, False)
for _ in range(14):
messages = self._message(message, limit)
logging.info(f"try to send this messages:\n{pformat(messages)}")
answer, limit = await self._acreate(messages, limit)
if answer is None:
continue continue
try:
response = json.loads(answer['content'])
except Exception as err:
logging.error(f"failed to parse the answer: {pformat(err)}\n{pformat(answer['content'])}")
return AIResponse(None, False, f"ERROR: I could not parse this answer: {pformat(answer['content'])}", None, False)
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
continue
logging.info(f"got this answer:\n{pformat(response)}")
self.history.append(messages[-1]) self.history.append(messages[-1])
self.history.append(answer) self.history.append(answer)
if len(self.history) > limit: if len(self.history) > limit:

View File

@ -1,6 +1,5 @@
import os import os
import unittest import unittest
import pytest
import aiohttp import aiohttp
import json import json
import toml import toml
@ -118,8 +117,8 @@ class TestFunctionality(TestBotBase):
message = self.create_message("Hello there! How are you?") message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \ with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
patch.object(logging, 'warning', logging_warning): patch.object(logging, 'warning', logging_warning):
with pytest.raises(RuntimeError, match='failed.*JSONDecodeError.*'):
await self.bot.on_message(message) await self.bot.on_message(message)
self.bot.staff_channel.send.assert_called_once_with("ERROR: I could not parse this answer: '{ \"test\": 3 ]'", suppress_embeds=True)
async def test_on_message_event4(self) -> None: async def test_on_message_event4(self) -> None:
async def acreate(*a, **kw): async def acreate(*a, **kw):