Fix error handling when JSON can not be parsed.

This commit is contained in:
Fjerkroa Auto 2023-03-24 15:29:13 +01:00
parent 78591ef13a
commit 72c7d83766
2 changed files with 38 additions and 29 deletions

View File

@ -6,7 +6,7 @@ import time
import re import re
from io import BytesIO from io import BytesIO
from pprint import pformat from pprint import pformat
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any, Tuple
class AIMessageBase(object): class AIMessageBase(object):
@ -85,47 +85,57 @@ class AIResponder(object):
response['picture'], response['picture'],
response['hack']) response['hack'])
def short_path(self, message: AIMessage) -> bool: def short_path(self, message: AIMessage, limit: int) -> bool:
if 'short-path' not in self.config: if 'short-path' not in self.config:
return False return False
for chan_re, user_re in self.config['short-path']: for chan_re, user_re in self.config['short-path']:
chan_ma = re.match(chan_re, message.channel) chan_ma = re.match(chan_re, message.channel)
user_ma = re.match(user_re, message.user) user_ma = re.match(user_re, message.user)
if chan_ma and user_ma: if chan_ma and user_ma:
self.history.append({"role": "user", "content": str(message)})
self.history = self.history[-limit:]
return True return True
return False return False
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
try:
result = await openai.ChatCompletion.acreate(model=self.config["model"],
messages=messages,
temperature=self.config["temperature"],
top_p=self.config["top-p"],
presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"])
answer = result['choices'][0]['message']
if type(answer) != dict:
answer = answer.to_dict()
return answer, limit
except openai.error.InvalidRequestError as err:
if 'maximum context length is' in str(err) and limit > 4:
limit -= 1
return None, limit
raise err
except Exception as err:
logging.warning(f"failed to generate response: {repr(err)}")
return None, limit
async def send(self, message: AIMessage) -> AIResponse: async def send(self, message: AIMessage) -> AIResponse:
limit = self.config["history-limit"] limit = self.config["history-limit"]
if self.short_path(message): if self.short_path(message, limit):
self.history.append({"role": "user", "content": str(message)})
self.history = self.history[-limit:]
return AIResponse(None, False, None, None, False) return AIResponse(None, False, None, None, False)
for _ in range(14): for _ in range(14):
messages = self._message(message, limit) messages = self._message(message, limit)
logging.info(f"try to send this messages:\n{pformat(messages)}") logging.info(f"try to send this messages:\n{pformat(messages)}")
try: answer, limit = await self._acreate(messages, limit)
result = await openai.ChatCompletion.acreate(model=self.config["model"], if answer is None:
messages=messages,
temperature=self.config["temperature"],
top_p=self.config["top-p"],
presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"])
answer = result['choices'][0]['message']
if type(answer) != dict:
answer = answer.to_dict()
response = json.loads(answer['content'])
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
continue
logging.info(f"got this answer:\n{pformat(response)}")
except openai.error.InvalidRequestError as err:
if 'maximum context length is' in str(err) and limit > 4:
limit -= 1
continue
raise err
except Exception as err:
logging.warning(f"failed to generate response: {repr(err)}")
continue continue
try:
response = json.loads(answer['content'])
except Exception as err:
logging.error(f"failed to parse the answer: {pformat(err)}\n{pformat(answer['content'])}")
return AIResponse(None, False, f"ERROR: I could not parse this answer: {pformat(answer['content'])}", None, False)
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
continue
logging.info(f"got this answer:\n{pformat(response)}")
self.history.append(messages[-1]) self.history.append(messages[-1])
self.history.append(answer) self.history.append(answer)
if len(self.history) > limit: if len(self.history) > limit:

View File

@ -1,6 +1,5 @@
import os import os
import unittest import unittest
import pytest
import aiohttp import aiohttp
import json import json
import toml import toml
@ -118,8 +117,8 @@ class TestFunctionality(TestBotBase):
message = self.create_message("Hello there! How are you?") message = self.create_message("Hello there! How are you?")
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \ with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
patch.object(logging, 'warning', logging_warning): patch.object(logging, 'warning', logging_warning):
with pytest.raises(RuntimeError, match='failed.*JSONDecodeError.*'): await self.bot.on_message(message)
await self.bot.on_message(message) self.bot.staff_channel.send.assert_called_once_with("ERROR: I could not parse this answer: '{ \"test\": 3 ]'", suppress_embeds=True)
async def test_on_message_event4(self) -> None: async def test_on_message_event4(self) -> None:
async def acreate(*a, **kw): async def acreate(*a, **kw):