Support Leonardo.ai to generate images.
This commit is contained in:
parent
771d965e8c
commit
e237800348
@ -140,6 +140,11 @@ class AIResponder(object):
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
async def draw(self, description: str) -> BytesIO:
|
async def draw(self, description: str) -> BytesIO:
|
||||||
|
if self.config.get('leonardo-token') is not None:
|
||||||
|
return await self._draw_leonardo(description)
|
||||||
|
return await self._draw_openai(description)
|
||||||
|
|
||||||
|
async def _draw_openai(self, description: str) -> BytesIO:
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
try:
|
try:
|
||||||
response = await openai.Image.acreate(prompt=description, n=1, size="512x512")
|
response = await openai.Image.acreate(prompt=description, n=1, size="512x512")
|
||||||
@ -150,6 +155,32 @@ class AIResponder(object):
|
|||||||
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
||||||
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
||||||
|
|
||||||
|
async def _draw_leonardo(self, description: str) -> BytesIO:
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post("https://cloud.leonardo.ai/api/rest/v1/generations",
|
||||||
|
json={"prompt": description,
|
||||||
|
"modelId": "6bef9f1b-29cb-40c7-b9df-32b51c1f67d3",
|
||||||
|
"num_images": 1,
|
||||||
|
"width": 512,
|
||||||
|
"height": 512},
|
||||||
|
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/json"},
|
||||||
|
) as response:
|
||||||
|
generation_id = (await response.json())["sdGenerationJob"]["generationId"]
|
||||||
|
async with session.get(f"https://cloud.leonardo.ai/api/rest/v1/generations/{generation_id}",
|
||||||
|
headers={"Authorization": f"Bearer {self.config['leonardo-token']}",
|
||||||
|
"Accept": "application/json"},
|
||||||
|
) as response:
|
||||||
|
image_url = (await response.json())["generations_by_pk"]["generated_images"][0]["url"]
|
||||||
|
async with session.get(image_url) as response:
|
||||||
|
return BytesIO(await response.read())
|
||||||
|
except Exception as err:
|
||||||
|
logging.warning(f"Failed to generate image {repr(description)}: {repr(err)}")
|
||||||
|
raise RuntimeError(f"Failed to generate image {repr(description)} after multiple retries")
|
||||||
|
|
||||||
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
|
async def post_process(self, message: AIMessage, response: Dict[str, Any]) -> AIResponse:
|
||||||
for fld in ('answer', 'channel', 'staff', 'picture', 'hack'):
|
for fld in ('answer', 'channel', 'staff', 'picture', 'hack'):
|
||||||
if str(response.get(fld)).strip().lower() in \
|
if str(response.get(fld)).strip().lower() in \
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user