From b23a9f663f3aa9cedadb7700b3f7332d449c5529 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Fri, 7 Apr 2023 17:40:24 +0200 Subject: [PATCH] Splain main.py to several files. --- README.md | 29 +++--- chatmastermind/api_client.py | 24 +++++ chatmastermind/main.py | 172 +++++++---------------------------- chatmastermind/storage.py | 57 ++++++++++++ chatmastermind/utils.py | 65 +++++++++++++ tests/test_main.py | 56 ++++++++---- 6 files changed, 228 insertions(+), 175 deletions(-) create mode 100644 chatmastermind/api_client.py create mode 100644 chatmastermind/storage.py create mode 100644 chatmastermind/utils.py diff --git a/README.md b/README.md index 13142a8..57067d2 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # ChatMastermind -ChatMastermind is a Python application that automates conversation with AI, stores question-answer pairs with tags, and composes a relevant chat history for the next question. +ChatMastermind is a Python application that automates conversation with AI, stores question-answer pairs with tags, and composes relevant chat history for the next question. -The project uses the OpenAI API to generate responses, and stores the data in YAML files. It also allows you to filter the chat history based on tags, and supports autocompletion for tags. +The project uses the OpenAI API to generate responses and stores the data in YAML files. It also allows you to filter chat history based on tags and supports autocompletion for tags. ## Requirements @@ -13,7 +13,7 @@ The project uses the OpenAI API to generate responses, and stores the data in YA You can install these requirements using `pip`: -``` +```bash pip install -r requirements.txt ``` @@ -21,13 +21,13 @@ pip install -r requirements.txt You can install the package with the requirements using `pip`: -``` +```bash pip install . ``` ## Usage -``` +```bash cmm [-h] [-p PRINT | -q QUESTION | -D | -d] [-c CONFIG] [-m MAX_TOKENS] [-T TEMPERATURE] [-M MODEL] [-n NUMBER] [-t [TAGS [TAGS ...]]] [-e [EXTAGS [EXTAGS ...]]] [-o [OTAGS [OTAGS ...]]] ``` @@ -50,37 +50,37 @@ cmm [-h] [-p PRINT | -q QUESTION | -D | -d] [-c CONFIG] [-m MAX_TOKENS] [-T TEMP 1. Print the contents of a YAML file: -``` +```bash cmm -p example.yaml ``` 2. Ask a question: -``` +```bash cmm -q "What is the meaning of life?" -t philosophy -e religion ``` 3. Display the chat history as a Python structure: -``` +```bash cmm -D ``` 4. Display the chat history as readable text: -``` +```bash cmm -d ``` 5. Filter chat history by tags: -``` +```bash cmm -d -t tag1 tag2 ``` 6. Exclude chat history by tags: -``` +```bash cmm -d -e tag3 tag4 ``` @@ -103,13 +103,12 @@ The configuration file (`.config.yaml`) should contain the following fields: To activate autocompletion for tags, add the following line to your shell's configuration file (e.g., `.bashrc`, `.zshrc`, or `.profile`): -``` +```bash eval "$(register-python-argcomplete cmm)" ``` -After adding this line, restart your shell or run `source ` to enable autocompletion for the `chatmastermind` script. +After adding this line, restart your shell or run `source ` to enable autocompletion for the `cmm` script. ## License -This project is licensed under the terms of the WTFPL License. - +This project is licensed under the terms of the WTFPL License. \ No newline at end of file diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py new file mode 100644 index 0000000..2ff8c59 --- /dev/null +++ b/chatmastermind/api_client.py @@ -0,0 +1,24 @@ +import openai + + +def openai_api_key(api_key: str) -> None: + openai.api_key = api_key + + +def ai(chat: list[dict[str, str]], + config: dict, + number: int + ) -> tuple[list[str], dict[str, int]]: + response = openai.ChatCompletion.create( + model=config['openai']['model'], + messages=chat, + temperature=config['openai']['temperature'], + max_tokens=config['openai']['max_tokens'], + top_p=config['openai']['top_p'], + n=number, + frequency_penalty=config['openai']['frequency_penalty'], + presence_penalty=config['openai']['presence_penalty']) + result = [] + for choice in response['choices']: # type: ignore + result.append(choice['message']['content'].strip()) + return result, dict(response['usage']) # type: ignore diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 33b1114..e3ddda6 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -3,19 +3,12 @@ # vim: set fileencoding=utf-8 : import yaml -import io import sys -import shutil -import openai -import pathlib import argcomplete import argparse -from pprint import PrettyPrinter -from typing import List, Dict, Any, Optional - -terminal_size = shutil.get_terminal_size() -terminal_width = terminal_size.columns -pp = PrettyPrinter(width=terminal_width).pprint +from .utils import terminal_width, pp, tags_completer, process_tags, display_chat +from .storage import save_answers, create_chat +from .api_client import ai, openai_api_key def run_print_command(args: argparse.Namespace, config: dict) -> None: @@ -24,143 +17,56 @@ def run_print_command(args: argparse.Namespace, config: dict) -> None: pp(data) -def process_tags(config: dict, tags: list, extags: list) -> None: - print(f"Tags: {', '.join(tags)}") - if len(extags) > 0: - print(f"Excluding tags: {', '.join(extags)}") - print() - - -def append_message(chat: List[Dict[str, str]], - role: str, - content: str - ) -> None: - chat.append({'role': role, 'content': content.replace("''", "'")}) - - -def message_to_chat(message: Dict[str, str], - chat: List[Dict[str, str]] - ) -> None: - append_message(chat, 'user', message['question']) - append_message(chat, 'assistant', message['answer']) - - -def create_chat(question: Optional[str], - tags: Optional[List[str]], - extags: Optional[List[str]], - config: Dict[str, Any] - ) -> List[Dict[str, str]]: - chat = [] - append_message(chat, 'system', config['system'].strip()) - for file in sorted(pathlib.Path(config['db']).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - data_tags = set(data.get('tags', [])) - tags_match = \ - not tags or data_tags.intersection(tags) - extags_do_not_match = \ - not extags or not data_tags.intersection(extags) - if tags_match and extags_do_not_match: - message_to_chat(data, chat) - if question: - append_message(chat, 'user', question) - return chat - - -def ai(chat: list[dict[str, str]], - config: dict, - number: int - ) -> tuple[list[str], dict[str, int]]: - response = openai.ChatCompletion.create( - model=config['openai']['model'], - messages=chat, - temperature=config['openai']['temperature'], - max_tokens=config['openai']['max_tokens'], - top_p=config['openai']['top_p'], - n=number, - frequency_penalty=config['openai']['frequency_penalty'], - presence_penalty=config['openai']['presence_penalty']) - result = [] - for choice in response['choices']: # type: ignore - result.append(choice['message']['content'].strip()) - return result, dict(response['usage']) # type: ignore - - def process_and_display_chat(args: argparse.Namespace, config: dict, dump: bool = False - ) -> tuple[list[dict[str, str]], list[str]]: + ) -> tuple[list[dict[str, str]], str, list[str]]: tags = args.tags or [] extags = args.extags or [] process_tags(config, tags, extags) - chat = create_chat(args.question, tags, extags, config) + + question_parts = [] + question_list = args.question if args.question is not None else [] + source_list = args.source if args.source is not None else [] + + for question, source in zip(question_list, source_list): + with open(source) as r: + question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") + + if len(question_list) > len(source_list): + for question in question_list[len(source_list):]: + question_parts.append(question) + else: + for source in source_list[len(question_list):]: + with open(source) as r: + question_parts.append(f"```\n{r.read().strip()}\n```") + + question = '\n\n'.join(question_parts) + + chat = create_chat(question, tags, extags, config) display_chat(chat, dump) - return chat, tags - - -def display_chat(chat, dump=False) -> None: - if dump: - pp(chat) - return - for message in chat: - if message['role'] == 'user': - print('-' * terminal_width) - if len(message['content']) > terminal_width-len(message['role'])-2: - print(f"{message['role'].upper()}:") - print(message['content']) - else: - print(f"{message['role'].upper()}: {message['content']}") + return chat, question, tags def handle_question(args: argparse.Namespace, config: dict, dump: bool = False ) -> None: - chat, tags = process_and_display_chat(args, config, dump) + chat, question, tags = process_and_display_chat(args, config, dump) otags = args.output_tags or [] answers, usage = ai(chat, config, args.number) - save_answers(args.question, answers, tags, otags) - print("-" * terminal_width) + save_answers(question, answers, tags, otags) + print("-" * terminal_width()) print(f"Usage: {usage}") -def save_answers(question: str, - answers: list[str], - tags: list[str], - otags: Optional[list[str]] - ) -> None: - wtags = otags or tags - for num, answer in enumerate(answers, start=1): - title = f'-- ANSWER {num} ' - title_end = '-' * (terminal_width - len(title)) - print(f'{title}{title_end}') - print(answer) - with open(f"{num:02d}.yaml", "w") as fd: - with io.StringIO() as f: - yaml.dump({'question': question}, - f, - default_style="|", - default_flow_style=False) - fd.write(f.getvalue().replace('"question":', "question:", 1)) - with io.StringIO() as f: - yaml.dump({'answer': answer}, - f, - default_style="|", - default_flow_style=False) - fd.write(f.getvalue().replace('"answer":', "answer:", 1)) - yaml.dump({'tags': wtags}, - fd, - default_flow_style=False) - - def create_parser() -> argparse.ArgumentParser: default_config = '.config.yaml' parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") group = parser.add_mutually_exclusive_group(required=True) group.add_argument('-p', '--print', help='YAML file to print') - group.add_argument('-q', '--question', help='Question to ask') + group.add_argument('-q', '--question', nargs='*', help='Question to ask') group.add_argument('-D', '--chat-dump', help="Print chat as Python structure", action='store_true') group.add_argument('-d', '--chat', help="Print chat as readable text", action='store_true') parser.add_argument('-c', '--config', help='Config file name.', default=default_config) @@ -168,6 +74,7 @@ def create_parser() -> argparse.ArgumentParser: parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) parser.add_argument('-M', '--model', help='Model to use') parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=3) + parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query') tags_arg = parser.add_argument('-t', '--tags', nargs='*', help='List of tag names', metavar='TAGS') tags_arg.completer = tags_completer # type: ignore extags_arg = parser.add_argument('-e', '--extags', nargs='*', help='List of tag names to exclude', metavar='EXTAGS') @@ -185,7 +92,7 @@ def main() -> int: with open(args.config, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) - openai.api_key = config['openai']['api_key'] + openai_api_key(config['openai']['api_key']) if args.max_tokens: config['openai']['max_tokens'] = args.max_tokens @@ -208,22 +115,5 @@ def main() -> int: return 0 -def tags_completer(prefix, parsed_args, **kwargs): - with open(parsed_args.config, 'r') as f: - config = yaml.load(f, Loader=yaml.FullLoader) - result = [] - for file in sorted(pathlib.Path(config['db']).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - for tag in data.get('tags', []): - if prefix and len(prefix) > 0: - if tag.startswith(prefix): - result.append(tag) - else: - result.append(tag) - return list(set(result)) - - if __name__ == '__main__': sys.exit(main()) diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py new file mode 100644 index 0000000..7b7e17b --- /dev/null +++ b/chatmastermind/storage.py @@ -0,0 +1,57 @@ +import yaml +import io +import pathlib +from .utils import terminal_width, append_message, message_to_chat +from typing import List, Dict, Any, Optional + + +def save_answers(question: str, + answers: list[str], + tags: list[str], + otags: Optional[list[str]] + ) -> None: + wtags = otags or tags + for num, answer in enumerate(answers, start=1): + title = f'-- ANSWER {num} ' + title_end = '-' * (terminal_width() - len(title)) + print(f'{title}{title_end}') + print(answer) + with open(f"{num:02d}.yaml", "w") as fd: + with io.StringIO() as f: + yaml.dump({'question': question}, + f, + default_style="|", + default_flow_style=False) + fd.write(f.getvalue().replace('"question":', "question:", 1)) + with io.StringIO() as f: + yaml.dump({'answer': answer}, + f, + default_style="|", + default_flow_style=False) + fd.write(f.getvalue().replace('"answer":', "answer:", 1)) + yaml.dump({'tags': wtags}, + fd, + default_flow_style=False) + + +def create_chat(question: Optional[str], + tags: Optional[List[str]], + extags: Optional[List[str]], + config: Dict[str, Any] + ) -> List[Dict[str, str]]: + chat = [] + append_message(chat, 'system', config['system'].strip()) + for file in sorted(pathlib.Path(config['db']).iterdir()): + if file.suffix == '.yaml': + with open(file, 'r') as f: + data = yaml.load(f, Loader=yaml.FullLoader) + data_tags = set(data.get('tags', [])) + tags_match = \ + not tags or data_tags.intersection(tags) + extags_do_not_match = \ + not extags or not data_tags.intersection(extags) + if tags_match and extags_do_not_match: + message_to_chat(data, chat) + if question: + append_message(chat, 'user', question) + return chat diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py new file mode 100644 index 0000000..3db408b --- /dev/null +++ b/chatmastermind/utils.py @@ -0,0 +1,65 @@ +import shutil +import yaml +import pathlib +from pprint import PrettyPrinter +from typing import List, Dict + + +def terminal_width() -> int: + return shutil.get_terminal_size().columns + + +def pp(*args, **kwargs) -> None: + return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) + + +def process_tags(config: dict, tags: list, extags: list) -> None: + print(f"Tags: {', '.join(tags)}") + if len(extags) > 0: + print(f"Excluding tags: {', '.join(extags)}") + print() + + +def append_message(chat: List[Dict[str, str]], + role: str, + content: str + ) -> None: + chat.append({'role': role, 'content': content.replace("''", "'")}) + + +def message_to_chat(message: Dict[str, str], + chat: List[Dict[str, str]] + ) -> None: + append_message(chat, 'user', message['question']) + append_message(chat, 'assistant', message['answer']) + + +def display_chat(chat, dump=False) -> None: + if dump: + pp(chat) + return + for message in chat: + if message['role'] == 'user': + print('-' * (terminal_width())) + if len(message['content']) > terminal_width() - len(message['role']) - 2: + print(f"{message['role'].upper()}:") + print(message['content']) + else: + print(f"{message['role'].upper()}: {message['content']}") + + +def tags_completer(prefix, parsed_args, **kwargs): + with open(parsed_args.config, 'r') as f: + config = yaml.load(f, Loader=yaml.FullLoader) + result = [] + for file in sorted(pathlib.Path(config['db']).iterdir()): + if file.suffix == '.yaml': + with open(file, 'r') as f: + data = yaml.load(f, Loader=yaml.FullLoader) + for tag in data.get('tags', []): + if prefix and len(prefix) > 0: + if tag.startswith(prefix): + result.append(tag) + else: + result.append(tag) + return list(set(result)) diff --git a/tests/test_main.py b/tests/test_main.py index 95b0ef2..19386b2 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -3,10 +3,12 @@ import io import os import yaml import argparse -import chatmastermind.main -from chatmastermind.main import create_chat, ai, handle_question, save_answers +from chatmastermind.utils import terminal_width +from chatmastermind.main import create_parser, handle_question +from chatmastermind.api_client import ai +from chatmastermind.storage import create_chat, save_answers from unittest import mock -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, Mock class TestCreateChat(unittest.TestCase): @@ -86,11 +88,13 @@ class TestCreateChat(unittest.TestCase): class TestHandleQuestion(unittest.TestCase): def setUp(self): + self.question = "test question" self.args = argparse.Namespace( tags=['tag1'], extags=['extag1'], output_tags=None, - question='test question', + question=[self.question], + source=None, number=3 ) self.config = { @@ -100,20 +104,19 @@ class TestHandleQuestion(unittest.TestCase): @patch("chatmastermind.main.create_chat", return_value="test_chat") @patch("chatmastermind.main.process_tags") - @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], - "test_usage")) - @patch("chatmastermind.main.pp") - @patch("chatmastermind.main.print") - @patch("chatmastermind.main.yaml.dump") + @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) + @patch("chatmastermind.utils.pp") + @patch("builtins.print") + @patch("chatmastermind.storage.yaml.dump") def test_handle_question(self, _, mock_print, mock_pp, mock_ai, mock_process_tags, mock_create_chat): open_mock = MagicMock() - with patch("chatmastermind.main.open", open_mock): + with patch("chatmastermind.storage.open", open_mock): handle_question(self.args, self.config, True) mock_process_tags.assert_called_once_with(self.config, self.args.tags, self.args.extags) - mock_create_chat.assert_called_once_with(self.args.question, + mock_create_chat.assert_called_once_with(self.question, self.args.tags, self.args.extags, self.config) @@ -124,15 +127,14 @@ class TestHandleQuestion(unittest.TestCase): expected_calls = [] for num, answer in enumerate(mock_ai.return_value[0], start=1): title = f'-- ANSWER {num} ' - title_end = '-' * (chatmastermind.main.terminal_width - len(title)) + title_end = '-' * (terminal_width() - len(title)) expected_calls.append(((f'{title}{title_end}',),)) expected_calls.append(((answer,),)) - expected_calls.append((("-" * chatmastermind.main.terminal_width,),)) + expected_calls.append((("-" * terminal_width(),),)) expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) - open_mock.assert_has_calls([ - mock.call(f"{num:02d}.yaml", "w") for num in range(1, 4) - ] + [mock.call().__enter__(), - mock.call().__exit__(None, None, None)] * 3, + open_mock.assert_has_calls( + [mock.call(f"{num:02d}.yaml", "w") for num in range(1, 4)] + [ + mock.call().__enter__(), mock.call().__exit__(None, None, None)] * 3, any_order=True) self.assertEqual(mock_print.call_args_list, expected_calls) @@ -152,9 +154,9 @@ class TestSaveAnswers(unittest.TestCase): def test_save_answers(self): try: - self.assert_stdout(f"-- ANSWER 1 {'-'*(chatmastermind.main.terminal_width-12)}\n" + self.assert_stdout(f"-- ANSWER 1 {'-'*(terminal_width()-12)}\n" "AI is Artificial Intelligence\n" - f"-- ANSWER 2 {'-'*(chatmastermind.main.terminal_width-12)}\n" + f"-- ANSWER 2 {'-'*(terminal_width()-12)}\n" "AI is a simulation of human intelligence\n") for idx, answer in enumerate(self.answers, start=1): with open(f"{idx:02d}.yaml", "r") as file: @@ -198,3 +200,19 @@ class TestAI(unittest.TestCase): expected_result = (['response_text_1', 'response_text_2'], {'tokens': 10}) self.assertEqual(result, expected_result) + + +class TestCreateParser(unittest.TestCase): + def test_create_parser(self): + with patch('argparse.ArgumentParser.add_mutually_exclusive_group') as mock_add_mutually_exclusive_group: + mock_group = Mock() + mock_add_mutually_exclusive_group.return_value = mock_group + parser = create_parser() + self.assertIsInstance(parser, argparse.ArgumentParser) + mock_add_mutually_exclusive_group.assert_called_once_with(required=True) + mock_group.add_argument.assert_any_call('-p', '--print', help='YAML file to print') + mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask') + mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat as Python structure", action='store_true') + mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat as readable text", action='store_true') + self.assertTrue('.config.yaml' in parser.get_default('config')) + self.assertEqual(parser.get_default('number'), 3)