Switch to TOML for config file, fix bugs.

This commit is contained in:
OK 2023-03-22 21:37:33 +01:00
parent 7ed9049892
commit c85153c490
9 changed files with 69 additions and 37 deletions

View File

@ -3,7 +3,7 @@ repos:
rev: 'v1.1.1'
hooks:
- id: mypy
args: [--config-file=mypy.ini]
args: [--config-file=mypy.ini, --install-types, --non-interactive]
- repo: https://github.com/pycqa/flake8
rev: 6.0.0

View File

@ -1,10 +0,0 @@
{ "openai_key": "OPENAIKEY",
"discord_token": "DISCORDTOKEN",
"model": "gpt-4",
"max_tokens": 1024,
"temperature": 0.9,
"top-p": 1.0,
"presence-penalty": 1.0,
"frequency-penalty": 1.0,
"history-limit": 10,
"system": "You are an smart AI" }

13
config.toml Normal file
View File

@ -0,0 +1,13 @@
openai-key = "OPENAIKEY"
discord-token = "DISCORDTOKEN"
model = "gpt-4"
max-tokens = 1024
temperature = 0.9
top-p = 1.0
presence-penalty = 1.0
frequency-penalty = 1.0
history-limit = 10
welcome-channel = "welcome"
staff-channel = "staff"
join-message = "Hi! I am {name}, and I am new here."
system = "You are an smart AI"

View File

@ -1,2 +1,3 @@
from .discord_bot import FjerkroaBot, main
from .ai_responder import AIMessage, AIResponse, AIResponder
from .bot_logging import setup_logging

View File

@ -2,8 +2,10 @@ import json
import openai
import aiohttp
import logging
import time
from io import BytesIO
from typing import Optional, List, Dict, Any, Tuple
from pprint import pformat
from typing import Optional, List, Dict, Any
class AIMessageBase(object):
@ -33,11 +35,13 @@ class AIResponder(object):
def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.history: List[Dict[str, Any]] = []
openai.api_key = self.config['openai_token']
openai.api_key = self.config['openai-token']
def _message(self, message: AIMessage, limit: Optional[int] = None) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
messages = []
messages.append({"role": "system", "content": self.config["system"]})
system = self.config["system"].replace('{date}', time.strftime('%Y-%m-%d'))\
.replace('{time}', time.strftime('%H-%M-%S'))
messages.append({"role": "system", "content": system})
if limit is None:
history = self.history[:]
else:
@ -45,7 +49,7 @@ class AIResponder(object):
history.append({"role": "user", "content": str(message)})
for msg in history:
messages.append(msg)
return messages, history
return messages
async def draw(self, description: str) -> BytesIO:
for _ in range(7):
@ -80,7 +84,8 @@ class AIResponder(object):
async def send(self, message: AIMessage) -> AIResponse:
limit = self.config["history-limit"]
for _ in range(14):
messages, history = self._message(message, limit)
messages = self._message(message, limit)
logging.info(f"try to send this messages:\n{pformat(messages)}")
try:
result = await openai.ChatCompletion.acreate(model=self.config["model"],
messages=messages,
@ -89,7 +94,10 @@ class AIResponder(object):
presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"])
answer = result['choices'][0]['message']
if type(answer) != dict:
answer = answer.to_dict()
response = json.loads(answer['content'])
logging.info(f"got this answer:\n{pformat(response)}")
except openai.error.InvalidRequestError as err:
if 'maximum context length is' in str(err) and limit > 4:
limit -= 1
@ -98,7 +106,9 @@ class AIResponder(object):
except Exception as err:
logging.warning(f"failed to generate response: {repr(err)}")
continue
history.append(answer)
self.history = history
self.history.append(messages[-1])
self.history.append(answer)
if len(self.history) > limit:
self.history = self.history[-limit:]
return await self.post_process(response)
raise RuntimeError("Failed to generate answer after multiple retries")

View File

@ -1,6 +1,6 @@
import sys
import argparse
import json
import toml
import discord
import logging
from discord import Message, TextChannel
@ -28,7 +28,7 @@ class FjerkroaBot(commands.Bot):
self.observer = Observer()
self.file_handler = ConfigFileHandler(self.on_config_file_modified)
self.observer.schedule(self.file_handler, path=".", recursive=False)
self.observer.schedule(self.file_handler, path=config_file, recursive=False)
self.observer.start()
self.airesponder = AIResponder(self.config)
@ -36,9 +36,9 @@ class FjerkroaBot(commands.Bot):
super().__init__(command_prefix="!", case_insensitive=True, intents=intents)
@classmethod
def load_config(self, config_file: str = "config.json"):
with open(config_file, "r") as file:
return json.load(file)
def load_config(self, config_file: str = "config.toml"):
with open(config_file, encoding='utf-8') as file:
return toml.load(file)
def on_config_file_modified(self, event):
if event.src_path == self.config_file:
@ -57,23 +57,28 @@ class FjerkroaBot(commands.Bot):
async def on_member_join(self, member):
logging.info(f"User {member.name} joined")
msg = AIMessage(member.name, self.config['join-message'])
msg = AIMessage(member.name, self.config['join-message'].replace('{name}', member.name))
if self.welcome_channel is not None:
await self.respond(msg, self.welcome_channel)
async def on_message(self, message: Message) -> None:
if self.user is not None and message.author.id == self.user.id:
return
msg = AIMessage(message.author.name, str(message.content).strip())
await self.respond(msg, message.channel)
async def respond(self, message: AIMessage, channel: TextChannel) -> None:
logging.info(f"handle message {str(message)} for channel {channel.name}")
async with channel.typing():
response = await self.airesponder.send(message)
if response.staff is not None and self.staff_channel is not None:
async with self.staff_channel.typing():
await self.staff_channel.send(response.staff)
if not response.answer_needed:
return
async with channel.typing():
if response.hack:
logging.warning(f"User {message.user} tried to hack the system.")
return
if response.picture is not None:
images = [discord.File(fp=await self.airesponder.draw(response.picture), filename="image.png")]
await channel.send(response.answer, files=images)
@ -89,12 +94,12 @@ def main() -> int:
from .bot_logging import setup_logging
setup_logging()
parser = argparse.ArgumentParser(description='Fjerkroa AI bot')
parser.add_argument('--config', type=str, default='config.json', help='Config file.')
parser.add_argument('--config', type=str, default='config.toml', help='Config file.')
args = parser.parse_args()
config = FjerkroaBot.load_config(args.config)
bot = FjerkroaBot(args.config)
bot.run(config["discord_token"])
bot.run(config["discord-token"])
return 0

View File

@ -8,3 +8,5 @@ pytest
setuptools
wheel
watchdog
toml
types-toml

View File

@ -8,7 +8,7 @@ class TestAIResponder(TestBotBase):
async def asyncSetUp(self):
await super().asyncSetUp()
self.system = r"""
You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you.
You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you. Current date is {date} and time is {time}.
Every message from users is a dictionary in JSON format with the following fields:
1. `user`: name of the user who wrote the message.
@ -38,9 +38,19 @@ You always try to say something positive about the current day and the Fjærkroa
async def test_responder1(self) -> None:
response = await self.bot.airesponder.send(AIMessage("lala", "who are you?"))
print(response)
print(f"\n{response}")
self.assertAIResponse(response, AIResponse('test', True, None, None, False))
async def test_history(self) -> None:
self.bot.airesponder.history = []
response = await self.bot.airesponder.send(AIMessage("lala", "which date is today?"))
print(f"\n{response}")
self.assertAIResponse(response, AIResponse('test', True, None, None, False))
response = await self.bot.airesponder.send(AIMessage("lala", "can I have an espresso please?"))
print(f"\n{response}")
self.assertAIResponse(response, AIResponse('test', True, 'something', None, False), scmp=lambda a, b: type(a) == str and len(a) > 5)
print(f"\n{self.bot.airesponder.history}")
if __name__ == "__mait__":
unittest.main()

View File

@ -3,6 +3,7 @@ import unittest
import pytest
import aiohttp
import json
import toml
import openai
import logging
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY
@ -18,7 +19,7 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
Mock(text="Nice day today!")
]
self.config_data = {
"openai_token": os.environ['OPENAI_TOKEN'],
"openai-token": os.environ['OPENAI_TOKEN'],
"model": "gpt-4",
"max_tokens": 1024,
"temperature": 0.9,
@ -55,7 +56,7 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
class TestFunctionality(TestBotBase):
def test_load_config(self) -> None:
with patch('builtins.open', mock_open(read_data=json.dumps(self.config_data))):
with patch('builtins.open', mock_open(read_data=toml.dumps(self.config_data))):
result = FjerkroaBot.load_config('config.json')
self.assertEqual(result, self.config_data)