Add ai_responder, some more stuff.

This commit is contained in:
Fjerkroa Auto 2023-03-21 20:43:46 +01:00
parent 5054cac49b
commit c0b51f947c
3 changed files with 110 additions and 26 deletions

View File

@ -0,0 +1,66 @@
import json
import openai
from typing import Optional, List, Dict, Any, Tuple
class AIMessageBase(object):
def __init__(self) -> None:
pass
def __str__(self) -> str:
return json.dumps(vars(self))
class AIMessage(AIMessageBase):
def __init__(self, user: str, message: str) -> None:
self.user = user
self.message = message
class AIResponse(AIMessageBase):
def __init__(self, answer: str, answer_needed: bool, staff: Optional[str], picture: Optional[str], hack: bool) -> None:
self.answer = answer
self.answer_needed = answer_needed
self.staff = staff
self.picture = picture
self.hack = hack
class AIResponder(object):
def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.history: List[Dict[str, Any]] = []
def _message(self, message: AIMessage, limit: Optional[int] = None) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
messages = []
messages.append({"role": "system", "content": self.config["system"]})
if limit is None:
history = self.history[:]
else:
history = self.history[-limit:]
history.append({"role": "user", "content": str(message)})
for msg in history:
messages.append(msg)
return messages, history
async def send(self, message: AIMessage) -> AIResponse:
limit = self.config["history-limit"]
while True:
messages, history = self._message(message, limit)
try:
result = await openai.ChatCompletion.acreate(model=self.config["model"],
messages=messages,
temperature=self.config["temperature"],
top_p=self.config["top-p"],
presence_penalty=self.config["presence-penalty"],
frequency_penalty=self.config["frequency-penalty"])
except openai.error.InvalidRequestError as err:
if 'maximum context length is' in str(err) and limit > 2:
limit -= 1
continue
raise err
answer = result['choices'][0]['message']
response = json.loads(answer['content'])
history.append(answer)
self.history = history
return response

View File

@ -6,6 +6,7 @@ from discord import Message
from discord.ext import commands
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from .ai_responder import AIResponder, AIMessage
class ConfigFileHandler(FileSystemEventHandler):
@ -18,11 +19,6 @@ class ConfigFileHandler(FileSystemEventHandler):
class FjerkroaBot(commands.Bot):
def __init__(self, config_file: str):
assert not hasattr(self, 'config_file')
assert not hasattr(self, 'config')
assert not hasattr(self, 'observer')
assert not hasattr(self, 'file_handler')
self.config_file = config_file
self.config = self.load_config(self.config_file)
intents = discord.Intents.default()
@ -34,6 +30,8 @@ class FjerkroaBot(commands.Bot):
self.observer.schedule(self.file_handler, path=".", recursive=False)
self.observer.start()
self.airesponder = AIResponder(self.config)
super().__init__(command_prefix="!", case_insensitive=True, intents=intents)
@classmethod
@ -44,12 +42,28 @@ class FjerkroaBot(commands.Bot):
def on_config_file_modified(self, event):
if event.src_path == self.config_file:
self.config = self.load_config(self.config_file)
self.airesponder.config = self.config
async def on_ready(self):
print(f"We have logged in as {self.user}")
self.staff_channel = None
self.welcome_channel = None
for guild in self.guilds:
if self.staff_channel is None:
self.staff_channel = discord.utils.get(guild.channels, name=self.config['staff-channel'])
if self.welcome_channel is None:
self.welcome_channel = discord.utils.get(guild.channels, name=self.config['welcome-channel'])
async def on_message(self, message: Message) -> None:
await message.channel.send("Hello!")
msg = AIMessage(message.author.name, str(message.content).strip())
response = await self.airesponder.send(msg)
if response.staff is not None:
async with self.staff_channel.typing():
self.staff_channel.send(response.staff)
if not response.answer_needed:
return
async with message.channel.typing():
message.channel.send(response.answer)
async def close(self):
self.observer.stop()

View File

@ -20,21 +20,13 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
"temperature": 0.9,
}
self.history_data = []
class TestFunctionality(TestBotBase):
def test_load_config(self):
with patch('builtins.open', mock_open(read_data=json.dumps(self.config_data))):
result = FjerkroaBot.load_config('config.json')
self.assertEqual(result, self.config_data)
async def test_on_message_event(self):
with patch.object(FjerkroaBot, 'load_config', new=lambda s, c: self.config_data), \
patch.object(FjerkroaBot, 'user', new_callable=PropertyMock) as mock_user:
mock_user.return_value = MagicMock(spec=User)
mock_user.return_value.id = 12
bot = FjerkroaBot('config.json')
self.bot = FjerkroaBot('config.json')
def create_message(self, message: str) -> Message:
message = MagicMock(spec=Message)
message.content = "Hello, how are you?"
message.author = AsyncMock(spec=User)
@ -42,7 +34,19 @@ class TestFunctionality(TestBotBase):
message.author.bot = False
message.channel = AsyncMock(spec=TextChannel)
message.channel.send = AsyncMock()
await bot.on_message(message)
return message
class TestFunctionality(TestBotBase):
def test_load_config(self) -> None:
with patch('builtins.open', mock_open(read_data=json.dumps(self.config_data))):
result = FjerkroaBot.load_config('config.json')
self.assertEqual(result, self.config_data)
async def test_on_message_event(self) -> None:
message = self.create_message("Hello there! How are you?")
await self.bot.on_message(message)
message.channel.send.assert_called_once_with("Hello!")