Switch to TOML for config file, fix bugs.

This commit is contained in:
OK 2023-03-22 21:37:33 +01:00
parent 7ed9049892
commit c85153c490
9 changed files with 69 additions and 37 deletions

View File

@ -3,7 +3,7 @@ repos:
rev: 'v1.1.1' rev: 'v1.1.1'
hooks: hooks:
- id: mypy - id: mypy
args: [--config-file=mypy.ini] args: [--config-file=mypy.ini, --install-types, --non-interactive]
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/pycqa/flake8
rev: 6.0.0 rev: 6.0.0

View File

@ -1,10 +0,0 @@
{ "openai_key": "OPENAIKEY",
"discord_token": "DISCORDTOKEN",
"model": "gpt-4",
"max_tokens": 1024,
"temperature": 0.9,
"top-p": 1.0,
"presence-penalty": 1.0,
"frequency-penalty": 1.0,
"history-limit": 10,
"system": "You are an smart AI" }

13
config.toml Normal file
View File

@ -0,0 +1,13 @@
openai-key = "OPENAIKEY"
discord-token = "DISCORDTOKEN"
model = "gpt-4"
max-tokens = 1024
temperature = 0.9
top-p = 1.0
presence-penalty = 1.0
frequency-penalty = 1.0
history-limit = 10
welcome-channel = "welcome"
staff-channel = "staff"
join-message = "Hi! I am {name}, and I am new here."
system = "You are an smart AI"

View File

@ -1,2 +1,3 @@
from .discord_bot import FjerkroaBot, main from .discord_bot import FjerkroaBot, main
from .ai_responder import AIMessage, AIResponse, AIResponder from .ai_responder import AIMessage, AIResponse, AIResponder
from .bot_logging import setup_logging

View File

@ -2,8 +2,10 @@ import json
import openai import openai
import aiohttp import aiohttp
import logging import logging
import time
from io import BytesIO from io import BytesIO
from typing import Optional, List, Dict, Any, Tuple from pprint import pformat
from typing import Optional, List, Dict, Any
class AIMessageBase(object): class AIMessageBase(object):
@ -33,11 +35,13 @@ class AIResponder(object):
def __init__(self, config: Dict[str, Any]) -> None: def __init__(self, config: Dict[str, Any]) -> None:
self.config = config self.config = config
self.history: List[Dict[str, Any]] = [] self.history: List[Dict[str, Any]] = []
openai.api_key = self.config['openai_token'] openai.api_key = self.config['openai-token']
def _message(self, message: AIMessage, limit: Optional[int] = None) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
messages = [] messages = []
messages.append({"role": "system", "content": self.config["system"]}) system = self.config["system"].replace('{date}', time.strftime('%Y-%m-%d'))\
.replace('{time}', time.strftime('%H-%M-%S'))
messages.append({"role": "system", "content": system})
if limit is None: if limit is None:
history = self.history[:] history = self.history[:]
else: else:
@ -45,7 +49,7 @@ class AIResponder(object):
history.append({"role": "user", "content": str(message)}) history.append({"role": "user", "content": str(message)})
for msg in history: for msg in history:
messages.append(msg) messages.append(msg)
return messages, history return messages
async def draw(self, description: str) -> BytesIO: async def draw(self, description: str) -> BytesIO:
for _ in range(7): for _ in range(7):
@ -80,7 +84,8 @@ class AIResponder(object):
async def send(self, message: AIMessage) -> AIResponse: async def send(self, message: AIMessage) -> AIResponse:
limit = self.config["history-limit"] limit = self.config["history-limit"]
for _ in range(14): for _ in range(14):
messages, history = self._message(message, limit) messages = self._message(message, limit)
logging.info(f"try to send this messages:\n{pformat(messages)}")
try: try:
result = await openai.ChatCompletion.acreate(model=self.config["model"], result = await openai.ChatCompletion.acreate(model=self.config["model"],
messages=messages, messages=messages,
@ -89,7 +94,10 @@ class AIResponder(object):
presence_penalty=self.config["presence-penalty"], presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"]) frequency_penalty=self.config["frequency-penalty"])
answer = result['choices'][0]['message'] answer = result['choices'][0]['message']
if type(answer) != dict:
answer = answer.to_dict()
response = json.loads(answer['content']) response = json.loads(answer['content'])
logging.info(f"got this answer:\n{pformat(response)}")
except openai.error.InvalidRequestError as err: except openai.error.InvalidRequestError as err:
if 'maximum context length is' in str(err) and limit > 4: if 'maximum context length is' in str(err) and limit > 4:
limit -= 1 limit -= 1
@ -98,7 +106,9 @@ class AIResponder(object):
except Exception as err: except Exception as err:
logging.warning(f"failed to generate response: {repr(err)}") logging.warning(f"failed to generate response: {repr(err)}")
continue continue
history.append(answer) self.history.append(messages[-1])
self.history = history self.history.append(answer)
if len(self.history) > limit:
self.history = self.history[-limit:]
return await self.post_process(response) return await self.post_process(response)
raise RuntimeError("Failed to generate answer after multiple retries") raise RuntimeError("Failed to generate answer after multiple retries")

View File

@ -1,6 +1,6 @@
import sys import sys
import argparse import argparse
import json import toml
import discord import discord
import logging import logging
from discord import Message, TextChannel from discord import Message, TextChannel
@ -28,7 +28,7 @@ class FjerkroaBot(commands.Bot):
self.observer = Observer() self.observer = Observer()
self.file_handler = ConfigFileHandler(self.on_config_file_modified) self.file_handler = ConfigFileHandler(self.on_config_file_modified)
self.observer.schedule(self.file_handler, path=".", recursive=False) self.observer.schedule(self.file_handler, path=config_file, recursive=False)
self.observer.start() self.observer.start()
self.airesponder = AIResponder(self.config) self.airesponder = AIResponder(self.config)
@ -36,9 +36,9 @@ class FjerkroaBot(commands.Bot):
super().__init__(command_prefix="!", case_insensitive=True, intents=intents) super().__init__(command_prefix="!", case_insensitive=True, intents=intents)
@classmethod @classmethod
def load_config(self, config_file: str = "config.json"): def load_config(self, config_file: str = "config.toml"):
with open(config_file, "r") as file: with open(config_file, encoding='utf-8') as file:
return json.load(file) return toml.load(file)
def on_config_file_modified(self, event): def on_config_file_modified(self, event):
if event.src_path == self.config_file: if event.src_path == self.config_file:
@ -57,23 +57,28 @@ class FjerkroaBot(commands.Bot):
async def on_member_join(self, member): async def on_member_join(self, member):
logging.info(f"User {member.name} joined") logging.info(f"User {member.name} joined")
msg = AIMessage(member.name, self.config['join-message']) msg = AIMessage(member.name, self.config['join-message'].replace('{name}', member.name))
if self.welcome_channel is not None: if self.welcome_channel is not None:
await self.respond(msg, self.welcome_channel) await self.respond(msg, self.welcome_channel)
async def on_message(self, message: Message) -> None: async def on_message(self, message: Message) -> None:
if self.user is not None and message.author.id == self.user.id:
return
msg = AIMessage(message.author.name, str(message.content).strip()) msg = AIMessage(message.author.name, str(message.content).strip())
await self.respond(msg, message.channel) await self.respond(msg, message.channel)
async def respond(self, message: AIMessage, channel: TextChannel) -> None: async def respond(self, message: AIMessage, channel: TextChannel) -> None:
logging.info(f"handle message {str(message)} for channel {channel.name}") logging.info(f"handle message {str(message)} for channel {channel.name}")
async with channel.typing():
response = await self.airesponder.send(message) response = await self.airesponder.send(message)
if response.staff is not None and self.staff_channel is not None: if response.staff is not None and self.staff_channel is not None:
async with self.staff_channel.typing(): async with self.staff_channel.typing():
await self.staff_channel.send(response.staff) await self.staff_channel.send(response.staff)
if not response.answer_needed: if not response.answer_needed:
return return
async with channel.typing(): if response.hack:
logging.warning(f"User {message.user} tried to hack the system.")
return
if response.picture is not None: if response.picture is not None:
images = [discord.File(fp=await self.airesponder.draw(response.picture), filename="image.png")] images = [discord.File(fp=await self.airesponder.draw(response.picture), filename="image.png")]
await channel.send(response.answer, files=images) await channel.send(response.answer, files=images)
@ -89,12 +94,12 @@ def main() -> int:
from .bot_logging import setup_logging from .bot_logging import setup_logging
setup_logging() setup_logging()
parser = argparse.ArgumentParser(description='Fjerkroa AI bot') parser = argparse.ArgumentParser(description='Fjerkroa AI bot')
parser.add_argument('--config', type=str, default='config.json', help='Config file.') parser.add_argument('--config', type=str, default='config.toml', help='Config file.')
args = parser.parse_args() args = parser.parse_args()
config = FjerkroaBot.load_config(args.config) config = FjerkroaBot.load_config(args.config)
bot = FjerkroaBot(args.config) bot = FjerkroaBot(args.config)
bot.run(config["discord_token"]) bot.run(config["discord-token"])
return 0 return 0

View File

@ -8,3 +8,5 @@ pytest
setuptools setuptools
wheel wheel
watchdog watchdog
toml
types-toml

View File

@ -8,7 +8,7 @@ class TestAIResponder(TestBotBase):
async def asyncSetUp(self): async def asyncSetUp(self):
await super().asyncSetUp() await super().asyncSetUp()
self.system = r""" self.system = r"""
You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you. You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you. Current date is {date} and time is {time}.
Every message from users is a dictionary in JSON format with the following fields: Every message from users is a dictionary in JSON format with the following fields:
1. `user`: name of the user who wrote the message. 1. `user`: name of the user who wrote the message.
@ -38,9 +38,19 @@ You always try to say something positive about the current day and the Fjærkroa
async def test_responder1(self) -> None: async def test_responder1(self) -> None:
response = await self.bot.airesponder.send(AIMessage("lala", "who are you?")) response = await self.bot.airesponder.send(AIMessage("lala", "who are you?"))
print(response) print(f"\n{response}")
self.assertAIResponse(response, AIResponse('test', True, None, None, False)) self.assertAIResponse(response, AIResponse('test', True, None, None, False))
async def test_history(self) -> None:
self.bot.airesponder.history = []
response = await self.bot.airesponder.send(AIMessage("lala", "which date is today?"))
print(f"\n{response}")
self.assertAIResponse(response, AIResponse('test', True, None, None, False))
response = await self.bot.airesponder.send(AIMessage("lala", "can I have an espresso please?"))
print(f"\n{response}")
self.assertAIResponse(response, AIResponse('test', True, 'something', None, False), scmp=lambda a, b: type(a) == str and len(a) > 5)
print(f"\n{self.bot.airesponder.history}")
if __name__ == "__mait__": if __name__ == "__mait__":
unittest.main() unittest.main()

View File

@ -3,6 +3,7 @@ import unittest
import pytest import pytest
import aiohttp import aiohttp
import json import json
import toml
import openai import openai
import logging import logging
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY
@ -18,7 +19,7 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
Mock(text="Nice day today!") Mock(text="Nice day today!")
] ]
self.config_data = { self.config_data = {
"openai_token": os.environ['OPENAI_TOKEN'], "openai-token": os.environ['OPENAI_TOKEN'],
"model": "gpt-4", "model": "gpt-4",
"max_tokens": 1024, "max_tokens": 1024,
"temperature": 0.9, "temperature": 0.9,
@ -55,7 +56,7 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
class TestFunctionality(TestBotBase): class TestFunctionality(TestBotBase):
def test_load_config(self) -> None: def test_load_config(self) -> None:
with patch('builtins.open', mock_open(read_data=json.dumps(self.config_data))): with patch('builtins.open', mock_open(read_data=toml.dumps(self.config_data))):
result = FjerkroaBot.load_config('config.json') result = FjerkroaBot.load_config('config.json')
self.assertEqual(result, self.config_data) self.assertEqual(result, self.config_data)