Compare commits

...

2 Commits

4 changed files with 34 additions and 7 deletions

View File

@ -1,4 +1,6 @@
import json import json
import asyncio
import random
import multiline import multiline
import openai import openai
import aiohttp import aiohttp
@ -31,6 +33,27 @@ def parse_json(content: str) -> Dict:
raise err raise err
def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1):
"""Generate sleep intervals for exponential backoff with jitter.
Args:
base: Base of the exponentiation operation
max_delay: Maximum delay
factor: Multiplication factor for each increase in backoff
jitter: Additional randomness range to prevent thundering herd problem
Yields:
Delay for backoff as a floating point number.
"""
attempt = 0
while True:
sleep = min(max_delay, factor * base ** attempt)
jitter_amount = jitter * sleep
sleep += random.uniform(-jitter_amount, jitter_amount)
yield sleep
attempt += 1
def parse_maybe_json(json_string): def parse_maybe_json(json_string):
if json_string is None: if json_string is None:
return None return None
@ -94,6 +117,7 @@ class AIResponder(object):
self.history: List[Dict[str, Any]] = [] self.history: List[Dict[str, Any]] = []
self.channel = channel if channel is not None else 'system' self.channel = channel if channel is not None else 'system'
openai.api_key = self.config['openai-token'] openai.api_key = self.config['openai-token']
self.rate_limit_backoff = exponential_backoff()
self.history_file: Optional[Path] = None self.history_file: Optional[Path] = None
if 'history-directory' in self.config: if 'history-directory' in self.config:
self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat' self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat'
@ -182,7 +206,8 @@ class AIResponder(object):
answer = result['choices'][0]['message'] answer = result['choices'][0]['message']
if type(answer) != dict: if type(answer) != dict:
answer = answer.to_dict() answer = answer.to_dict()
logging.info(f"generated response: {repr(answer)}") self.rate_limit_backoff = exponential_backoff()
logging.info(f"generated response {result.get('usage')}: {repr(answer)}")
return answer, limit return answer, limit
except openai.error.InvalidRequestError as err: except openai.error.InvalidRequestError as err:
if 'maximum context length is' in str(err) and limit > 4: if 'maximum context length is' in str(err) and limit > 4:
@ -190,6 +215,10 @@ class AIResponder(object):
limit -= 1 limit -= 1
return None, limit return None, limit
raise err raise err
except openai.error.RateLimitError as err:
rate_limit_sleep = next(self.rate_limit_backoff)
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
await asyncio.sleep(rate_limit_sleep)
except Exception as err: except Exception as err:
logging.warning(f"failed to generate response: {repr(err)}") logging.warning(f"failed to generate response: {repr(err)}")
return None, limit return None, limit

View File

@ -1,6 +1,6 @@
import sys import sys
import argparse import argparse
import toml import tomlkit
import discord import discord
import logging import logging
from discord import Message, TextChannel, DMChannel from discord import Message, TextChannel, DMChannel
@ -76,7 +76,7 @@ class FjerkroaBot(commands.Bot):
@classmethod @classmethod
def load_config(self, config_file: str = "config.toml"): def load_config(self, config_file: str = "config.toml"):
with open(config_file, encoding='utf-8') as file: with open(config_file, encoding='utf-8') as file:
return toml.load(file) return tomlkit.load(file)
def channel_by_name(self, def channel_by_name(self,
channel_name: Optional[str], channel_name: Optional[str],

View File

@ -43,5 +43,4 @@ pytest = "*"
setuptools = "*" setuptools = "*"
wheel = "*" wheel = "*"
watchdog = "*" watchdog = "*"
toml = "*" tomlkit = "*"
types-toml = "*"

View File

@ -8,6 +8,5 @@ pytest
setuptools setuptools
wheel wheel
watchdog watchdog
toml tomlkit
types-toml
multiline multiline