Add ai_responder, some more stuff.
This commit is contained in:
parent
5054cac49b
commit
c0b51f947c
66
fjerkroa_bot/ai_responder.py
Normal file
66
fjerkroa_bot/ai_responder.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
import json
|
||||||
|
import openai
|
||||||
|
from typing import Optional, List, Dict, Any, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class AIMessageBase(object):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return json.dumps(vars(self))
|
||||||
|
|
||||||
|
|
||||||
|
class AIMessage(AIMessageBase):
|
||||||
|
def __init__(self, user: str, message: str) -> None:
|
||||||
|
self.user = user
|
||||||
|
self.message = message
|
||||||
|
|
||||||
|
|
||||||
|
class AIResponse(AIMessageBase):
|
||||||
|
def __init__(self, answer: str, answer_needed: bool, staff: Optional[str], picture: Optional[str], hack: bool) -> None:
|
||||||
|
self.answer = answer
|
||||||
|
self.answer_needed = answer_needed
|
||||||
|
self.staff = staff
|
||||||
|
self.picture = picture
|
||||||
|
self.hack = hack
|
||||||
|
|
||||||
|
|
||||||
|
class AIResponder(object):
|
||||||
|
def __init__(self, config: Dict[str, Any]) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.history: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
def _message(self, message: AIMessage, limit: Optional[int] = None) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
|
messages = []
|
||||||
|
messages.append({"role": "system", "content": self.config["system"]})
|
||||||
|
if limit is None:
|
||||||
|
history = self.history[:]
|
||||||
|
else:
|
||||||
|
history = self.history[-limit:]
|
||||||
|
history.append({"role": "user", "content": str(message)})
|
||||||
|
for msg in history:
|
||||||
|
messages.append(msg)
|
||||||
|
return messages, history
|
||||||
|
|
||||||
|
async def send(self, message: AIMessage) -> AIResponse:
|
||||||
|
limit = self.config["history-limit"]
|
||||||
|
while True:
|
||||||
|
messages, history = self._message(message, limit)
|
||||||
|
try:
|
||||||
|
result = await openai.ChatCompletion.acreate(model=self.config["model"],
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.config["temperature"],
|
||||||
|
top_p=self.config["top-p"],
|
||||||
|
presence_penalty=self.config["presence-penalty"],
|
||||||
|
frequency_penalty=self.config["frequency-penalty"])
|
||||||
|
except openai.error.InvalidRequestError as err:
|
||||||
|
if 'maximum context length is' in str(err) and limit > 2:
|
||||||
|
limit -= 1
|
||||||
|
continue
|
||||||
|
raise err
|
||||||
|
answer = result['choices'][0]['message']
|
||||||
|
response = json.loads(answer['content'])
|
||||||
|
history.append(answer)
|
||||||
|
self.history = history
|
||||||
|
return response
|
||||||
@ -6,6 +6,7 @@ from discord import Message
|
|||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from watchdog.observers import Observer
|
from watchdog.observers import Observer
|
||||||
from watchdog.events import FileSystemEventHandler
|
from watchdog.events import FileSystemEventHandler
|
||||||
|
from .ai_responder import AIResponder, AIMessage
|
||||||
|
|
||||||
|
|
||||||
class ConfigFileHandler(FileSystemEventHandler):
|
class ConfigFileHandler(FileSystemEventHandler):
|
||||||
@ -18,11 +19,6 @@ class ConfigFileHandler(FileSystemEventHandler):
|
|||||||
|
|
||||||
class FjerkroaBot(commands.Bot):
|
class FjerkroaBot(commands.Bot):
|
||||||
def __init__(self, config_file: str):
|
def __init__(self, config_file: str):
|
||||||
assert not hasattr(self, 'config_file')
|
|
||||||
assert not hasattr(self, 'config')
|
|
||||||
assert not hasattr(self, 'observer')
|
|
||||||
assert not hasattr(self, 'file_handler')
|
|
||||||
|
|
||||||
self.config_file = config_file
|
self.config_file = config_file
|
||||||
self.config = self.load_config(self.config_file)
|
self.config = self.load_config(self.config_file)
|
||||||
intents = discord.Intents.default()
|
intents = discord.Intents.default()
|
||||||
@ -34,6 +30,8 @@ class FjerkroaBot(commands.Bot):
|
|||||||
self.observer.schedule(self.file_handler, path=".", recursive=False)
|
self.observer.schedule(self.file_handler, path=".", recursive=False)
|
||||||
self.observer.start()
|
self.observer.start()
|
||||||
|
|
||||||
|
self.airesponder = AIResponder(self.config)
|
||||||
|
|
||||||
super().__init__(command_prefix="!", case_insensitive=True, intents=intents)
|
super().__init__(command_prefix="!", case_insensitive=True, intents=intents)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -44,12 +42,28 @@ class FjerkroaBot(commands.Bot):
|
|||||||
def on_config_file_modified(self, event):
|
def on_config_file_modified(self, event):
|
||||||
if event.src_path == self.config_file:
|
if event.src_path == self.config_file:
|
||||||
self.config = self.load_config(self.config_file)
|
self.config = self.load_config(self.config_file)
|
||||||
|
self.airesponder.config = self.config
|
||||||
|
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
print(f"We have logged in as {self.user}")
|
print(f"We have logged in as {self.user}")
|
||||||
|
self.staff_channel = None
|
||||||
|
self.welcome_channel = None
|
||||||
|
for guild in self.guilds:
|
||||||
|
if self.staff_channel is None:
|
||||||
|
self.staff_channel = discord.utils.get(guild.channels, name=self.config['staff-channel'])
|
||||||
|
if self.welcome_channel is None:
|
||||||
|
self.welcome_channel = discord.utils.get(guild.channels, name=self.config['welcome-channel'])
|
||||||
|
|
||||||
async def on_message(self, message: Message) -> None:
|
async def on_message(self, message: Message) -> None:
|
||||||
await message.channel.send("Hello!")
|
msg = AIMessage(message.author.name, str(message.content).strip())
|
||||||
|
response = await self.airesponder.send(msg)
|
||||||
|
if response.staff is not None:
|
||||||
|
async with self.staff_channel.typing():
|
||||||
|
self.staff_channel.send(response.staff)
|
||||||
|
if not response.answer_needed:
|
||||||
|
return
|
||||||
|
async with message.channel.typing():
|
||||||
|
message.channel.send(response.answer)
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
self.observer.stop()
|
self.observer.stop()
|
||||||
|
|||||||
@ -20,30 +20,34 @@ class TestBotBase(unittest.IsolatedAsyncioTestCase):
|
|||||||
"temperature": 0.9,
|
"temperature": 0.9,
|
||||||
}
|
}
|
||||||
self.history_data = []
|
self.history_data = []
|
||||||
|
|
||||||
|
|
||||||
class TestFunctionality(TestBotBase):
|
|
||||||
|
|
||||||
def test_load_config(self):
|
|
||||||
with patch('builtins.open', mock_open(read_data=json.dumps(self.config_data))):
|
|
||||||
result = FjerkroaBot.load_config('config.json')
|
|
||||||
self.assertEqual(result, self.config_data)
|
|
||||||
|
|
||||||
async def test_on_message_event(self):
|
|
||||||
with patch.object(FjerkroaBot, 'load_config', new=lambda s, c: self.config_data), \
|
with patch.object(FjerkroaBot, 'load_config', new=lambda s, c: self.config_data), \
|
||||||
patch.object(FjerkroaBot, 'user', new_callable=PropertyMock) as mock_user:
|
patch.object(FjerkroaBot, 'user', new_callable=PropertyMock) as mock_user:
|
||||||
mock_user.return_value = MagicMock(spec=User)
|
mock_user.return_value = MagicMock(spec=User)
|
||||||
mock_user.return_value.id = 12
|
mock_user.return_value.id = 12
|
||||||
bot = FjerkroaBot('config.json')
|
self.bot = FjerkroaBot('config.json')
|
||||||
message = MagicMock(spec=Message)
|
|
||||||
message.content = "Hello, how are you?"
|
def create_message(self, message: str) -> Message:
|
||||||
message.author = AsyncMock(spec=User)
|
message = MagicMock(spec=Message)
|
||||||
message.author.id = 123
|
message.content = "Hello, how are you?"
|
||||||
message.author.bot = False
|
message.author = AsyncMock(spec=User)
|
||||||
message.channel = AsyncMock(spec=TextChannel)
|
message.author.id = 123
|
||||||
message.channel.send = AsyncMock()
|
message.author.bot = False
|
||||||
await bot.on_message(message)
|
message.channel = AsyncMock(spec=TextChannel)
|
||||||
message.channel.send.assert_called_once_with("Hello!")
|
message.channel.send = AsyncMock()
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
class TestFunctionality(TestBotBase):
|
||||||
|
|
||||||
|
def test_load_config(self) -> None:
|
||||||
|
with patch('builtins.open', mock_open(read_data=json.dumps(self.config_data))):
|
||||||
|
result = FjerkroaBot.load_config('config.json')
|
||||||
|
self.assertEqual(result, self.config_data)
|
||||||
|
|
||||||
|
async def test_on_message_event(self) -> None:
|
||||||
|
message = self.create_message("Hello there! How are you?")
|
||||||
|
await self.bot.on_message(message)
|
||||||
|
message.channel.send.assert_called_once_with("Hello!")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__mait__":
|
if __name__ == "__mait__":
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user