diff --git a/fjerkroa_bot/ai_responder.py b/fjerkroa_bot/ai_responder.py index faed325..cafdf55 100644 --- a/fjerkroa_bot/ai_responder.py +++ b/fjerkroa_bot/ai_responder.py @@ -140,6 +140,11 @@ class AIResponder(object): return messages async def draw(self, description: str) -> BytesIO: + if self.config.get('leonardo-token') is not None: + return await self._draw_leonardo(description) + return await self._draw_openai(description) + + async def _draw_openai(self, description: str) -> BytesIO: for _ in range(3): try: response = await openai.Image.acreate(prompt=description, n=1, size="512x512") @@ -150,6 +155,32 @@ class AIResponder(object): logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}") raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries") + async def _draw_leonardo(self, description: str) -> BytesIO: + for _ in range(3): + try: + async with aiohttp.ClientSession() as session: + async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations", + json={"prompt": description, + "modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3", + "num_images": 1, + "width": 512, + "height": 512}, + headers={"Authorization": f"Bearer {self.config['leonardo-token']}", + "Accept": "application/json", + "Content-Type": "application/json"}, + ) as response: + generation_id = (await response.json())["sdGenerationJob"]["generationId"] + async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}", + headers={"Authorization": f"Bearer {self.config['leonardo-token']}", + "Accept": "application/json"}, + ) as response: + image_url = (await response.json())["generations_by_pk"]["generated_images"][0]["url"] + async with session.get(image_url) as response: + return BytesIO(await response.read()) + except Exception as err: + logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}") + raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries") + async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse: for fld in ('answer', 'channel', 'staff', 'picture', 'hack'): if str(response.get(fld)).strip().lower() in \