Switch to TOML for config file, fix bugs.
This commit is contained in:
parent
7ed9049892
commit
c85153c490
@ -3,7 +3,7 @@ repos:
|
|||||||
rev: 'v1.1.1'
|
rev: 'v1.1.1'
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
args: [--config-file=mypy.ini]
|
args: [--config-file=mypy.ini, --install-types, --non-interactive]
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/flake8
|
- repo: https://github.com/pycqa/flake8
|
||||||
rev: 6.0.0
|
rev: 6.0.0
|
||||||
|
|||||||
10
config.json
10
config.json
@ -1,10 +0,0 @@
|
|||||||
{ "openai_key": "OPENAIKEY",
|
|
||||||
"discord_token": "DISCORDTOKEN",
|
|
||||||
"model": "gpt-4",
|
|
||||||
"max_tokens": 1024,
|
|
||||||
"temperature": 0.9,
|
|
||||||
"top-p": 1.0,
|
|
||||||
"presence-penalty": 1.0,
|
|
||||||
"frequency-penalty": 1.0,
|
|
||||||
"history-limit": 10,
|
|
||||||
"system": "You are an smart AI" }
|
|
||||||
13
config.toml
Normal file
13
config.toml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
openai-key = "OPENAIKEY"
|
||||||
|
discord-token = "DISCORDTOKEN"
|
||||||
|
model = "gpt-4"
|
||||||
|
max-tokens = 1024
|
||||||
|
temperature = 0.9
|
||||||
|
top-p = 1.0
|
||||||
|
presence-penalty = 1.0
|
||||||
|
frequency-penalty = 1.0
|
||||||
|
history-limit = 10
|
||||||
|
welcome-channel = "welcome"
|
||||||
|
staff-channel = "staff"
|
||||||
|
join-message = "Hi! I am {name}, and I am new here."
|
||||||
|
system = "You are an smart AI"
|
||||||
@ -1,2 +1,3 @@
|
|||||||
from .discord_bot import FjerkroaBot, main
|
from .discord_bot import FjerkroaBot, main
|
||||||
from .ai_responder import AIMessage, AIResponse, AIResponder
|
from .ai_responder import AIMessage, AIResponse, AIResponder
|
||||||
|
from .bot_logging import setup_logging
|
||||||
|
|||||||
@ -2,8 +2,10 @@ import json
|
|||||||
import openai
|
import openai
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional, List, Dict, Any, Tuple
|
from pprint import pformat
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
class AIMessageBase(object):
|
class AIMessageBase(object):
|
||||||
@ -33,11 +35,13 @@ class AIResponder(object):
|
|||||||
def __init__(self, config: Dict[str, Any]) -> None:
|
def __init__(self, config: Dict[str, Any]) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.history: List[Dict[str, Any]] = []
|
self.history: List[Dict[str, Any]] = []
|
||||||
openai.api_key = self.config['openai_token']
|
openai.api_key = self.config['openai-token']
|
||||||
|
|
||||||
def _message(self, message: AIMessage, limit: Optional[int] = None) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
def _message(self, message: AIMessage, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||||
messages = []
|
messages = []
|
||||||
messages.append({"role": "system", "content": self.config["system"]})
|
system = self.config["system"].replace('{date}', time.strftime('%Y-%m-%d'))\
|
||||||
|
.replace('{time}', time.strftime('%H-%M-%S'))
|
||||||
|
messages.append({"role": "system", "content": system})
|
||||||
if limit is None:
|
if limit is None:
|
||||||
history = self.history[:]
|
history = self.history[:]
|
||||||
else:
|
else:
|
||||||
@ -45,7 +49,7 @@ class AIResponder(object):
|
|||||||
history.append({"role": "user", "content": str(message)})
|
history.append({"role": "user", "content": str(message)})
|
||||||
for msg in history:
|
for msg in history:
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
return messages, history
|
return messages
|
||||||
|
|
||||||
async def draw(self, description: str) -> BytesIO:
|
async def draw(self, description: str) -> BytesIO:
|
||||||
for _ in range(7):
|
for _ in range(7):
|
||||||
@ -80,7 +84,8 @@ class AIResponder(object):
|
|||||||
async def send(self, message: AIMessage) -> AIResponse:
|
async def send(self, message: AIMessage) -> AIResponse:
|
||||||
limit = self.config["history-limit"]
|
limit = self.config["history-limit"]
|
||||||
for _ in range(14):
|
for _ in range(14):
|
||||||
messages, history = self._message(message, limit)
|
messages = self._message(message, limit)
|
||||||
|
logging.info(f"try to send this messages:\n{pformat(messages)}")
|
||||||
try:
|
try:
|
||||||
result = await openai.ChatCompletion.acreate(model=self.config["model"],
|
result = await openai.ChatCompletion.acreate(model=self.config["model"],
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@ -89,7 +94,10 @@ class AIResponder(object):
|
|||||||
presence_penalty=self.config["presence-penalty"],
|
presence_penalty=self.config["presence-penalty"],
|
||||||
frequency_penalty=self.config["frequency-penalty"])
|
frequency_penalty=self.config["frequency-penalty"])
|
||||||
answer = result['choices'][0]['message']
|
answer = result['choices'][0]['message']
|
||||||
|
if type(answer) != dict:
|
||||||
|
answer = answer.to_dict()
|
||||||
response = json.loads(answer['content'])
|
response = json.loads(answer['content'])
|
||||||
|
logging.info(f"got this answer:\n{pformat(response)}")
|
||||||
except openai.error.InvalidRequestError as err:
|
except openai.error.InvalidRequestError as err:
|
||||||
if 'maximum context length is' in str(err) and limit > 4:
|
if 'maximum context length is' in str(err) and limit > 4:
|
||||||
limit -= 1
|
limit -= 1
|
||||||
@ -98,7 +106,9 @@ class AIResponder(object):
|
|||||||
except Exception as err:
|
except Exception as err:
|
||||||
logging.warning(f"failed to generate response: {repr(err)}")
|
logging.warning(f"failed to generate response: {repr(err)}")
|
||||||
continue
|
continue
|
||||||
history.append(answer)
|
self.history.append(messages[-1])
|
||||||
self.history = history
|
self.history.append(answer)
|
||||||
|
if len(self.history) > limit:
|
||||||
|
self.history = self.history[-limit:]
|
||||||
return await self.post_process(response)
|
return await self.post_process(response)
|
||||||
raise RuntimeError("Failed to generate answer after multiple retries")
|
raise RuntimeError("Failed to generate answer after multiple retries")
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import toml
|
||||||
import discord
|
import discord
|
||||||
import logging
|
import logging
|
||||||
from discord import Message, TextChannel
|
from discord import Message, TextChannel
|
||||||
@ -28,7 +28,7 @@ class FjerkroaBot(commands.Bot):
|
|||||||
|
|
||||||
self.observer = Observer()
|
self.observer = Observer()
|
||||||
self.file_handler = ConfigFileHandler(self.on_config_file_modified)
|
self.file_handler = ConfigFileHandler(self.on_config_file_modified)
|
||||||
self.observer.schedule(self.file_handler, path=".", recursive=False)
|
self.observer.schedule(self.file_handler, path=config_file, recursive=False)
|
||||||
self.observer.start()
|
self.observer.start()
|
||||||
|
|
||||||
self.airesponder = AIResponder(self.config)
|
self.airesponder = AIResponder(self.config)
|
||||||
@ -36,9 +36,9 @@ class FjerkroaBot(commands.Bot):
|
|||||||
super().__init__(command_prefix="!", case_insensitive=True, intents=intents)
|
super().__init__(command_prefix="!", case_insensitive=True, intents=intents)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_config(self, config_file: str = "config.json"):
|
def load_config(self, config_file: str = "config.toml"):
|
||||||
with open(config_file, "r") as file:
|
with open(config_file, encoding='utf-8') as file:
|
||||||
return json.load(file)
|
return toml.load(file)
|
||||||
|
|
||||||
def on_config_file_modified(self, event):
|
def on_config_file_modified(self, event):
|
||||||
if event.src_path == self.config_file:
|
if event.src_path == self.config_file:
|
||||||
@ -57,23 +57,28 @@ class FjerkroaBot(commands.Bot):
|
|||||||
|
|
||||||
async def on_member_join(self, member):
|
async def on_member_join(self, member):
|
||||||
logging.info(f"User {member.name} joined")
|
logging.info(f"User {member.name} joined")
|
||||||
msg = AIMessage(member.name, self.config['join-message'])
|
msg = AIMessage(member.name, self.config['join-message'].replace('{name}', member.name))
|
||||||
if self.welcome_channel is not None:
|
if self.welcome_channel is not None:
|
||||||
await self.respond(msg, self.welcome_channel)
|
await self.respond(msg, self.welcome_channel)
|
||||||
|
|
||||||
async def on_message(self, message: Message) -> None:
|
async def on_message(self, message: Message) -> None:
|
||||||
|
if self.user is not None and message.author.id == self.user.id:
|
||||||
|
return
|
||||||
msg = AIMessage(message.author.name, str(message.content).strip())
|
msg = AIMessage(message.author.name, str(message.content).strip())
|
||||||
await self.respond(msg, message.channel)
|
await self.respond(msg, message.channel)
|
||||||
|
|
||||||
async def respond(self, message: AIMessage, channel: TextChannel) -> None:
|
async def respond(self, message: AIMessage, channel: TextChannel) -> None:
|
||||||
logging.info(f"handle message {str(message)} for channel {channel.name}")
|
logging.info(f"handle message {str(message)} for channel {channel.name}")
|
||||||
|
async with channel.typing():
|
||||||
response = await self.airesponder.send(message)
|
response = await self.airesponder.send(message)
|
||||||
if response.staff is not None and self.staff_channel is not None:
|
if response.staff is not None and self.staff_channel is not None:
|
||||||
async with self.staff_channel.typing():
|
async with self.staff_channel.typing():
|
||||||
await self.staff_channel.send(response.staff)
|
await self.staff_channel.send(response.staff)
|
||||||
if not response.answer_needed:
|
if not response.answer_needed:
|
||||||
return
|
return
|
||||||
async with channel.typing():
|
if response.hack:
|
||||||
|
logging.warning(f"User {message.user} tried to hack the system.")
|
||||||
|
return
|
||||||
if response.picture is not None:
|
if response.picture is not None:
|
||||||
images = [discord.File(fp=await self.airesponder.draw(response.picture), filename="image.png")]
|
images = [discord.File(fp=await self.airesponder.draw(response.picture), filename="image.png")]
|
||||||
await channel.send(response.answer, files=images)
|
await channel.send(response.answer, files=images)
|
||||||
@ -89,12 +94,12 @@ def main() -> int:
|
|||||||
from .bot_logging import setup_logging
|
from .bot_logging import setup_logging
|
||||||
setup_logging()
|
setup_logging()
|
||||||
parser = argparse.ArgumentParser(description='Fjerkroa AI bot')
|
parser = argparse.ArgumentParser(description='Fjerkroa AI bot')
|
||||||
parser.add_argument('--config', type=str, default='config.json', help='Config file.')
|
parser.add_argument('--config', type=str, default='config.toml', help='Config file.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
config = FjerkroaBot.load_config(args.config)
|
config = FjerkroaBot.load_config(args.config)
|
||||||
bot = FjerkroaBot(args.config)
|
bot = FjerkroaBot(args.config)
|
||||||
bot.run(config["discord_token"])
|
bot.run(config["discord-token"])
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|||||||
@ -8,3 +8,5 @@ pytest
|
|||||||
setuptools
|
setuptools
|
||||||
wheel
|
wheel
|
||||||
watchdog
|
watchdog
|
||||||
|
toml
|
||||||
|
types-toml
|
||||||
|
|||||||
@ -8,7 +8,7 @@ class TestAIResponder(TestBotBase):
|
|||||||
async def asyncSetUp(self):
|
async def asyncSetUp(self):
|
||||||
await super().asyncSetUp()
|
await super().asyncSetUp()
|
||||||
self.system = r"""
|
self.system = r"""
|
||||||
You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you.
|
You are a female AI named Fjærkroa, a cafe on the island Sleneset in Luroy kommune in Norway, with Address Straumen 6, 8762 Sleneset, Norway. Fjærkroa also offers rooms for rent. You chat with several people in different languages. Please welcome everyone to Fjærkroa. Here multiple users speak with each other and sometimes with you. Current date is {date} and time is {time}.
|
||||||
|
|
||||||
Every message from users is a dictionary in JSON format with the following fields:
|
Every message from users is a dictionary in JSON format with the following fields:
|
||||||
1. `user`: name of the user who wrote the message.
|
1. `user`: name of the user who wrote the message.
|
||||||
@ -38,9 +38,19 @@ You always try to say something positive about the current day and the Fjærkroa
|
|||||||
|
|
||||||
async def test_responder1(self) -> None:
|
async def test_responder1(self) -> None:
|
||||||
response = await self.bot.airesponder.send(AIMessage("lala", "who are you?"))
|
response = await self.bot.airesponder.send(AIMessage("lala", "who are you?"))
|
||||||
print(response)
|
print(f"\n{response}")
|
||||||
self.assertAIResponse(response, AIResponse('test', True, None, None, False))
|
self.assertAIResponse(response, AIResponse('test', True, None, None, False))
|
||||||
|
|
||||||
|
async def test_history(self) -> None:
|
||||||
|
self.bot.airesponder.history = []
|
||||||
|
response = await self.bot.airesponder.send(AIMessage("lala", "which date is today?"))
|
||||||
|
print(f"\n{response}")
|
||||||
|
self.assertAIResponse(response, AIResponse('test', True, None, None, False))
|
||||||
|
response = await self.bot.airesponder.send(AIMessage("lala", "can I have an espresso please?"))
|
||||||
|
print(f"\n{response}")
|
||||||
|
self.assertAIResponse(response, AIResponse('test', True, 'something', None, False), scmp=lambda a, b: type(a) == str and len(a) > 5)
|
||||||
|
print(f"\n{self.bot.airesponder.history}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__mait__":
|
if __name__ == "__mait__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import unittest
|
|||||||
import pytest
|
import pytest
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import json
|
import json
|
||||||
|
import toml
|
||||||
import openai
|
import openai
|
||||||
import logging
|
import logging
|
||||||
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY
|
from unittest.mock import Mock, PropertyMock, MagicMock, AsyncMock, patch, mock_open, ANY
|
||||||
@ -18,7 +19,7 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
|
|||||||
Mock(text="Nice day today!")
|
Mock(text="Nice day today!")
|
||||||
]
|
]
|
||||||
self.config_data = {
|
self.config_data = {
|
||||||
"openai_token": os.environ['OPENAI_TOKEN'],
|
"openai-token": os.environ['OPENAI_TOKEN'],
|
||||||
"model": "gpt-4",
|
"model": "gpt-4",
|
||||||
"max_tokens": 1024,
|
"max_tokens": 1024,
|
||||||
"temperature": 0.9,
|
"temperature": 0.9,
|
||||||
@ -55,7 +56,7 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
|
|||||||
class TestFunctionality(TestBotBase):
|
class TestFunctionality(TestBotBase):
|
||||||
|
|
||||||
def test_load_config(self) -> None:
|
def test_load_config(self) -> None:
|
||||||
with patch('builtins.open', mock_open(read_data=json.dumps(self.config_data))):
|
with patch('builtins.open', mock_open(read_data=toml.dumps(self.config_data))):
|
||||||
result = FjerkroaBot.load_config('config.json')
|
result = FjerkroaBot.load_config('config.json')
|
||||||
self.assertEqual(result, self.config_data)
|
self.assertEqual(result, self.config_data)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user