Compare commits
No commits in common. "a2c7aec1e3989d142286eb7c02fe33246d90f65a" and "924daf134f5c516f7658992ea18e2a03f2a50806" have entirely different histories.
a2c7aec1e3
...
924daf134f
@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
@ -120,7 +119,7 @@ class AIResponder(object):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.history: List[Dict[str, Any]] = []
|
self.history: List[Dict[str, Any]] = []
|
||||||
self.channel = channel if channel is not None else 'system'
|
self.channel = channel if channel is not None else 'system'
|
||||||
self.client = openai.AsyncOpenAI(api_key=self.config['openai-token'])
|
openai.api_key = self.config['openai-token']
|
||||||
self.rate_limit_backoff = exponential_backoff()
|
self.rate_limit_backoff = exponential_backoff()
|
||||||
self.history_file: Optional[Path] = None
|
self.history_file: Optional[Path] = None
|
||||||
if 'history-directory' in self.config:
|
if 'history-directory' in self.config:
|
||||||
@ -134,11 +133,6 @@ class AIResponder(object):
|
|||||||
system = self.config.get(self.channel, self.config['system'])
|
system = self.config.get(self.channel, self.config['system'])
|
||||||
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
|
system = system.replace('{date}', time.strftime('%Y-%m-%d'))\
|
||||||
.replace('{time}', time.strftime('%H:%M:%S'))
|
.replace('{time}', time.strftime('%H:%M:%S'))
|
||||||
news_feed = self.config.get('news')
|
|
||||||
if news_feed and os.path.exists(news_feed):
|
|
||||||
with open(news_feed) as fd:
|
|
||||||
news_feed = fd.read().strip()
|
|
||||||
system = system.replace('{news}', news_feed)
|
|
||||||
messages.append({"role": "system", "content": system})
|
messages.append({"role": "system", "content": system})
|
||||||
if limit is not None:
|
if limit is not None:
|
||||||
while len(self.history) > limit:
|
while len(self.history) > limit:
|
||||||
@ -156,9 +150,9 @@ class AIResponder(object):
|
|||||||
async def _draw_openai(self, description: str) -> BytesIO:
|
async def _draw_openai(self, description: str) -> BytesIO:
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
try:
|
try:
|
||||||
response = await self.client.images.generate(prompt=description, n=1, size="1024x1024", model="dall-e-3")
|
response = await openai.Image.acreate(prompt=description, n=1, size="512x512")
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(response.data[0].url) as image:
|
async with session.get(response['data'][0]['url']) as image:
|
||||||
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
|
logging.info(f'Drawed a picture with DALL-E on this description: {repr(description)}')
|
||||||
return BytesIO(await image.read())
|
return BytesIO(await image.read())
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@ -270,27 +264,29 @@ class AIResponder(object):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[openai.types.chat.ChatCompletionMessage], int]:
|
async def _acreate(self, messages: List[Dict[str, Any]], limit: int) -> Tuple[Optional[Dict[str, Any]], int]:
|
||||||
model = self.config["model"]
|
model = self.config["model"]
|
||||||
try:
|
try:
|
||||||
result = await self.client.chat.completions.create(model=model,
|
result = await openai.ChatCompletion.acreate(model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=self.config["temperature"],
|
temperature=self.config["temperature"],
|
||||||
max_tokens=self.config["max-tokens"],
|
max_tokens=self.config["max-tokens"],
|
||||||
top_p=self.config["top-p"],
|
top_p=self.config["top-p"],
|
||||||
presence_penalty=self.config["presence-penalty"],
|
presence_penalty=self.config["presence-penalty"],
|
||||||
frequency_penalty=self.config["frequency-penalty"])
|
frequency_penalty=self.config["frequency-penalty"])
|
||||||
answer = result.choices[0].message
|
answer = result['choices'][0]['message']
|
||||||
|
if type(answer) != dict:
|
||||||
|
answer = answer.to_dict()
|
||||||
self.rate_limit_backoff = exponential_backoff()
|
self.rate_limit_backoff = exponential_backoff()
|
||||||
logging.info(f"generated response {result.usage}: {repr(answer)}")
|
logging.info(f"generated response {result.get('usage')}: {repr(answer)}")
|
||||||
return answer, limit
|
return answer, limit
|
||||||
except openai.BadRequestError as err:
|
except openai.error.InvalidRequestError as err:
|
||||||
if 'maximum context length is' in str(err) and limit > 4:
|
if 'maximum context length is' in str(err) and limit > 4:
|
||||||
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
|
logging.warning(f"context length exceeded, reduce the limit {limit}: {str(err)}")
|
||||||
limit -= 1
|
limit -= 1
|
||||||
return None, limit
|
return None, limit
|
||||||
raise err
|
raise err
|
||||||
except openai.RateLimitError as err:
|
except openai.error.RateLimitError as err:
|
||||||
rate_limit_sleep = next(self.rate_limit_backoff)
|
rate_limit_sleep = next(self.rate_limit_backoff)
|
||||||
if "retry-model" in self.config:
|
if "retry-model" in self.config:
|
||||||
model = self.config["retry-model"]
|
model = self.config["retry-model"]
|
||||||
@ -306,12 +302,12 @@ class AIResponder(object):
|
|||||||
messages = [{"role": "system", "content": self.config["fix-description"]},
|
messages = [{"role": "system", "content": self.config["fix-description"]},
|
||||||
{"role": "user", "content": answer}]
|
{"role": "user", "content": answer}]
|
||||||
try:
|
try:
|
||||||
result = await self.client.chat.completions.create(model=self.config["fix-model"],
|
result = await openai.ChatCompletion.acreate(model=self.config["fix-model"],
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
max_tokens=2048)
|
max_tokens=2048)
|
||||||
logging.info(f"got this message as fix:\n{pp(result.choices[0].message.content)}")
|
logging.info(f"got this message as fix:\n{pp(result['choices'][0]['message']['content'])}")
|
||||||
response = result.choices[0].message.content
|
response = result['choices'][0]['message']['content']
|
||||||
start, end = response.find("{"), response.rfind("}")
|
start, end = response.find("{"), response.rfind("}")
|
||||||
if start == -1 or end == -1 or (start + 3) >= end:
|
if start == -1 or end == -1 or (start + 3) >= end:
|
||||||
return answer
|
return answer
|
||||||
@ -330,11 +326,11 @@ class AIResponder(object):
|
|||||||
f" if it is not already in {language}, otherwise you just copy it."},
|
f" if it is not already in {language}, otherwise you just copy it."},
|
||||||
{"role": "user", "content": text}]
|
{"role": "user", "content": text}]
|
||||||
try:
|
try:
|
||||||
result = await self.client.chat.completions.create(model=self.config["fix-model"],
|
result = await openai.ChatCompletion.acreate(model=self.config["fix-model"],
|
||||||
messages=message,
|
messages=message,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
max_tokens=2048)
|
max_tokens=2048)
|
||||||
response = result.choices[0].message.content
|
response = result['choices'][0]['message']['content']
|
||||||
logging.info(f"got this translated message:\n{pp(response)}")
|
logging.info(f"got this translated message:\n{pp(response)}")
|
||||||
return response
|
return response
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
@ -390,16 +386,16 @@ class AIResponder(object):
|
|||||||
|
|
||||||
# Attempt to parse the AI's response
|
# Attempt to parse the AI's response
|
||||||
try:
|
try:
|
||||||
response = parse_json(answer.content)
|
response = parse_json(answer['content'])
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logging.warning(f"failed to parse the answer: {pp(err)}\n{repr(answer.content)}")
|
logging.warning(f"failed to parse the answer: {pp(err)}\n{repr(answer['content'])}")
|
||||||
answer.content = await self.fix(answer['content'])
|
answer['content'] = await self.fix(answer['content'])
|
||||||
|
|
||||||
# Retry parsing the fixed content
|
# Retry parsing the fixed content
|
||||||
try:
|
try:
|
||||||
response = parse_json(answer.content)
|
response = parse_json(answer['content'])
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logging.error(f"failed to parse the fixed answer: {pp(err)}\n{repr(answer.content)}")
|
logging.error(f"failed to parse the fixed answer: {pp(err)}\n{repr(answer['content'])}")
|
||||||
retries -= 1
|
retries -= 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -414,7 +410,7 @@ class AIResponder(object):
|
|||||||
|
|
||||||
# Post-process the message and update the answer's content
|
# Post-process the message and update the answer's content
|
||||||
answer_message = await self.post_process(message, response)
|
answer_message = await self.post_process(message, response)
|
||||||
answer.content = str(answer_message)
|
answer['content'] = str(answer_message)
|
||||||
|
|
||||||
# Update message history
|
# Update message history
|
||||||
self.update_history(messages[-1], answer, limit, message.historise_question)
|
self.update_history(messages[-1], answer, limit, message.historise_question)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user