Support Leonardo.ai to generate images.
This commit is contained in:
parent
771d965e8c
commit
e237800348
@ -140,6 +140,11 @@ class AIResponder(object):
|
||||
return messages
|
||||
|
||||
async def draw(self, description: str) -> BytesIO:
|
||||
if self.config.get('leonardo-token') is not None:
|
||||
return await self._draw_leonardo(description)
|
||||
return await self._draw_openai(description)
|
||||
|
||||
async def _draw_openai(self, description: str) -> BytesIO:
|
||||
for _ in range(3):
|
||||
try:
|
||||
response = await openai.Image.acreate(prompt=description, n=1, size="512x512")
|
||||
@ -150,6 +155,32 @@ class AIResponder(object):
|
||||
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
||||
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
||||
|
||||
async def _draw_leonardo(self, description: str) -> BytesIO:
|
||||
for _ in range(3):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
|
||||
json={"prompt": description,
|
||||
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
|
||||
"num_images": 1,
|
||||
"width": 512,
|
||||
"height": 512},
|
||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json"},
|
||||
) as response:
|
||||
generation_id = (await response.json())["sdGenerationJob"]["generationId"]
|
||||
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
||||
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
||||
"Accept": "application/json"},
|
||||
) as response:
|
||||
image_url = (await response.json())["generations_by_pk"]["generated_images"][0]["url"]
|
||||
async with session.get(image_url) as response:
|
||||
return BytesIO(await response.read())
|
||||
except Exception as err:
|
||||
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
||||
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
||||
|
||||
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
|
||||
for fld in ('answer', 'channel', 'staff', 'picture', 'hack'):
|
||||
if str(response.get(fld)).strip().lower() in \
|
||||
|
||||
Loading…
Reference in New Issue
Block a user