Switch to TOML for config file, fix bugs.
This commit is contained in:
parent
7ed9049892
commit
c85153c490
@ -3,7 +3,7 @@ repos:
|
||||
rev: 'v1.1.1'
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--config-file=mypy.ini]
|
||||
args: [--config-file=mypy.ini, --install-types, --non-interactive]
|
||||
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: 6.0.0
|
||||
|
||||
10
config.json
10
config.json
@ -1,10 +0,0 @@
|
||||
{ "openai_key": "OPENAIKEY",
|
||||
"discord_token": "DISCORDTOKEN",
|
||||
"model": "gpt-4",
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.9,
|
||||
"top-p": 1.0,
|
||||
"presence-penalty": 1.0,
|
||||
"frequency-penalty": 1.0,
|
||||
"history-limit": 10,
|
||||
"system": "You are an smart AI" }
|
||||
13
config.toml
Normal file
13
config.toml
Normal file
@ -0,0 +1,13 @@
|
||||
openai-key = "OPENAIKEY"
|
||||
discord-token = "DISCORDTOKEN"
|
||||
model = "gpt-4"
|
||||
max-tokens = 1024
|
||||
temperature = 0.9
|
||||
top-p = 1.0
|
||||
presence-penalty = 1.0
|
||||
frequency-penalty = 1.0
|
||||
history-limit = 10
|
||||
welcome-channel = "welcome"
|
||||
staff-channel = "staff"
|
||||
join-message = "Hi! I am {name}, and I am new here."
|
||||
system = "You are an smart AI"
|
||||
@ -1,2 +1,3 @@
|
||||
from .discord_bot import FjerkroaBot, main
|
||||
from .ai_responder import AIMessage, AIResponse, AIResponder
|
||||
from .bot_logging import setup_logging
|
||||
|
||||
@ -2,8 +2,10 @@ import json
|
||||
import openai
|
||||
import aiohttp
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from pprint import pformat
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
|
||||
class AIMessageBase(object):
|
||||
@ -33,11 +35,13 @@ class AIResponder(object):
|
||||
def __init__(self, config: Dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.history: List[Dict[str, Any]] = []
|
||||
openai.api_key = self.config['openai_token']
|
||||
openai.api_key = self.config['openai-token']
|
||||
|
||||
def _message(self, message: AIMessage, limit: Optional[int] = None) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
messages = []
|
||||
messages.append({"role": "system", "content": self.config["system"]})
|
||||
system = self.config["system"].replace('{date}', time.strftime('%Y-%m-%d'))\
|
||||
.replace('{time}', time.strftime('%H-%M-%S'))
|
||||
messages.append({"role": "system", "content": system})
|
||||
if limit is None:
|
||||
history = self.history[:]
|
||||
else:
|
||||
@ -45,7 +49,7 @@ class AIResponder(object):
|
||||
history.append({"role": "user", "content": str(message)})
|
||||
for msg in history:
|
||||
messages.append(msg)
|
||||
return messages, history
|
||||
return messages
|
||||
|
||||
async def draw(self, description: str) -> BytesIO:
|
||||
for _ in range(7):
|
||||
@ -80,7 +84,8 @@ class AIResponder(object):
|
||||
async def send(self, message: AIMessage) -> AIResponse:
|
||||
limit = self.config["history-limit"]
|
||||
for _ in range(14):
|
||||
messages, history = self._message(message, limit)
|
||||
messages = self._message(message, limit)
|
||||
logging.info(f"try to send this messages:\n{pformat(messages)}")
|
||||
try:
|
||||
result = await openai.ChatCompletion.acreate(model=self.config["model"],
|
||||
messages=messages,
|
||||
@ -89,7 +94,10 @@ class AIResponder(object):
|
||||
presence_penalty=self.config["presence-penalty"],
|
||||
frequency_penalty=self.config["frequency-penalty"])
|
||||
answer = result['choices'][0]['message']
|
||||
if type(answer) != dict:
|
||||
answer = answer.to_dict()
|
||||
response = json.loads(answer['content'])
|
||||
logging.info(f"got this answer:\n{pformat(response)}")
|
||||
except openai.error.InvalidRequestError as err:
|
||||
if 'maximum context length is' in str(err) and limit > 4:
|
||||
limit -= 1
|
||||
@ -98,7 +106,9 @@ class AIResponder(object):
|
||||
except Exception as err:
|
||||
logging.warning(f"failed to generate response: {repr(err)}")
|
||||
continue
|
||||
history.append(answer)
|
||||
self.history = history
|
||||
self.history.append(messages[-1])
|
||||
self.history.append(answer)
|
||||
if len(self.history) > limit:
|
||||
self.history = self.history[-limit:]
|
||||
return await self.post_process(response)
|
||||
raise RuntimeError("Failed to generate answer after multiple retries")
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
import toml
|
||||
import discord
|
||||
import logging
|
||||
from discord import Message, TextChannel
|
||||
@ -28,7 +28,7 @@ class FjerkroaBot(commands.Bot):
|
||||
|
||||
self.observer = Observer()
|
||||
self.file_handler = ConfigFileHandler(self.on_config_file_modified)
|
||||
self.observer.schedule(self.file_handler, path=".", recursive=False)
|
||||
self.observer.schedule(self.file_handler, path=config_file, recursive=False)
|
||||
self.observer.start()
|
||||
|
||||
self.airesponder = AIResponder(self.config)
|
||||
@ -36,9 +36,9 @@ class FjerkroaBot(commands.Bot):
|
||||
super().__init__(command_prefix="!", case_insensitive=True, intents=intents)
|
||||
|
||||
@classmethod
|
||||
def load_config(self, config_file: str = "config.json"):
|
||||
with open(config_file, "r") as file:
|
||||
return json.load(file)
|
||||
def load_config(self, config_file: str = "config.toml"):
|
||||
with open(config_file, encoding='utf-8') as file:
|
||||
return toml.load(file)
|
||||
|
||||
def on_config_file_modified(self, event):
|
||||
if event.src_path == self.config_file:
|
||||
@ -57,23 +57,28 @@ class FjerkroaBot(commands.Bot):
|
||||
|
||||
async def on_member_join(self, member):
|
||||
logging.info(f"User {member.name} joined")
|
||||
msg = AIMessage(member.name, self.config['join-message'])
|
||||
msg = AIMessage(member.name, self.config['join-message'].replace('{name}', member.name))
|
||||
if self.welcome_channel is not None:
|
||||
await self.respond(msg, self.welcome_channel)
|
||||
|
||||
async def on_message(self, message: Message) -> None:
|
||||
if self.user is not None and message.author.id == self.user.id:
|
||||
return
|
||||
msg = AIMessage(message.author.name, str(message.content).strip())
|
||||
await self.respond(msg, message.channel)
|
||||
|
||||
async def respond(self, message: AIMessage, channel: TextChannel) -> None:
|
||||
logging.info(f"handle message {str(message)} for channel {channel.name}")
|
||||
response = await self.airesponder.send(message)
|
||||
if response.staff is not None and self.staff_channel is not None:
|
||||
async with self.staff_channel.typing():
|
||||
await self.staff_channel.send(response.staff)
|
||||
if not response.answer_needed:
|
||||
return
|
||||
async with channel.typing():
|
||||
response = await self.airesponder.send(message)
|
||||
if response.staff is not None and self.staff_channel is not None:
|
||||
async with self.staff_channel.typing():
|
||||
await self.staff_channel.send(response.staff)
|
||||
if not response.answer_needed:
|
||||
return
|
||||
if response.hack:
|
||||
logging.warning(f"User {message.user} tried to hack the system.")
|
||||
return
|
||||
if response.picture is not None:
|
||||
images = [discord.File(fp=await self.airesponder.draw(response.picture), filename="image.png")]
|
||||
await channel.send(response.answer, files=images)
|
||||
@ -89,12 +94,12 @@ def main() -> int:
|
||||
from .bot_logging import setup_logging
|
||||
setup_logging()
|
||||
parser = argparse.ArgumentParser(description='Fjerkroa AI bot')
|
||||
parser.add_argument('--config', type=str, default='config.json', help='Config file.')
|
||||
parser.add_argument('--config', type=str, default='config.toml', help='Config file.')
|
||||
args = parser.parse_args()
|
||||
|
||||
config = FjerkroaBot.load_config(args.config)
|
||||
bot = FjerkroaBot(args.config)
|
||||
bot.run(config["discord_token"])
|
||||
bot.run(config["discord-token"])
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
@ -8,3 +8,5 @@ pytest
|
||||
setuptools
|
||||
wheel
|
||||
watchdog
|
||||
toml
|
||||
types-toml
|
||||
|
||||
@ -8,7 +8,7 @@ class TestAIResponder(TestBotBase):
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
self.system = r"""
|
||||
You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you.
|
||||
You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you. Current date is {date} and time is {time}.
|
||||
|
||||
Every message from users is a dictionary in JSON format with the following fields:
|
||||
1. `user`: name of the user who wrote the message.
|
||||
@ -38,9 +38,19 @@ You always try to say something positive about the current day and the Fjærkroa
|
||||
|
||||
async def test_responder1(self) -> None:
|
||||
response = await self.bot.airesponder.send(AIMessage("lala", "who are you?"))
|
||||
print(response)
|
||||
print(f"\n{response}")
|
||||
self.assertAIResponse(response, AIResponse('test', True, None, None, False))
|
||||
|
||||
async def test_history(self) -> None:
|
||||
self.bot.airesponder.history = []
|
||||
response = await self.bot.airesponder.send(AIMessage("lala", "which date is today?"))
|
||||
print(f"\n{response}")
|
||||
self.assertAIResponse(response, AIResponse('test', True, None, None, False))
|
||||
response = await self.bot.airesponder.send(AIMessage("lala", "can I have an espresso please?"))
|
||||
print(f"\n{response}")
|
||||
self.assertAIResponse(response, AIResponse('test', True, 'something', None, False), scmp=lambda a, b: type(a) == str and len(a) > 5)
|
||||
print(f"\n{self.bot.airesponder.history}")
|
||||
|
||||
|
||||
if __name__ == "__mait__":
|
||||
unittest.main()
|
||||
|
||||
@ -3,6 +3,7 @@ import unittest
|
||||
import pytest
|
||||
import aiohttp
|
||||
import json
|
||||
import toml
|
||||
import openai
|
||||
import logging
|
||||
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY
|
||||
@ -18,7 +19,7 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
|
||||
Mock(text="Nice day today!")
|
||||
]
|
||||
self.config_data = {
|
||||
"openai_token": os.environ['OPENAI_TOKEN'],
|
||||
"openai-token": os.environ['OPENAI_TOKEN'],
|
||||
"model": "gpt-4",
|
||||
"max_tokens": 1024,
|
||||
"temperature": 0.9,
|
||||
@ -55,7 +56,7 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
|
||||
class TestFunctionality(TestBotBase):
|
||||
|
||||
def test_load_config(self) -> None:
|
||||
with patch('builtins.open', mock_open(read_data=json.dumps(self.config_data))):
|
||||
with patch('builtins.open', mock_open(read_data=toml.dumps(self.config_data))):
|
||||
result = FjerkroaBot.load_config('config.json')
|
||||
self.assertEqual(result, self.config_data)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user