Compare commits
2 Commits
112f03a47a
...
7c52896ad9
| Author | SHA1 | Date | |
|---|---|---|---|
| 7c52896ad9 | |||
| 6d9fa0e718 |
@ -1,4 +1,6 @@
|
||||
import json
|
||||
import asyncio
|
||||
import random
|
||||
import multiline
|
||||
import openai
|
||||
import aiohttp
|
||||
@ -31,6 +33,27 @@ def parse_json(content: str) -> Dict:
|
||||
raise err
|
||||
|
||||
|
||||
def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1):
|
||||
"""Generate sleep intervals for exponential backoff with jitter.
|
||||
|
||||
Args:
|
||||
base: Base of the exponentiation operation
|
||||
max_delay: Maximum delay
|
||||
factor: Multiplication factor for each increase in backoff
|
||||
jitter: Additional randomness range to prevent thundering herd problem
|
||||
|
||||
Yields:
|
||||
Delay for backoff as a floating point number.
|
||||
"""
|
||||
attempt = 0
|
||||
while True:
|
||||
sleep = min(max_delay, factor * base ** attempt)
|
||||
jitter_amount = jitter * sleep
|
||||
sleep += random.uniform(-jitter_amount, jitter_amount)
|
||||
yield sleep
|
||||
attempt += 1
|
||||
|
||||
|
||||
def parse_maybe_json(json_string):
|
||||
if json_string is None:
|
||||
return None
|
||||
@ -94,6 +117,7 @@ class AIResponder(object):
|
||||
self.history: List[Dict[str, Any]] = []
|
||||
self.channel = channel if channel is not None else 'system'
|
||||
openai.api_key = self.config['openai-token']
|
||||
self.rate_limit_backoff = exponential_backoff()
|
||||
self.history_file: Optional[Path] = None
|
||||
if 'history-directory' in self.config:
|
||||
self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat'
|
||||
@ -182,7 +206,8 @@ class AIResponder(object):
|
||||
answer = result['choices'][0]['message']
|
||||
if type(answer) != dict:
|
||||
answer = answer.to_dict()
|
||||
logging.info(f"generated response: {repr(answer)}")
|
||||
self.rate_limit_backoff = exponential_backoff()
|
||||
logging.info(f"generated response {result.get('usage')}: {repr(answer)}")
|
||||
return answer, limit
|
||||
except openai.error.InvalidRequestError as err:
|
||||
if 'maximum context length is' in str(err) and limit > 4:
|
||||
@ -190,6 +215,10 @@ class AIResponder(object):
|
||||
limit -= 1
|
||||
return None, limit
|
||||
raise err
|
||||
except openai.error.RateLimitError as err:
|
||||
rate_limit_sleep = next(self.rate_limit_backoff)
|
||||
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
|
||||
await asyncio.sleep(rate_limit_sleep)
|
||||
except Exception as err:
|
||||
logging.warning(f"failed to generate response: {repr(err)}")
|
||||
return None, limit
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import sys
|
||||
import argparse
|
||||
import toml
|
||||
import tomlkit
|
||||
import discord
|
||||
import logging
|
||||
from discord import Message, TextChannel, DMChannel
|
||||
@ -76,7 +76,7 @@ class FjerkroaBot(commands.Bot):
|
||||
@classmethod
|
||||
def load_config(self, config_file: str = "config.toml"):
|
||||
with open(config_file, encoding='utf-8') as file:
|
||||
return toml.load(file)
|
||||
return tomlkit.load(file)
|
||||
|
||||
def channel_by_name(self,
|
||||
channel_name: Optional[str],
|
||||
|
||||
@ -43,5 +43,4 @@ pytest = "*"
|
||||
setuptools = "*"
|
||||
wheel = "*"
|
||||
watchdog = "*"
|
||||
toml = "*"
|
||||
types-toml = "*"
|
||||
tomlkit = "*"
|
||||
|
||||
@ -8,6 +8,5 @@ pytest
|
||||
setuptools
|
||||
wheel
|
||||
watchdog
|
||||
toml
|
||||
types-toml
|
||||
tomlkit
|
||||
multiline
|
||||
|
||||
Loading…
Reference in New Issue
Block a user