Support Leonardo.ai to generate images.

This commit is contained in:
OK 2023-08-19 14:33:04 +02:00
parent 771d965e8c
commit e237800348

View File

@ -140,6 +140,11 @@ class AIResponder(object):
return messages
async def draw(self, description: str) -> BytesIO:
if self.config.get('leonardo-token') is not None:
return await self._draw_leonardo(description)
return await self._draw_openai(description)
async def _draw_openai(self, description: str) -> BytesIO:
for _ in range(3):
try:
response = await openai.Image.acreate(prompt=description, n=1, size="512x512")
@ -150,6 +155,32 @@ class AIResponder(object):
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
async def _draw_leonardo(self, description: str) -> BytesIO:
for _ in range(3):
try:
async with aiohttp.ClientSession() as session:
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
json={"prompt": description,
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
"num_images": 1,
"width": 512,
"height": 512},
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
"Accept": "application/json",
"Content-Type": "application/json"},
) as response:
generation_id = (await response.json())["sdGenerationJob"]["generationId"]
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
"Accept": "application/json"},
) as response:
image_url = (await response.json())["generations_by_pk"]["generated_images"][0]["url"]
async with session.get(image_url) as response:
return BytesIO(await response.read())
except Exception as err:
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
for fld in ('answer', 'channel', 'staff', 'picture', 'hack'):
if str(response.get(fld)).strip().lower() in \