Add possibility to set retry-model, if AI request failes because of limits. Improve error handling in leanardo AI drawing.
This commit is contained in:
parent
17386af950
commit
95bc6ce041
@ -33,7 +33,7 @@ def parse_json(content: str) -> Dict:
|
|||||||
raise err
|
raise err
|
||||||
|
|
||||||
|
|
||||||
def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1):
|
def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1, max_attempts=None):
|
||||||
"""Generate sleep intervals for exponential backoff with jitter.
|
"""Generate sleep intervals for exponential backoff with jitter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -52,6 +52,8 @@ def exponential_backoff(base=2, max_delay=60, factor=1, jitter=0.1):
|
|||||||
sleep += random.uniform(-jitter_amount, jitter_amount)
|
sleep += random.uniform(-jitter_amount, jitter_amount)
|
||||||
yield sleep
|
yield sleep
|
||||||
attempt += 1
|
attempt += 1
|
||||||
|
if max_attempts is not None and attempt > max_attempts:
|
||||||
|
raise RuntimeError("Max attempts reached in exponential backoff.")
|
||||||
|
|
||||||
|
|
||||||
def parse_maybe_json(json_string):
|
def parse_maybe_json(json_string):
|
||||||
@ -158,36 +160,51 @@ class AIResponder(object):
|
|||||||
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
||||||
|
|
||||||
async def _draw_leonardo(self, description: str) -> BytesIO:
|
async def _draw_leonardo(self, description: str) -> BytesIO:
|
||||||
for _ in range(3):
|
error_backoff = exponential_backoff(max_attempts=12)
|
||||||
|
generation_id = None
|
||||||
|
image_url = None
|
||||||
|
image_bytes = None
|
||||||
|
while True:
|
||||||
|
error_sleep = next(error_backoff)
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
|
if generation_id is None:
|
||||||
json={"prompt": description,
|
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
|
||||||
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
|
json={"prompt": description,
|
||||||
"num_images": 1,
|
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
|
||||||
"sd_version": "v2",
|
"num_images": 1,
|
||||||
"promptMagic": True,
|
"sd_version": "v2",
|
||||||
"width": 512,
|
"promptMagic": True,
|
||||||
"height": 512},
|
"unzoomAmount": 1,
|
||||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
"width": 512,
|
||||||
"Accept": "application/json",
|
"height": 512},
|
||||||
"Content-Type": "application/json"},
|
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
||||||
) as response:
|
"Accept": "application/json",
|
||||||
response = await response.json()
|
"Content-Type": "application/json"},
|
||||||
generation_id = response["sdGenerationJob"]["generationId"]
|
) as response:
|
||||||
while True:
|
response = await response.json()
|
||||||
|
if "sdGenerationJob" not in response:
|
||||||
|
logging.warning(f"No 'sdGenerationJob' found in response: {repr(response)}")
|
||||||
|
await asyncio.sleep(error_sleep)
|
||||||
|
continue
|
||||||
|
generation_id = response["sdGenerationJob"]["generationId"]
|
||||||
|
if image_url is None:
|
||||||
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
||||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
||||||
"Accept": "application/json"},
|
"Accept": "application/json"},
|
||||||
) as response:
|
) as response:
|
||||||
response = await response.json()
|
response = await response.json()
|
||||||
|
if "generations_by_pk" not in response:
|
||||||
|
logging.warning(f"Unexpected response: {repr(response)}")
|
||||||
|
await asyncio.sleep(error_sleep)
|
||||||
|
continue
|
||||||
if len(response["generations_by_pk"]["generated_images"]) == 0:
|
if len(response["generations_by_pk"]["generated_images"]) == 0:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(error_sleep)
|
||||||
continue
|
continue
|
||||||
image_url = response["generations_by_pk"]["generated_images"][0]["url"]
|
image_url = response["generations_by_pk"]["generated_images"][0]["url"]
|
||||||
break
|
if image_bytes is None:
|
||||||
async with session.get(image_url) as response:
|
async with session.get(image_url) as response:
|
||||||
image_bytes = BytesIO(await response.read())
|
image_bytes = BytesIO(await response.read())
|
||||||
async with session.delete(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
async with session.delete(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
||||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}"},
|
headers={"Authorization": f"Bearer {self.config['leonardo-token']}"},
|
||||||
) as response:
|
) as response:
|
||||||
@ -195,7 +212,9 @@ class AIResponder(object):
|
|||||||
logging.info(f'Drawed a picture with leonardo AI on this description: {repr(description)}')
|
logging.info(f'Drawed a picture with leonardo AI on this description: {repr(description)}')
|
||||||
return image_bytes
|
return image_bytes
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logging.warning(f"Failed to generate image: {repr(err)}")
|
logging.warning(f"Failed to generate image: {repr(description)}\n{repr(err)}")
|
||||||
|
else:
|
||||||
|
logging.warning(f"Failed to generate image: {repr(description)}")
|
||||||
raise RuntimeError(f"Failed to generate image {repr(description)}")
|
raise RuntimeError(f"Failed to generate image {repr(description)}")
|
||||||
|
|
||||||
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
|
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
|
||||||
@ -245,8 +264,9 @@ class AIResponder(object):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
||||||
|
model = self.config["model"]
|
||||||
try:
|
try:
|
||||||
result = await openai.ChatCompletion.acreate(model=self.config["model"],
|
result = await openai.ChatCompletion.acreate(model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=self.config["temperature"],
|
temperature=self.config["temperature"],
|
||||||
max_tokens=self.config["max-tokens"],
|
max_tokens=self.config["max-tokens"],
|
||||||
@ -267,6 +287,8 @@ class AIResponder(object):
|
|||||||
raise err
|
raise err
|
||||||
except openai.error.RateLimitError as err:
|
except openai.error.RateLimitError as err:
|
||||||
rate_limit_sleep = next(self.rate_limit_backoff)
|
rate_limit_sleep = next(self.rate_limit_backoff)
|
||||||
|
if "retry-model" in self.config:
|
||||||
|
model = self.config["retry-model"]
|
||||||
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
|
logging.warning(f"got an rate limit error, sleep for {rate_limit_sleep} seconds: {str(err)}")
|
||||||
await asyncio.sleep(rate_limit_sleep)
|
await asyncio.sleep(rate_limit_sleep)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user