Fix error handling when JSON can not be parsed.
This commit is contained in:
parent
78591ef13a
commit
72c7d83766
@ -6,7 +6,7 @@ import time
|
|||||||
import re
|
import re
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any, Tuple
|
||||||
|
|
||||||
|
|
||||||
class AIMessageBase(object):
|
class AIMessageBase(object):
|
||||||
@ -85,47 +85,57 @@ class AIResponder(object):
|
|||||||
response['picture'],
|
response['picture'],
|
||||||
response['hack'])
|
response['hack'])
|
||||||
|
|
||||||
def short_path(self, message: AIMessage) -> bool:
|
def short_path(self, message: AIMessage, limit: int) -> bool:
|
||||||
if 'short-path' not in self.config:
|
if 'short-path' not in self.config:
|
||||||
return False
|
return False
|
||||||
for chan_re, user_re in self.config['short-path']:
|
for chan_re, user_re in self.config['short-path']:
|
||||||
chan_ma = re.match(chan_re, message.channel)
|
chan_ma = re.match(chan_re, message.channel)
|
||||||
user_ma = re.match(user_re, message.user)
|
user_ma = re.match(user_re, message.user)
|
||||||
if chan_ma and user_ma:
|
if chan_ma and user_ma:
|
||||||
|
self.history.append({"role": "user", "content": str(message)})
|
||||||
|
self.history = self.history[-limit:]
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
||||||
|
try:
|
||||||
|
result = await openai.ChatCompletion.acreate(model=self.config["model"],
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.config["temperature"],
|
||||||
|
top_p=self.config["top-p"],
|
||||||
|
presence_penalty=self.config["presence-penalty"],
|
||||||
|
frequency_penalty=self.config["frequency-penalty"])
|
||||||
|
answer = result['choices'][0]['message']
|
||||||
|
if type(answer) != dict:
|
||||||
|
answer = answer.to_dict()
|
||||||
|
return answer, limit
|
||||||
|
except openai.error.InvalidRequestError as err:
|
||||||
|
if 'maximum context length is' in str(err) and limit > 4:
|
||||||
|
limit -= 1
|
||||||
|
return None, limit
|
||||||
|
raise err
|
||||||
|
except Exception as err:
|
||||||
|
logging.warning(f"failed to generate response: {repr(err)}")
|
||||||
|
return None, limit
|
||||||
|
|
||||||
async def send(self, message: AIMessage) -> AIResponse:
|
async def send(self, message: AIMessage) -> AIResponse:
|
||||||
limit = self.config["history-limit"]
|
limit = self.config["history-limit"]
|
||||||
if self.short_path(message):
|
if self.short_path(message, limit):
|
||||||
self.history.append({"role": "user", "content": str(message)})
|
|
||||||
self.history = self.history[-limit:]
|
|
||||||
return AIResponse(None, False, None, None, False)
|
return AIResponse(None, False, None, None, False)
|
||||||
for _ in range(14):
|
for _ in range(14):
|
||||||
messages = self._message(message, limit)
|
messages = self._message(message, limit)
|
||||||
logging.info(f"try to send this messages:\n{pformat(messages)}")
|
logging.info(f"try to send this messages:\n{pformat(messages)}")
|
||||||
try:
|
answer, limit = await self._acreate(messages, limit)
|
||||||
result = await openai.ChatCompletion.acreate(model=self.config["model"],
|
if answer is None:
|
||||||
messages=messages,
|
|
||||||
temperature=self.config["temperature"],
|
|
||||||
top_p=self.config["top-p"],
|
|
||||||
presence_penalty=self.config["presence-penalty"],
|
|
||||||
frequency_penalty=self.config["frequency-penalty"])
|
|
||||||
answer = result['choices'][0]['message']
|
|
||||||
if type(answer) != dict:
|
|
||||||
answer = answer.to_dict()
|
|
||||||
response = json.loads(answer['content'])
|
|
||||||
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
|
|
||||||
continue
|
|
||||||
logging.info(f"got this answer:\n{pformat(response)}")
|
|
||||||
except openai.error.InvalidRequestError as err:
|
|
||||||
if 'maximum context length is' in str(err) and limit > 4:
|
|
||||||
limit -= 1
|
|
||||||
continue
|
|
||||||
raise err
|
|
||||||
except Exception as err:
|
|
||||||
logging.warning(f"failed to generate response: {repr(err)}")
|
|
||||||
continue
|
continue
|
||||||
|
try:
|
||||||
|
response = json.loads(answer['content'])
|
||||||
|
except Exception as err:
|
||||||
|
logging.error(f"failed to parse the answer: {pformat(err)}\n{pformat(answer['content'])}")
|
||||||
|
return AIResponse(None, False, f"ERROR: I could not parse this answer: {pformat(answer['content'])}", None, False)
|
||||||
|
if 'hack' not in response or type(response.get('picture', None)) not in (type(None), str):
|
||||||
|
continue
|
||||||
|
logging.info(f"got this answer:\n{pformat(response)}")
|
||||||
self.history.append(messages[-1])
|
self.history.append(messages[-1])
|
||||||
self.history.append(answer)
|
self.history.append(answer)
|
||||||
if len(self.history) > limit:
|
if len(self.history) > limit:
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import pytest
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import json
|
import json
|
||||||
import toml
|
import toml
|
||||||
@ -118,8 +117,8 @@ class TestFunctionality(TestBotBase):
|
|||||||
message = self.create_message("Hello there! How are you?")
|
message = self.create_message("Hello there! How are you?")
|
||||||
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
|
with patch.object(openai.ChatCompletion, 'acreate', new=acreate), \
|
||||||
patch.object(logging, 'warning', logging_warning):
|
patch.object(logging, 'warning', logging_warning):
|
||||||
with pytest.raises(RuntimeError, match='failed.*JSONDecodeError.*'):
|
await self.bot.on_message(message)
|
||||||
await self.bot.on_message(message)
|
self.bot.staff_channel.send.assert_called_once_with("ERROR: I could not parse this answer: '{ \"test\": 3 ]'", suppress_embeds=True)
|
||||||
|
|
||||||
async def test_on_message_event4(self) -> None:
|
async def test_on_message_event4(self) -> None:
|
||||||
async def acreate(*a, **kw):
|
async def acreate(*a, **kw):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user