Compare commits
2 Commits
112f03a47a
...
7c52896ad9
| Author | SHA1 | Date | |
|---|---|---|---|
| 7c52896ad9 | |||
| 6d9fa0e718 |
@ -1,4 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
|
import asyncio
|
||||||
|
import random
|
||||||
import multiline
|
import multiline
|
||||||
import openai
|
import openai
|
||||||
import aiohttp
|
import aiohttp
|
||||||
@ -31,6 +33,27 @@ def parse_json(content: str) -> Dict:
|
|||||||
raise err
|
raise err
|
||||||
|
|
||||||
|
|
||||||
|
def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1):
|
||||||
|
"""Generate sleep intervals for exponential backoff with jitter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base: Base of the exponentiation operation
|
||||||
|
max_delay: Maximum delay
|
||||||
|
factor: Multiplication factor for each increase in backoff
|
||||||
|
jitter: Additional randomness range to prevent thundering herd problem
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Delay for backoff as a floating point number.
|
||||||
|
"""
|
||||||
|
attempt = 0
|
||||||
|
while True:
|
||||||
|
sleep = min(max_delay, factor * base ** attempt)
|
||||||
|
jitter_amount = jitter * sleep
|
||||||
|
sleep += random.uniform(-jitter_amount, jitter_amount)
|
||||||
|
yield sleep
|
||||||
|
attempt += 1
|
||||||
|
|
||||||
|
|
||||||
def parse_maybe_json(json_string):
|
def parse_maybe_json(json_string):
|
||||||
if json_string is None:
|
if json_string is None:
|
||||||
return None
|
return None
|
||||||
@ -94,6 +117,7 @@ class AIResponder(object):
|
|||||||
self.history: List[Dict[str, Any]] = []
|
self.history: List[Dict[str, Any]] = []
|
||||||
self.channel = channel if channel is not None else 'system'
|
self.channel = channel if channel is not None else 'system'
|
||||||
openai.api_key = self.config['openai-token']
|
openai.api_key = self.config['openai-token']
|
||||||
|
self.rate_limit_backoff = exponential_backoff()
|
||||||
self.history_file: Optional[Path] = None
|
self.history_file: Optional[Path] = None
|
||||||
if 'history-directory' in self.config:
|
if 'history-directory' in self.config:
|
||||||
self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat'
|
self.history_file = Path(self.config['history-directory']).expanduser() / f'{self.channel}.dat'
|
||||||
@ -182,7 +206,8 @@ class AIResponder(object):
|
|||||||
answer = result['choices'][0]['message']
|
answer = result['choices'][0]['message']
|
||||||
if type(answer) != dict:
|
if type(answer) != dict:
|
||||||
answer = answer.to_dict()
|
answer = answer.to_dict()
|
||||||
logging.info(f"generated response: {repr(answer)}")
|
self.rate_limit_backoff = exponential_backoff()
|
||||||
|
logging.info(f"generated response {result.get('usage')}: {repr(answer)}")
|
||||||
return answer, limit
|
return answer, limit
|
||||||
except openai.error.InvalidRequestError as err:
|
except openai.error.InvalidRequestError as err:
|
||||||
if 'maximum context length is' in str(err) and limit > 4:
|
if 'maximum context length is' in str(err) and limit > 4:
|
||||||
@ -190,6 +215,10 @@ class AIResponder(object):
|
|||||||
limit -= 1
|
limit -= 1
|
||||||
return None, limit
|
return None, limit
|
||||||
raise err
|
raise err
|
||||||
|
except openai.error.RateLimitError as err:
|
||||||
|
rate_limit_sleep = next(self.rate_limit_backoff)
|
||||||
|
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
|
||||||
|
await asyncio.sleep(rate_limit_sleep)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logging.warning(f"failed to generate response: {repr(err)}")
|
logging.warning(f"failed to generate response: {repr(err)}")
|
||||||
return None, limit
|
return None, limit
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
import toml
|
import tomlkit
|
||||||
import discord
|
import discord
|
||||||
import logging
|
import logging
|
||||||
from discord import Message, TextChannel, DMChannel
|
from discord import Message, TextChannel, DMChannel
|
||||||
@ -76,7 +76,7 @@ class FjerkroaBot(commands.Bot):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def load_config(self, config_file: str = "config.toml"):
|
def load_config(self, config_file: str = "config.toml"):
|
||||||
with open(config_file, encoding='utf-8') as file:
|
with open(config_file, encoding='utf-8') as file:
|
||||||
return toml.load(file)
|
return tomlkit.load(file)
|
||||||
|
|
||||||
def channel_by_name(self,
|
def channel_by_name(self,
|
||||||
channel_name: Optional[str],
|
channel_name: Optional[str],
|
||||||
|
|||||||
@ -43,5 +43,4 @@ pytest = "*"
|
|||||||
setuptools = "*"
|
setuptools = "*"
|
||||||
wheel = "*"
|
wheel = "*"
|
||||||
watchdog = "*"
|
watchdog = "*"
|
||||||
toml = "*"
|
tomlkit = "*"
|
||||||
types-toml = "*"
|
|
||||||
|
|||||||
@ -8,6 +8,5 @@ pytest
|
|||||||
setuptools
|
setuptools
|
||||||
wheel
|
wheel
|
||||||
watchdog
|
watchdog
|
||||||
toml
|
tomlkit
|
||||||
types-toml
|
|
||||||
multiline
|
multiline
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user