From bb8aa2f81714d9458a32342c177373395c6b7ae6 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 5 Aug 2023 12:36:04 +0200 Subject: [PATCH 001/170] Fix read_file declaration. --- chatmastermind/storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index 4215e5a..44b21fc 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -5,7 +5,7 @@ from .utils import terminal_width, append_message, message_to_chat from typing import List, Dict, Any, Optional -def read_file(fname: str, tags_only: bool = False) -> Dict[str, Any]: +def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: with open(fname, "r") as fd: if tags_only: return {"tags": [x.strip() for x in fd.readline().strip().split(':')[1].strip().split(',')]} -- 2.36.6 From 01de75bef3ba02551f7ef4fcc31245cecfb34289 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 5 Aug 2023 12:39:46 +0200 Subject: [PATCH 002/170] Improve -l output. --- chatmastermind/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index 2dceb69..387db7b 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -74,9 +74,10 @@ def display_chat(chat, dump=False, source_code=False) -> None: else: print(f"{message['role'].upper()}: {message['content']}") + def display_tags_frequency(tags: List[str], dump=False) -> None: if dump: pp(tags) return for tag in set(tags): - print(f"-{tag} : {tags.count(tag)}") + print(f"- {tag}: {tags.count(tag)}") -- 2.36.6 From 8bb2a002a6e68330818f78b23da8dbfc68670b47 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 5 Aug 2023 13:21:17 +0200 Subject: [PATCH 003/170] Fixed tests. --- chatmastermind/main.py | 11 ++++++++--- chatmastermind/storage.py | 1 + tests/test_main.py | 12 ++++++++---- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 6fe4984..01cc038 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -61,9 +61,10 @@ def process_and_display_chat(args: argparse.Namespace, display_chat(chat, dump, args.only_source_code) return chat, full_question, tags + def process_and_display_tags(args: argparse.Namespace, config: dict, - dump: bool=False + dump: bool = False ) -> None: display_tags_frequency(get_tags(config, None), dump) @@ -104,8 +105,12 @@ def create_parser() -> argparse.ArgumentParser: parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query') parser.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", action='store_true') - parser.add_argument('-W', '--with-file', help="Print chat history with filename.", action='store_true') - parser.add_argument('-a', '--match-all-tags', help="All given tags must match when selecting chat history entries.", action='store_true') + parser.add_argument('-W', '--with-file', + help="Print chat history with filename.", + action='store_true') + parser.add_argument('-a', '--match-all-tags', + help="All given tags must match when selecting chat history entries.", + action='store_true') tags_arg = parser.add_argument('-t', '--tags', nargs='*', help='List of tag names', metavar='TAGS') tags_arg.completer = tags_completer # type: ignore extags_arg = parser.add_argument('-e', '--extags', nargs='*', help='List of tag names to exclude', metavar='EXTAGS') diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index 50e1ddd..69adc58 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -111,5 +111,6 @@ def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]: result.append(tag) return result + def get_tags_unique(config: Dict[str, Any], prefix: Optional[str]) -> List[str]: return list(set(get_tags(config, prefix))) diff --git a/tests/test_main.py b/tests/test_main.py index eca160f..b572557 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -95,7 +95,10 @@ class TestHandleQuestion(unittest.TestCase): question=[self.question], source=None, only_source_code=False, - number=3 + number=3, + match_all_tags=False, + with_tags=False, + with_file=False, ) self.config = { 'db': 'test_files', @@ -119,7 +122,8 @@ class TestHandleQuestion(unittest.TestCase): mock_create_chat.assert_called_once_with(self.question, self.args.tags, self.args.extags, - self.config) + self.config, + False, False, False) mock_pp.assert_called_once_with("test_chat") mock_ai.assert_called_with("test_chat", self.config, @@ -203,7 +207,7 @@ class TestCreateParser(unittest.TestCase): mock_add_mutually_exclusive_group.assert_called_once_with(required=True) mock_group.add_argument.assert_any_call('-p', '--print', help='File to print') mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask') - mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat as Python structure", action='store_true') - mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat as readable text", action='store_true') + mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat history as Python structure", action='store_true') + mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat history as readable text", action='store_true') self.assertTrue('.config.yaml' in parser.get_default('config')) self.assertEqual(parser.get_default('number'), 1) -- 2.36.6 From 820d938060d94060e0fe7acbb7da43479fa0998d Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 5 Aug 2023 14:42:38 +0200 Subject: [PATCH 004/170] Fix tests for Python 3.10. --- tests/test_main.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/test_main.py b/tests/test_main.py index b572557..48d9ea8 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -21,9 +21,11 @@ class TestCreateChat(unittest.TestCase): self.tags = ['test_tag'] @patch('os.listdir') + @patch('pathlib.Path.iterdir') @patch('builtins.open') - def test_create_chat_with_tags(self, open_mock, listdir_mock): + def test_create_chat_with_tags(self, open_mock, iterdir_mock, listdir_mock): listdir_mock.return_value = ['testfile.txt'] + iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( {'question': 'test_content', 'answer': 'some answer', 'tags': ['test_tag']})) @@ -41,9 +43,11 @@ class TestCreateChat(unittest.TestCase): {'role': 'user', 'content': self.question}) @patch('os.listdir') + @patch('pathlib.Path.iterdir') @patch('builtins.open') - def test_create_chat_with_other_tags(self, open_mock, listdir_mock): + def test_create_chat_with_other_tags(self, open_mock, iterdir_mock, listdir_mock): listdir_mock.return_value = ['testfile.txt'] + iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( {'question': 'test_content', 'answer': 'some answer', 'tags': ['other_tag']})) @@ -57,9 +61,11 @@ class TestCreateChat(unittest.TestCase): {'role': 'user', 'content': self.question}) @patch('os.listdir') + @patch('pathlib.Path.iterdir') @patch('builtins.open') - def test_create_chat_without_tags(self, open_mock, listdir_mock): + def test_create_chat_without_tags(self, open_mock, iterdir_mock, listdir_mock): listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] + iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] open_mock.side_effect = ( io.StringIO(dump_data({'question': 'test_content', 'answer': 'some answer', -- 2.36.6 From ca3a53e68b34172b9e91b134a4824bb6a6913e89 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 5 Aug 2023 15:35:27 +0200 Subject: [PATCH 005/170] Fix backwards compatibility of -W flag. --- chatmastermind/storage.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index 69adc58..bf6d81c 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -16,7 +16,7 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: question = "\n".join(text[question_idx:answer_idx]).strip() answer = "\n".join(text[answer_idx + 1:]).strip() return {"question": question, "answer": answer, "tags": tags, - "file": pathlib.Path(fname).name} + "file": fname.name} def dump_data(data: Dict[str, Any]) -> str: @@ -74,6 +74,7 @@ def create_chat(question: Optional[str], if file.suffix == '.yaml': with open(file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) + data['file'] = file.name elif file.suffix == '.txt': data = read_file(file) else: -- 2.36.6 From caf5244d520aecaf482013d6a04c230de704f9b6 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 5 Aug 2023 16:04:25 +0200 Subject: [PATCH 006/170] Add action -L to list all available models --- chatmastermind/api_client.py | 11 +++++++++++ chatmastermind/main.py | 5 ++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py index 2ff8c59..b9b0d05 100644 --- a/chatmastermind/api_client.py +++ b/chatmastermind/api_client.py @@ -5,6 +5,17 @@ def openai_api_key(api_key: str) -> None: openai.api_key = api_key +def display_models() -> None: + not_ready = [] + for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): + if engine['ready']: + print(engine['id']) + else: + not_ready.append(engine['id']) + if len(not_ready) > 0: + print('\nNot ready: ' + ', '.join(not_ready)) + + def ai(chat: list[dict[str, str]], config: dict, number: int diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 01cc038..68fe906 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -9,7 +9,7 @@ import argparse import pathlib from .utils import terminal_width, process_tags, display_chat, display_source_code, display_tags_frequency from .storage import save_answers, create_chat, get_tags, get_tags_unique, read_file, dump_data -from .api_client import ai, openai_api_key +from .api_client import ai, openai_api_key, display_models from itertools import zip_longest @@ -97,6 +97,7 @@ def create_parser() -> argparse.ArgumentParser: group.add_argument('-D', '--chat-dump', help="Print chat history as Python structure", action='store_true') group.add_argument('-d', '--chat', help="Print chat history as readable text", action='store_true') group.add_argument('-l', '--list-tags', help="List all tags and their frequency", action='store_true') + group.add_argument('-L', '--list-models', help="List all available models", action='store_true') parser.add_argument('-c', '--config', help='Config file name.', default=default_config) parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) @@ -149,6 +150,8 @@ def main() -> int: process_and_display_chat(args, config) elif args.list_tags: process_and_display_tags(args, config) + elif args.list_models: + display_models() return 0 -- 2.36.6 From f8ed0e36367b71746e22c1e9665a41fd7c161c6e Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 5 Aug 2023 17:45:43 +0200 Subject: [PATCH 007/170] tags are now separated by ' ' (old format is still readable) --- chatmastermind/storage.py | 11 +++++++---- chatmastermind/utils.py | 8 ++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index bf6d81c..aa6288c 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -7,10 +7,13 @@ from typing import List, Dict, Any, Optional def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: with open(fname, "r") as fd: + tagline = fd.readline().strip().split(':')[1].strip() + # also support tags separated by ',' (old format) + separator = ',' if ',' in tagline else ' ' + tags = [t.strip() for t in tagline.split(separator)] if tags_only: - return {"tags": [x.strip() for x in fd.readline().strip().split(':')[1].strip().split(',')]} + return {"tags": tags} text = fd.read().strip().split('\n') - tags = [x.strip() for x in text.pop(0).split(':')[1].strip().split(',')] question_idx = text.index("=== QUESTION ===") + 1 answer_idx = text.index("==== ANSWER ====") question = "\n".join(text[question_idx:answer_idx]).strip() @@ -21,7 +24,7 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: def dump_data(data: Dict[str, Any]) -> str: with io.StringIO() as fd: - fd.write(f'TAGS: {", ".join(data["tags"])}\n') + fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n') fd.write(f'==== ANSWER ====\n{data["answer"]}\n') return fd.getvalue() @@ -29,7 +32,7 @@ def dump_data(data: Dict[str, Any]) -> str: def write_file(fname: str, data: Dict[str, Any]) -> None: with open(fname, "w") as fd: - fd.write(f'TAGS: {", ".join(data["tags"])}\n') + fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n') fd.write(f'==== ANSWER ====\n{data["answer"]}\n') diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index 387db7b..7bac123 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -15,11 +15,11 @@ def process_tags(tags: list[str], extags: list[str], otags: list[str]) -> None: printed_messages = [] if tags: - printed_messages.append(f"Tags: {', '.join(tags)}") + printed_messages.append(f"Tags: {' '.join(tags)}") if extags: - printed_messages.append(f"Excluding tags: {', '.join(extags)}") + printed_messages.append(f"Excluding tags: {' '.join(extags)}") if otags: - printed_messages.append(f"Output tags: {', '.join(otags)}") + printed_messages.append(f"Output tags: {' '.join(otags)}") if printed_messages: print("\n".join(printed_messages)) @@ -41,7 +41,7 @@ def message_to_chat(message: Dict[str, str], append_message(chat, 'user', message['question']) append_message(chat, 'assistant', message['answer']) if with_tags: - tags = ", ".join(message['tags']) + tags = " ".join(message['tags']) append_message(chat, 'tags', tags) if with_file: append_message(chat, 'file', message['file']) -- 2.36.6 From c5c4a6628f688a6a93dc57e33406f3aca8cce415 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 5 Aug 2023 21:00:30 +0200 Subject: [PATCH 008/170] Allow character ":" in tags. --- chatmastermind/storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index aa6288c..ac59eb5 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -7,7 +7,7 @@ from typing import List, Dict, Any, Optional def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: with open(fname, "r") as fd: - tagline = fd.readline().strip().split(':')[1].strip() + tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() # also support tags separated by ',' (old format) separator = ',' if ',' in tagline else ' ' tags = [t.strip() for t in tagline.split(separator)] -- 2.36.6 From 9b6b13993c2e58d82d88b90c2cccbf10834395b1 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 5 Aug 2023 23:07:39 +0200 Subject: [PATCH 009/170] Output the tag list sorted alphabetically. --- chatmastermind/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index 7bac123..bc1dcd2 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -79,5 +79,5 @@ def display_tags_frequency(tags: List[str], dump=False) -> None: if dump: pp(tags) return - for tag in set(tags): + for tag in sorted(set(tags)): print(f"- {tag}: {tags.count(tag)}") -- 2.36.6 From 7a92ebe539738cf659d54f5b4f9a9141d8dc3ef9 Mon Sep 17 00:00:00 2001 From: juk0de <5322305+juk0de@users.noreply.github.com> Date: Thu, 10 Aug 2023 08:26:27 +0200 Subject: [PATCH 010/170] README: Added 'Contributing' section --- README.md | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/README.md b/README.md index dd38a71..617b5c0 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,45 @@ eval "$(register-python-argcomplete cmm)" After adding this line, restart your shell or run `source ` to enable autocompletion for the `cmm` script. +## Contributing + +### Enable commit hooks +``` +pip install pre-commit +pre-commit install +``` +### Execute tests before opening a PR +``` +pytest +``` +### Consider using `pyenv` / `pyenv-virtualenv` +Short installation instructions: +* install `pyenv`: +``` +cd ~ +git clone https://github.com/pyenv/pyenv .pyenv +cd ~/.pyenv && src/configure && make -C src +``` +* make sure that `~/.pyenv/shims` and `~/.pyenv/bin` are the first entries in your `PATH`, e. g. by setting it in `~/.bashrc` +* add the following to your `~/.bashrc` (after setting `PATH`): `eval "$(pyenv init -)"` +* create a new terminal or source the changes (e. g. `source ~/.bashrc`) +* install `virtualenv` +``` +git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv +``` +* add the following to your `~/.bashrc` (after the commands above): `eval "$(pyenv virtualenv-init -)` +* create a new terminal or source the changes (e. g. `source ~/.bashrc`) +* go back to the `ChatMasterMind` repo and create a virtual environment with the latest `Python`, e. g. `3.11.4`: +``` +cd +pyenv install 3.11.4 +pyenv virtualenv 3.11.4 py311 +pyenv activate py311 +``` +* see also the [official pyenv documentation](https://github.com/pyenv/pyenv#readme) + ## License This project is licensed under the terms of the WTFPL License. + + -- 2.36.6 From bc9baff0dcdd97b42e1b87ce56adb9705f2645f2 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Thu, 10 Aug 2023 11:29:54 +0200 Subject: [PATCH 011/170] Add official repository URL. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 617b5c0..7db5ea4 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ ChatMastermind is a Python application that automates conversation with AI, stor The project uses the OpenAI API to generate responses and stores the data in YAML files. It also allows you to filter chat history based on tags and supports autocompletion for tags. +Official repository URL: https://kaizenkodo.no/gitea/kaizenkodo/ChatMastermind.git + ## Requirements - Python 3.6 or higher -- 2.36.6 From df91ca863a6c5e7a9311cd8005fd97035e9e192a Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Fri, 11 Aug 2023 11:03:04 +0200 Subject: [PATCH 012/170] Fix the supported python version in README.md and set it to 3.9, also add some classifiers. --- README.md | 2 +- setup.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 7db5ea4..4ff5d97 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Official repository URL: https://kaizenkodo.no/gitea/kaizenkodo/ChatMastermind.g ## Requirements -- Python 3.6 or higher +- Python 3.9 or higher - openai - PyYAML - argcomplete diff --git a/setup.py b/setup.py index 252f277..02d9ab1 100644 --- a/setup.py +++ b/setup.py @@ -15,12 +15,18 @@ setup( packages=find_packages(), classifiers=[ "Development Status :: 3 - Alpha", + "Environment :: Console", "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", + "Intended Audience :: End Users/Desktop", + "Intended Audience :: Science/Research", "Operating System :: OS Independent", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Utilities", + "Topic :: Text Processing", ], install_requires=[ "openai", @@ -28,7 +34,7 @@ setup( "argcomplete", "pytest" ], - python_requires=">=3.10", + python_requires=">=3.9", test_suite="tests", entry_points={ "console_scripts": [ -- 2.36.6 From 6406d2f5b5daf35a8cfe550290dabf8196653aba Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 11 Aug 2023 18:12:49 +0200 Subject: [PATCH 013/170] started to implement sub-commands --- chatmastermind/api_client.py | 2 +- chatmastermind/main.py | 242 ++++++++++++++++++++++------------- chatmastermind/storage.py | 16 +-- chatmastermind/utils.py | 4 +- tests/test_main.py | 28 ++-- 5 files changed, 176 insertions(+), 116 deletions(-) diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py index b9b0d05..8eaf695 100644 --- a/chatmastermind/api_client.py +++ b/chatmastermind/api_client.py @@ -5,7 +5,7 @@ def openai_api_key(api_key: str) -> None: openai.api_key = api_key -def display_models() -> None: +def print_models() -> None: not_ready = [] for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): if engine['ready']: diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 68fe906..1b512e4 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -7,32 +7,33 @@ import sys import argcomplete import argparse import pathlib -from .utils import terminal_width, process_tags, display_chat, display_source_code, display_tags_frequency -from .storage import save_answers, create_chat, get_tags, get_tags_unique, read_file, dump_data -from .api_client import ai, openai_api_key, display_models +from .utils import terminal_width, process_tags, print_chat_hist, display_source_code, print_tags_frequency +from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data +from .api_client import ai, openai_api_key, print_models from itertools import zip_longest - -def run_print_command(args: argparse.Namespace, config: dict) -> None: - fname = pathlib.Path(args.print) - if fname.suffix == '.yaml': - with open(args.print, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - elif fname.suffix == '.txt': - data = read_file(fname) - else: - print(f"Unknown file type: {args.print}") - sys.exit(1) - if args.only_source_code: - display_source_code(data['answer']) - else: - print(dump_data(data).strip()) +default_config = '.config.yaml' -def process_and_display_chat(args: argparse.Namespace, +def tags_completer(prefix, parsed_args, **kwargs): + with open(parsed_args.config, 'r') as f: + config = yaml.load(f, Loader=yaml.FullLoader) + return get_tags_unique(config, prefix) + + +def read_config(path: str): + with open(path, 'r') as f: + config = yaml.load(f, Loader=yaml.FullLoader) + return config + + +def create_question_and_chat(args: argparse.Namespace, config: dict, - dump: bool = False ) -> tuple[list[dict[str, str]], str, list[str]]: + """ + Creates the "SI request", including the question and chat history as determined + by the specified tags. + """ tags = args.tags or [] extags = args.extags or [] otags = args.output_tags or [] @@ -55,25 +56,42 @@ def process_and_display_chat(args: argparse.Namespace, question_parts.append(f"```\n{r.read().strip()}\n```") full_question = '\n\n'.join(question_parts) - chat = create_chat(full_question, tags, extags, config, - args.match_all_tags, args.with_tags, - args.with_file) - display_chat(chat, dump, args.only_source_code) + chat = create_chat_hist(full_question, tags, extags, config, + args.match_all_tags, args.with_tags, + args.with_file) return chat, full_question, tags -def process_and_display_tags(args: argparse.Namespace, - config: dict, - dump: bool = False - ) -> None: - display_tags_frequency(get_tags(config, None), dump) +def tag_cmd(args: argparse.Namespace) -> None: + """ + Handler for the 'tag' command. + """ + config = read_config(args.config) + if args.list: + print_tags_frequency(get_tags(config, None), args.dump) -def handle_question(args: argparse.Namespace, - config: dict, - dump: bool = False - ) -> None: - chat, question, tags = process_and_display_chat(args, config, dump) +def model_cmd(args: argparse.Namespace) -> None: + """ + Handler for the 'model' command. + """ + if args.list: + print_models() + + +def ask_cmd(args: argparse.Namespace) -> None: + """ + Handler for the 'ask' command. + """ + config = read_config(args.config) + if args.max_tokens: + config['openai']['max_tokens'] = args.max_tokens + if args.temperature: + config['openai']['temperature'] = args.temperature + if args.model: + config['openai']['model'] = args.model + chat, question, tags = create_question_and_chat(args, config) + print_chat_hist(chat, args.dump, args.only_source_code) otags = args.output_tags or [] answers, usage = ai(chat, config, args.number) save_answers(question, answers, tags, otags, config) @@ -81,77 +99,119 @@ def handle_question(args: argparse.Namespace, print(f"Usage: {usage}") -def tags_completer(prefix, parsed_args, **kwargs): - with open(parsed_args.config, 'r') as f: - config = yaml.load(f, Loader=yaml.FullLoader) - return get_tags_unique(config, prefix) +def hist_cmd(args: argparse.Namespace) -> None: + """ + Handler for the 'hist' command. + """ + config = read_config(args.config) + chat, q, t = create_question_and_chat(args, config) + print_chat_hist(chat, args.dump, args.only_source_code) + + +def print_cmd(args: argparse.Namespace) -> None: + """ + Handler for the 'print' command. + """ + fname = pathlib.Path(args.print) + if fname.suffix == '.yaml': + with open(args.print, 'r') as f: + data = yaml.load(f, Loader=yaml.FullLoader) + elif fname.suffix == '.txt': + data = read_file(fname) + else: + print(f"Unknown file type: {args.print}") + sys.exit(1) + if args.only_source_code: + display_source_code(data['answer']) + else: + print(dump_data(data).strip()) def create_parser() -> argparse.ArgumentParser: - default_config = '.config.yaml' parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") - group = parser.add_mutually_exclusive_group(required=True) - group.add_argument('-p', '--print', help='File to print') - group.add_argument('-q', '--question', nargs='*', help='Question to ask') - group.add_argument('-D', '--chat-dump', help="Print chat history as Python structure", action='store_true') - group.add_argument('-d', '--chat', help="Print chat history as readable text", action='store_true') - group.add_argument('-l', '--list-tags', help="List all tags and their frequency", action='store_true') - group.add_argument('-L', '--list-models', help="List all available models", action='store_true') parser.add_argument('-c', '--config', help='Config file name.', default=default_config) - parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) - parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) - parser.add_argument('-M', '--model', help='Model to use') - parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1) - parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query') - parser.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') - parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", action='store_true') - parser.add_argument('-W', '--with-file', - help="Print chat history with filename.", - action='store_true') - parser.add_argument('-a', '--match-all-tags', - help="All given tags must match when selecting chat history entries.", - action='store_true') - tags_arg = parser.add_argument('-t', '--tags', nargs='*', help='List of tag names', metavar='TAGS') - tags_arg.completer = tags_completer # type: ignore - extags_arg = parser.add_argument('-e', '--extags', nargs='*', help='List of tag names to exclude', metavar='EXTAGS') - extags_arg.completer = tags_completer # type: ignore - otags_arg = parser.add_argument('-o', '--output-tags', nargs='*', help='List of output tag names, default is input', metavar='OTAGS') - otags_arg.completer = tags_completer # type: ignore - argcomplete.autocomplete(parser) + + # subcommand-parser + cmdparser = parser.add_subparsers(dest='command', + title='commands', + description='supported commands') + cmdparser.required = True + + # a parent parser for all commands that support tag selection + tag_parser = argparse.ArgumentParser(add_help=False) + tag_arg = tag_parser.add_argument('-t', '--tags', nargs='*', + help='List of tag names', metavar='TAGS') + tag_arg.completer = tags_completer # type: ignore + extag_arg = tag_parser.add_argument('-e', '--extags', nargs='*', + help='List of tag names to exclude', metavar='EXTAGS') + extag_arg.completer = tags_completer # type: ignore + otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='*', + help='List of output tag names, default is input', metavar='OTAGS') + otag_arg.completer = tags_completer # type: ignore + tag_parser.add_argument('-a', '--match-all-tags', + help="All given tags must match when selecting chat history entries", + action='store_true') + # enable autocompletion for tags + argcomplete.autocomplete(tag_parser) + + # 'ask' command parser + ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], + help="Ask a question.") + ask_cmd_parser.set_defaults(func=ask_cmd) + ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask', required=True) + ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) + ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) + ask_cmd_parser.add_argument('-M', '--model', help='Model to use') + ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1) + ask_cmd_parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query') + + # 'hist' command parser + hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], + help="Print chat history.") + hist_cmd_parser.set_defaults(func=hist_cmd) + hist_cmd_parser.add_argument('-d', '--dump', help="Print chat history as Python structure", + action='store_true') + hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", + action='store_true') + hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", + action='store_true') + hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', + action='store_true') + + # 'tag' command parser + tag_cmd_parser = cmdparser.add_parser('tag', + help="Manage tags.") + tag_cmd_parser.set_defaults(func=tag_cmd) + tag_cmd_parser.add_argument('-l', '--list', help="List all tags and their frequency", + action='store_true') + + # 'model' command parser + model_cmd_parser = cmdparser.add_parser('model', + help="Manage models.") + model_cmd_parser.set_defaults(func=model_cmd) + model_cmd_parser.add_argument('-l', '--list', help="List all available models", + action='store_true') + + # 'print' command parser + print_cmd_parser = cmdparser.add_parser('print', + help="Print files.") + print_cmd_parser.set_defaults(func=print_cmd) + print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) + print_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', + action='store_true') + return parser def main() -> int: parser = create_parser() args = parser.parse_args() + command = parser.parse_args() - with open(args.config, 'r') as f: - config = yaml.load(f, Loader=yaml.FullLoader) + openai_api_key(read_config(args.config)['openai']['api_key']) - openai_api_key(config['openai']['api_key']) - - if args.max_tokens: - config['openai']['max_tokens'] = args.max_tokens - - if args.temperature: - config['openai']['temperature'] = args.temperature - - if args.model: - config['openai']['model'] = args.model - - if args.print: - run_print_command(args, config) - elif args.question: - handle_question(args, config) - elif args.chat_dump: - process_and_display_chat(args, config, dump=True) - elif args.chat: - process_and_display_chat(args, config) - elif args.list_tags: - process_and_display_tags(args, config) - elif args.list_models: - display_models() + command.func(command) return 0 diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index ac59eb5..4705893 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -63,14 +63,14 @@ def save_answers(question: str, f.write(f'{num}') -def create_chat(question: Optional[str], - tags: Optional[List[str]], - extags: Optional[List[str]], - config: Dict[str, Any], - match_all_tags: bool = False, - with_tags: bool = False, - with_file: bool = False - ) -> List[Dict[str, str]]: +def create_chat_hist(question: Optional[str], + tags: Optional[List[str]], + extags: Optional[List[str]], + config: Dict[str, Any], + match_all_tags: bool = False, + with_tags: bool = False, + with_file: bool = False + ) -> List[Dict[str, str]]: chat: List[Dict[str, str]] = [] append_message(chat, 'system', config['system'].strip()) for file in sorted(pathlib.Path(config['db']).iterdir()): diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index bc1dcd2..ca92d25 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -57,7 +57,7 @@ def display_source_code(content: str) -> None: pass -def display_chat(chat, dump=False, source_code=False) -> None: +def print_chat_hist(chat, dump=False, source_code=False) -> None: if dump: pp(chat) return @@ -75,7 +75,7 @@ def display_chat(chat, dump=False, source_code=False) -> None: print(f"{message['role'].upper()}: {message['content']}") -def display_tags_frequency(tags: List[str], dump=False) -> None: +def print_tags_frequency(tags: List[str], dump=False) -> None: if dump: pp(tags) return diff --git a/tests/test_main.py b/tests/test_main.py index 48d9ea8..c0aa32c 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -3,9 +3,9 @@ import io import pathlib import argparse from chatmastermind.utils import terminal_width -from chatmastermind.main import create_parser, handle_question +from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai -from chatmastermind.storage import create_chat, save_answers, dump_data +from chatmastermind.storage import create_chat_hist, save_answers, dump_data from unittest import mock from unittest.mock import patch, MagicMock, Mock @@ -30,7 +30,7 @@ class TestCreateChat(unittest.TestCase): {'question': 'test_content', 'answer': 'some answer', 'tags': ['test_tag']})) - test_chat = create_chat(self.question, self.tags, None, self.config) + test_chat = create_chat_hist(self.question, self.tags, None, self.config) self.assertEqual(len(test_chat), 4) self.assertEqual(test_chat[0], @@ -52,7 +52,7 @@ class TestCreateChat(unittest.TestCase): {'question': 'test_content', 'answer': 'some answer', 'tags': ['other_tag']})) - test_chat = create_chat(self.question, self.tags, None, self.config) + test_chat = create_chat_hist(self.question, self.tags, None, self.config) self.assertEqual(len(test_chat), 2) self.assertEqual(test_chat[0], @@ -75,7 +75,7 @@ class TestCreateChat(unittest.TestCase): 'tags': ['test_tag2']})), ) - test_chat = create_chat(self.question, [], None, self.config) + test_chat = create_chat_hist(self.question, [], None, self.config) self.assertEqual(len(test_chat), 6) self.assertEqual(test_chat[0], @@ -112,24 +112,24 @@ class TestHandleQuestion(unittest.TestCase): 'setting2': 'value2' } - @patch("chatmastermind.main.create_chat", return_value="test_chat") + @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") @patch("chatmastermind.main.process_tags") @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) @patch("chatmastermind.utils.pp") @patch("builtins.print") - def test_handle_question(self, mock_print, mock_pp, mock_ai, - mock_process_tags, mock_create_chat): + def test_ask_cmd(self, mock_print, mock_pp, mock_ai, + mock_process_tags, mock_create_chat_hist): open_mock = MagicMock() with patch("chatmastermind.storage.open", open_mock): - handle_question(self.args, self.config, True) + ask_cmd(self.args, self.config, True) mock_process_tags.assert_called_once_with(self.args.tags, self.args.extags, []) - mock_create_chat.assert_called_once_with(self.question, - self.args.tags, - self.args.extags, - self.config, - False, False, False) + mock_create_chat_hist.assert_called_once_with(self.question, + self.args.tags, + self.args.extags, + self.config, + False, False, False) mock_pp.assert_called_once_with("test_chat") mock_ai.assert_called_with("test_chat", self.config, -- 2.36.6 From f90e7bcd47971c7d0d64406c5210641774df3f40 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 12 Aug 2023 08:13:31 +0200 Subject: [PATCH 014/170] fixed 'hist' command and simplified reading the config file --- chatmastermind/main.py | 38 +++++++++++++++++++++----------------- chatmastermind/utils.py | 5 ++++- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 1b512e4..32622da 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -7,7 +7,7 @@ import sys import argcomplete import argparse import pathlib -from .utils import terminal_width, process_tags, print_chat_hist, display_source_code, print_tags_frequency +from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data from .api_client import ai, openai_api_key, print_models from itertools import zip_longest @@ -27,9 +27,9 @@ def read_config(path: str): return config -def create_question_and_chat(args: argparse.Namespace, - config: dict, - ) -> tuple[list[dict[str, str]], str, list[str]]: +def create_question_with_hist(args: argparse.Namespace, + config: dict, + ) -> tuple[list[dict[str, str]], str, list[str]]: """ Creates the "SI request", including the question and chat history as determined by the specified tags. @@ -39,7 +39,7 @@ def create_question_and_chat(args: argparse.Namespace, otags = args.output_tags or [] if not args.only_source_code: - process_tags(tags, extags, otags) + print_tag_args(tags, extags, otags) question_parts = [] question_list = args.question if args.question is not None else [] @@ -62,16 +62,15 @@ def create_question_and_chat(args: argparse.Namespace, return chat, full_question, tags -def tag_cmd(args: argparse.Namespace) -> None: +def tag_cmd(args: argparse.Namespace, config: dict) -> None: """ Handler for the 'tag' command. """ - config = read_config(args.config) if args.list: print_tags_frequency(get_tags(config, None), args.dump) -def model_cmd(args: argparse.Namespace) -> None: +def model_cmd(args: argparse.Namespace, config: dict) -> None: """ Handler for the 'model' command. """ @@ -79,18 +78,17 @@ def model_cmd(args: argparse.Namespace) -> None: print_models() -def ask_cmd(args: argparse.Namespace) -> None: +def ask_cmd(args: argparse.Namespace, config: dict) -> None: """ Handler for the 'ask' command. """ - config = read_config(args.config) if args.max_tokens: config['openai']['max_tokens'] = args.max_tokens if args.temperature: config['openai']['temperature'] = args.temperature if args.model: config['openai']['model'] = args.model - chat, question, tags = create_question_and_chat(args, config) + chat, question, tags = create_question_with_hist(args, config) print_chat_hist(chat, args.dump, args.only_source_code) otags = args.output_tags or [] answers, usage = ai(chat, config, args.number) @@ -99,16 +97,21 @@ def ask_cmd(args: argparse.Namespace) -> None: print(f"Usage: {usage}") -def hist_cmd(args: argparse.Namespace) -> None: +def hist_cmd(args: argparse.Namespace, config: dict) -> None: """ Handler for the 'hist' command. """ - config = read_config(args.config) - chat, q, t = create_question_and_chat(args, config) + tags = args.tags or [] + extags = args.extags or [] + + chat = create_chat_hist(None, tags, extags, config, + args.match_all_tags, + args.with_tags, + args.with_files) print_chat_hist(chat, args.dump, args.only_source_code) -def print_cmd(args: argparse.Namespace) -> None: +def print_cmd(args: argparse.Namespace, config: dict) -> None: """ Handler for the 'print' command. """ @@ -209,9 +212,10 @@ def main() -> int: args = parser.parse_args() command = parser.parse_args() - openai_api_key(read_config(args.config)['openai']['api_key']) + config = read_config(args.config) + openai_api_key(config['openai']['api_key']) - command.func(command) + command.func(command, config) return 0 diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index ca92d25..cdd0e60 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -11,7 +11,10 @@ def pp(*args, **kwargs) -> None: return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) -def process_tags(tags: list[str], extags: list[str], otags: list[str]) -> None: +def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None: + """ + Prints the tags specified in the given args. + """ printed_messages = [] if tags: -- 2.36.6 From 5a435c5f8f701e14a7a82b675b581aaa34fe10fc Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 12 Aug 2023 08:20:00 +0200 Subject: [PATCH 015/170] fixed 'tag' and 'hist' commands --- chatmastermind/main.py | 8 ++++---- chatmastermind/utils.py | 5 +---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 32622da..f80c26c 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -67,7 +67,7 @@ def tag_cmd(args: argparse.Namespace, config: dict) -> None: Handler for the 'tag' command. """ if args.list: - print_tags_frequency(get_tags(config, None), args.dump) + print_tags_frequency(get_tags(config, None)) def model_cmd(args: argparse.Namespace, config: dict) -> None: @@ -115,14 +115,14 @@ def print_cmd(args: argparse.Namespace, config: dict) -> None: """ Handler for the 'print' command. """ - fname = pathlib.Path(args.print) + fname = pathlib.Path(args.file) if fname.suffix == '.yaml': - with open(args.print, 'r') as f: + with open(args.file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) elif fname.suffix == '.txt': data = read_file(fname) else: - print(f"Unknown file type: {args.print}") + print(f"Unknown file type: {args.file}") sys.exit(1) if args.only_source_code: display_source_code(data['answer']) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index cdd0e60..78440fa 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -78,9 +78,6 @@ def print_chat_hist(chat, dump=False, source_code=False) -> None: print(f"{message['role'].upper()}: {message['content']}") -def print_tags_frequency(tags: List[str], dump=False) -> None: - if dump: - pp(tags) - return +def print_tags_frequency(tags: List[str]) -> None: for tag in sorted(set(tags)): print(f"- {tag}: {tags.count(tag)}") -- 2.36.6 From 5119b3a8743db64fb4f865a19ebab7b4fbf47e16 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 12 Aug 2023 08:28:07 +0200 Subject: [PATCH 016/170] fixed 'ask' command --- chatmastermind/main.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index f80c26c..9c8c3c5 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -57,8 +57,7 @@ def create_question_with_hist(args: argparse.Namespace, full_question = '\n\n'.join(question_parts) chat = create_chat_hist(full_question, tags, extags, config, - args.match_all_tags, args.with_tags, - args.with_file) + args.match_all_tags, False, False) return chat, full_question, tags @@ -89,7 +88,7 @@ def ask_cmd(args: argparse.Namespace, config: dict) -> None: if args.model: config['openai']['model'] = args.model chat, question, tags = create_question_with_hist(args, config) - print_chat_hist(chat, args.dump, args.only_source_code) + print_chat_hist(chat, False, args.only_source_code) otags = args.output_tags or [] answers, usage = ai(chat, config, args.number) save_answers(question, answers, tags, otags, config) @@ -162,12 +161,16 @@ def create_parser() -> argparse.ArgumentParser: ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], help="Ask a question.") ask_cmd_parser.set_defaults(func=ask_cmd) - ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask', required=True) + ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask', + required=True) ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) ask_cmd_parser.add_argument('-M', '--model', help='Model to use') - ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1) + ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, + default=1) ask_cmd_parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query') + ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history', + action='store_true') # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], -- 2.36.6 From 93a8b0081af71f5b0ea05fc71b0393150326a052 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 12 Aug 2023 09:50:54 +0200 Subject: [PATCH 017/170] main: cleanup --- chatmastermind/main.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 9c8c3c5..b1db2d6 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -81,12 +81,6 @@ def ask_cmd(args: argparse.Namespace, config: dict) -> None: """ Handler for the 'ask' command. """ - if args.max_tokens: - config['openai']['max_tokens'] = args.max_tokens - if args.temperature: - config['openai']['temperature'] = args.temperature - if args.model: - config['openai']['model'] = args.model chat, question, tags = create_question_with_hist(args, config) print_chat_hist(chat, False, args.only_source_code) otags = args.output_tags or [] @@ -137,8 +131,8 @@ def create_parser() -> argparse.ArgumentParser: # subcommand-parser cmdparser = parser.add_subparsers(dest='command', title='commands', - description='supported commands') - cmdparser.required = True + description='supported commands', + required=True) # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) @@ -214,9 +208,16 @@ def main() -> int: parser = create_parser() args = parser.parse_args() command = parser.parse_args() - config = read_config(args.config) + + # modify config according to args openai_api_key(config['openai']['api_key']) + if args.max_tokens: + config['openai']['max_tokens'] = args.max_tokens + if args.temperature: + config['openai']['temperature'] = args.temperature + if args.model: + config['openai']['model'] = args.model command.func(command, config) -- 2.36.6 From 056bf4c6b574177c13e8260765f1508c7b220018 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 12 Aug 2023 09:51:13 +0200 Subject: [PATCH 018/170] fixed almost all tests --- tests/test_main.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/tests/test_main.py b/tests/test_main.py index c0aa32c..4434757 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -7,7 +7,7 @@ from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai from chatmastermind.storage import create_chat_hist, save_answers, dump_data from unittest import mock -from unittest.mock import patch, MagicMock, Mock +from unittest.mock import patch, MagicMock, Mock, ANY class TestCreateChat(unittest.TestCase): @@ -113,23 +113,28 @@ class TestHandleQuestion(unittest.TestCase): } @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") - @patch("chatmastermind.main.process_tags") + @patch("chatmastermind.main.print_tag_args") + @patch("chatmastermind.utils.print_chat_hist") @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) @patch("chatmastermind.utils.pp") @patch("builtins.print") def test_ask_cmd(self, mock_print, mock_pp, mock_ai, - mock_process_tags, mock_create_chat_hist): + mock_print_tag_args, mock_create_chat_hist, + mock_print_chat_hist): open_mock = MagicMock() with patch("chatmastermind.storage.open", open_mock): - ask_cmd(self.args, self.config, True) - mock_process_tags.assert_called_once_with(self.args.tags, - self.args.extags, - []) + ask_cmd(self.args, self.config) + mock_print_tag_args.assert_called_once_with(self.args.tags, + self.args.extags, + []) mock_create_chat_hist.assert_called_once_with(self.question, self.args.tags, self.args.extags, self.config, False, False, False) + mock_print_chat_hist.assert_called_once_with('test_chat', + False, + self.args.only_source_code) mock_pp.assert_called_once_with("test_chat") mock_ai.assert_called_with("test_chat", self.config, @@ -205,15 +210,15 @@ class TestAI(unittest.TestCase): class TestCreateParser(unittest.TestCase): def test_create_parser(self): - with patch('argparse.ArgumentParser.add_mutually_exclusive_group') as mock_add_mutually_exclusive_group: - mock_group = Mock() - mock_add_mutually_exclusive_group.return_value = mock_group + with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: + mock_cmdparser = Mock() + mock_add_subparsers.return_value = mock_cmdparser parser = create_parser() self.assertIsInstance(parser, argparse.ArgumentParser) - mock_add_mutually_exclusive_group.assert_called_once_with(required=True) - mock_group.add_argument.assert_any_call('-p', '--print', help='File to print') - mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask') - mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat history as Python structure", action='store_true') - mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat history as readable text", action='store_true') + mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) + mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY) + mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY) + mock_cmdparser.add_parser.assert_any_call('tag', help=ANY) + mock_cmdparser.add_parser.assert_any_call('model', help=ANY) + mock_cmdparser.add_parser.assert_any_call('print', help=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) - self.assertEqual(parser.get_default('number'), 1) -- 2.36.6 From bc5e6228a63e8eb657e7f1ee2f23de5a56721465 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 12 Aug 2023 10:21:09 +0200 Subject: [PATCH 019/170] defined 'ConfigType' for configuration file type hinting --- chatmastermind/main.py | 14 +++++++------- chatmastermind/storage.py | 18 +++++++++--------- chatmastermind/utils.py | 4 +++- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index b1db2d6..843885c 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -7,7 +7,7 @@ import sys import argcomplete import argparse import pathlib -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency +from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data from .api_client import ai, openai_api_key, print_models from itertools import zip_longest @@ -28,7 +28,7 @@ def read_config(path: str): def create_question_with_hist(args: argparse.Namespace, - config: dict, + config: ConfigType, ) -> tuple[list[dict[str, str]], str, list[str]]: """ Creates the "SI request", including the question and chat history as determined @@ -61,7 +61,7 @@ def create_question_with_hist(args: argparse.Namespace, return chat, full_question, tags -def tag_cmd(args: argparse.Namespace, config: dict) -> None: +def tag_cmd(args: argparse.Namespace, config: ConfigType) -> None: """ Handler for the 'tag' command. """ @@ -69,7 +69,7 @@ def tag_cmd(args: argparse.Namespace, config: dict) -> None: print_tags_frequency(get_tags(config, None)) -def model_cmd(args: argparse.Namespace, config: dict) -> None: +def model_cmd(args: argparse.Namespace, config: ConfigType) -> None: """ Handler for the 'model' command. """ @@ -77,7 +77,7 @@ def model_cmd(args: argparse.Namespace, config: dict) -> None: print_models() -def ask_cmd(args: argparse.Namespace, config: dict) -> None: +def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: """ Handler for the 'ask' command. """ @@ -90,7 +90,7 @@ def ask_cmd(args: argparse.Namespace, config: dict) -> None: print(f"Usage: {usage}") -def hist_cmd(args: argparse.Namespace, config: dict) -> None: +def hist_cmd(args: argparse.Namespace, config: ConfigType) -> None: """ Handler for the 'hist' command. """ @@ -104,7 +104,7 @@ def hist_cmd(args: argparse.Namespace, config: dict) -> None: print_chat_hist(chat, args.dump, args.only_source_code) -def print_cmd(args: argparse.Namespace, config: dict) -> None: +def print_cmd(args: argparse.Namespace, config: ConfigType) -> None: """ Handler for the 'print' command. """ diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index 4705893..afd1e8d 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -1,7 +1,7 @@ import yaml import io import pathlib -from .utils import terminal_width, append_message, message_to_chat +from .utils import terminal_width, append_message, message_to_chat, ConfigType from typing import List, Dict, Any, Optional @@ -41,11 +41,11 @@ def save_answers(question: str, answers: list[str], tags: list[str], otags: Optional[list[str]], - config: Dict[str, Any] + config: ConfigType ) -> None: wtags = otags or tags num, inum = 0, 0 - next_fname = pathlib.Path(config['db']) / '.next' + next_fname = pathlib.Path(str(config['db'])) / '.next' try: with open(next_fname, 'r') as f: num = int(f.read()) @@ -66,14 +66,14 @@ def save_answers(question: str, def create_chat_hist(question: Optional[str], tags: Optional[List[str]], extags: Optional[List[str]], - config: Dict[str, Any], + config: ConfigType, match_all_tags: bool = False, with_tags: bool = False, with_file: bool = False ) -> List[Dict[str, str]]: chat: List[Dict[str, str]] = [] - append_message(chat, 'system', config['system'].strip()) - for file in sorted(pathlib.Path(config['db']).iterdir()): + append_message(chat, 'system', str(config['system']).strip()) + for file in sorted(pathlib.Path(str(config['db'])).iterdir()): if file.suffix == '.yaml': with open(file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) @@ -97,9 +97,9 @@ def create_chat_hist(question: Optional[str], return chat -def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]: +def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]: result = [] - for file in sorted(pathlib.Path(config['db']).iterdir()): + for file in sorted(pathlib.Path(str(config['db'])).iterdir()): if file.suffix == '.yaml': with open(file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) @@ -116,5 +116,5 @@ def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]: return result -def get_tags_unique(config: Dict[str, Any], prefix: Optional[str]) -> List[str]: +def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> List[str]: return list(set(get_tags(config, prefix))) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index 78440fa..2a58dae 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -1,6 +1,8 @@ import shutil from pprint import PrettyPrinter -from typing import List, Dict +from typing import List, Dict, Union + +ConfigType = Dict[str, Union[str, Dict[str, Union[str, int]]]] def terminal_width() -> int: -- 2.36.6 From e4d055b90033773bba712f8ab49cf9339d7d83ed Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 12 Aug 2023 12:20:49 +0200 Subject: [PATCH 020/170] Fix the max_tokens, temperature, and model setup. --- chatmastermind/main.py | 26 ++++++++++++++++++-------- chatmastermind/utils.py | 11 +++++------ 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 843885c..3ed387c 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -21,7 +21,7 @@ def tags_completer(prefix, parsed_args, **kwargs): return get_tags_unique(config, prefix) -def read_config(path: str): +def read_config(path: str) -> ConfigType: with open(path, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) return config @@ -81,6 +81,15 @@ def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: """ Handler for the 'ask' command. """ + if type(config['openai']) is not dict: + raise RuntimeError('Configuration openai is not a dict.') + config_openai = config['openai'] + if args.max_tokens: + config_openai['max_tokens'] = args.max_tokens + if args.temperature: + config_openai['temperature'] = args.temperature + if args.model: + config_openai['model'] = args.model chat, question, tags = create_question_with_hist(args, config) print_chat_hist(chat, False, args.only_source_code) otags = args.output_tags or [] @@ -211,13 +220,14 @@ def main() -> int: config = read_config(args.config) # modify config according to args - openai_api_key(config['openai']['api_key']) - if args.max_tokens: - config['openai']['max_tokens'] = args.max_tokens - if args.temperature: - config['openai']['temperature'] = args.temperature - if args.model: - config['openai']['model'] = args.model + if type(config['openai']) is dict: + config_openai = config['openai'] + else: + RuntimeError("Configuration openai is not a dict.") + if type(config_openai['api_key']) is str: + openai_api_key(config_openai['api_key']) + else: + raise RuntimeError("Configuration openai.api_key is not a string.") command.func(command, config) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index 2a58dae..fba8296 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -1,8 +1,7 @@ import shutil from pprint import PrettyPrinter -from typing import List, Dict, Union -ConfigType = Dict[str, Union[str, Dict[str, Union[str, int]]]] +ConfigType = dict[str, str | dict[str, str | int | float]] def terminal_width() -> int: @@ -31,15 +30,15 @@ def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None print() -def append_message(chat: List[Dict[str, str]], +def append_message(chat: list[dict[str, str]], role: str, content: str ) -> None: chat.append({'role': role, 'content': content.replace("''", "'")}) -def message_to_chat(message: Dict[str, str], - chat: List[Dict[str, str]], +def message_to_chat(message: dict[str, str], + chat: list[dict[str, str]], with_tags: bool = False, with_file: bool = False ) -> None: @@ -80,6 +79,6 @@ def print_chat_hist(chat, dump=False, source_code=False) -> None: print(f"{message['role'].upper()}: {message['content']}") -def print_tags_frequency(tags: List[str]) -> None: +def print_tags_frequency(tags: list[str]) -> None: for tag in sorted(set(tags)): print(f"- {tag}: {tags.count(tag)}") -- 2.36.6 From 4b2f634b79ea491ab0bddec3c8fcd56c7ae26934 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 12 Aug 2023 12:30:07 +0200 Subject: [PATCH 021/170] Remove wrong comment and make it more readable. --- chatmastermind/main.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 3ed387c..3150931 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -219,15 +219,10 @@ def main() -> int: command = parser.parse_args() config = read_config(args.config) - # modify config according to args - if type(config['openai']) is dict: - config_openai = config['openai'] + if type(config['openai']) is dict and type(config['openai']['api_key']) is str: + openai_api_key(config['openai']['api_key']) else: - RuntimeError("Configuration openai is not a dict.") - if type(config_openai['api_key']) is str: - openai_api_key(config_openai['api_key']) - else: - raise RuntimeError("Configuration openai.api_key is not a string.") + raise RuntimeError("Configuration openai.api_key is wrong.") command.func(command, config) -- 2.36.6 From 1fb9144192b8839839a0e6e29b285a783bbd044d Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 12 Aug 2023 12:44:13 +0200 Subject: [PATCH 022/170] Change REDAME.md with the new call semantics. --- README.md | 78 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 4ff5d97..95d60a9 100644 --- a/README.md +++ b/README.md @@ -29,65 +29,99 @@ pip install . ## Usage +The `cmm` script has global options, a list of commands, and options per command: + ```bash -cmm [-h] [-p PRINT | -q QUESTION | -D | -d | -l] [-c CONFIG] [-m MAX_TOKENS] [-T TEMPERATURE] [-M MODEL] [-n NUMBER] [-t [TAGS [TAGS ...]]] [-e [EXTAGS [EXTAGS ...]]] [-o [OTAGS [OTAGS ...]]] [-a] [-w] [-W] +cmm [global options] command [command options] ``` -### Arguments +### Global Options -- `-p`, `--print`: YAML file to print. -- `-q`, `--question`: Question to ask. -- `-D`, `--chat-dump`: Print chat history as a Python structure. -- `-d`, `--chat`: Print chat history as readable text. -- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries. -- `-w`, `--with-tags`: Print chat history with tags. -- `-W`, `--with-tags`: Print chat history with filenames. -- `-l`, `--list-tags`: List all tags and their frequency. - `-c`, `--config`: Config file name (defaults to `.config.yaml`). + +### Commands + +- `ask`: Ask a question. +- `hist`: Print chat history. +- `tag`: Manage tags. +- `model`: Manage models. +- `print`: Print files. + +### Command Options + +#### `ask` Command Options + +- `-q`, `--question`: Question to ask (required). - `-m`, `--max-tokens`: Max tokens to use. - `-T`, `--temperature`: Temperature to use. - `-M`, `--model`: Model to use. - `-n`, `--number`: Number of answers to produce (default is 3). +- `-s`, `--source`: Add content of a file to the query. +- `-S`, `--only-source-code`: Add pure source code to the chat history. - `-t`, `--tags`: List of tag names. - `-e`, `--extags`: List of tag names to exclude. - `-o`, `--output-tags`: List of output tag names (default is the input tags). +- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries. + +#### `hist` Command Options + +- `-d`, `--dump`: Print chat history as Python structure. +- `-w`, `--with-tags`: Print chat history with tags. +- `-W`, `--with-files`: Print chat history with filenames. +- `-S`, `--only-source-code`: Print only source code. +- `-t`, `--tags`: List of tag names. +- `-e`, `--extags`: List of tag names to exclude. +- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries. + +#### `tag` Command Options + +- `-l`, `--list`: List all tags and their frequency. + +#### `model` Command Options + +- `-l`, `--list`: List all available models. + +#### `print` Command Options + +- `-f`, `--file`: File to print (required). +- `-S`, `--only-source-code`: Print only source code. ### Examples -1. Print the contents of a YAML file: +1. Ask a question: ```bash -cmm -p example.yaml +cmm ask -q "What is the meaning of life?" -t philosophy -e religion ``` -2. Ask a question: +2. Display the chat history: ```bash -cmm -q "What is the meaning of life?" -t philosophy -e religion +cmm hist ``` -3. Display the chat history as a Python structure: +3. Filter chat history by tags: ```bash -cmm -D +cmm hist -t tag1 tag2 ``` -4. Display the chat history as readable text: +4. Exclude chat history by tags: ```bash -cmm -d +cmm hist -e tag3 tag4 ``` -5. Filter chat history by tags: +5. List all tags and their frequency: ```bash -cmm -d -t tag1 tag2 +cmm tag -l ``` -6. Exclude chat history by tags: +6. Print the contents of a file: ```bash -cmm -d -e tag3 tag4 +cmm print -f example.yaml ``` ## Configuration -- 2.36.6 From 6ed459be6fe5fa76b631cb6b972b31c26d5cd634 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 12 Aug 2023 13:17:10 +0200 Subject: [PATCH 023/170] Fix tests. --- tests/test_main.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/test_main.py b/tests/test_main.py index 4434757..632124a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -102,6 +102,9 @@ class TestHandleQuestion(unittest.TestCase): source=None, only_source_code=False, number=3, + max_tokens=None, + temperature=None, + model=None, match_all_tags=False, with_tags=False, with_file=False, @@ -109,18 +112,19 @@ class TestHandleQuestion(unittest.TestCase): self.config = { 'db': 'test_files', 'setting1': 'value1', - 'setting2': 'value2' + 'setting2': 'value2', + 'openai': {}, } @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") @patch("chatmastermind.main.print_tag_args") - @patch("chatmastermind.utils.print_chat_hist") + @patch("chatmastermind.main.print_chat_hist") @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) @patch("chatmastermind.utils.pp") @patch("builtins.print") def test_ask_cmd(self, mock_print, mock_pp, mock_ai, - mock_print_tag_args, mock_create_chat_hist, - mock_print_chat_hist): + mock_print_chat_hist, mock_print_tag_args, + mock_create_chat_hist): open_mock = MagicMock() with patch("chatmastermind.storage.open", open_mock): ask_cmd(self.args, self.config) @@ -135,7 +139,6 @@ class TestHandleQuestion(unittest.TestCase): mock_print_chat_hist.assert_called_once_with('test_chat', False, self.args.only_source_code) - mock_pp.assert_called_once_with("test_chat") mock_ai.assert_called_with("test_chat", self.config, self.args.number) -- 2.36.6 From f371a6146e00b41a1eb005f94da0842a1772f8ff Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 12 Aug 2023 13:55:39 +0200 Subject: [PATCH 024/170] moved 'read_config' to storage.py and added 'write_config' --- chatmastermind/main.py | 8 +------- chatmastermind/storage.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 3150931..0486ae6 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -8,7 +8,7 @@ import argcomplete import argparse import pathlib from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType -from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data +from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, dump_data from .api_client import ai, openai_api_key, print_models from itertools import zip_longest @@ -21,12 +21,6 @@ def tags_completer(prefix, parsed_args, **kwargs): return get_tags_unique(config, prefix) -def read_config(path: str) -> ConfigType: - with open(path, 'r') as f: - config = yaml.load(f, Loader=yaml.FullLoader) - return config - - def create_question_with_hist(args: argparse.Namespace, config: ConfigType, ) -> tuple[list[dict[str, str]], str, list[str]]: diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index afd1e8d..d90598b 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -22,6 +22,17 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: "file": fname.name} +def read_config(path: str) -> ConfigType: + with open(path, 'r') as f: + config = yaml.load(f, Loader=yaml.FullLoader) + return config + + +def write_config(path: str, config: ConfigType) -> None: + with open(path, 'w') as f: + yaml.dump(config, f) + + def dump_data(data: Dict[str, Any]) -> str: with io.StringIO() as fd: fd.write(f'TAGS: {" ".join(data["tags"])}\n') -- 2.36.6 From b6eb7d9af8e2a50a9bf20db7a84c4660adfb04a9 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 12 Aug 2023 13:57:52 +0200 Subject: [PATCH 025/170] Fix autocompletion. --- chatmastermind/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 0486ae6..2e0cce1 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -152,7 +152,6 @@ def create_parser() -> argparse.ArgumentParser: help="All given tags must match when selecting chat history entries", action='store_true') # enable autocompletion for tags - argcomplete.autocomplete(tag_parser) # 'ask' command parser ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], @@ -204,6 +203,7 @@ def create_parser() -> argparse.ArgumentParser: print_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') + argcomplete.autocomplete(parser) return parser -- 2.36.6 From f7ba0c000f4dae95380d3fa8e8cc8af996972df1 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 12 Aug 2023 14:12:35 +0200 Subject: [PATCH 026/170] renamed 'model' command to 'config' --- chatmastermind/main.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 2e0cce1..cc634fc 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -8,7 +8,7 @@ import argcomplete import argparse import pathlib from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType -from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, dump_data +from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data from .api_client import ai, openai_api_key, print_models from itertools import zip_longest @@ -63,12 +63,20 @@ def tag_cmd(args: argparse.Namespace, config: ConfigType) -> None: print_tags_frequency(get_tags(config, None)) -def model_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def config_cmd(args: argparse.Namespace, config: ConfigType) -> None: """ - Handler for the 'model' command. + Handler for the 'config' command. """ - if args.list: + if type(config['openai']) is not dict: + raise RuntimeError('Configuration openai is not a dict.') + + if args.list_models: print_models() + elif args.show_model: + print(config['openai']['model']) + elif args.model: + config['openai']['model'] = args.model + write_config(args.config, config) def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: @@ -188,12 +196,16 @@ def create_parser() -> argparse.ArgumentParser: tag_cmd_parser.add_argument('-l', '--list', help="List all tags and their frequency", action='store_true') - # 'model' command parser - model_cmd_parser = cmdparser.add_parser('model', - help="Manage models.") - model_cmd_parser.set_defaults(func=model_cmd) - model_cmd_parser.add_argument('-l', '--list', help="List all available models", - action='store_true') + # 'config' command parser + config_cmd_parser = cmdparser.add_parser('config', + help="Manage configuration") + config_cmd_parser.set_defaults(func=config_cmd) + config_group = config_cmd_parser.add_mutually_exclusive_group(required=True) + config_group.add_argument('-L', '--list-models', help="List all available models", + action='store_true') + config_group.add_argument('-m', '--show-model', help="Show current model", + action='store_true') + config_group.add_argument('-M', '--model', help="Set model in the config file") # 'print' command parser print_cmd_parser = cmdparser.add_parser('print', -- 2.36.6 From 22bebc16ed1300fb1226794e1d1e80b58d72c085 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 12 Aug 2023 14:14:06 +0200 Subject: [PATCH 027/170] fixed min nr of expected arguments --- chatmastermind/main.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index cc634fc..623b83a 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -147,13 +147,13 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) - tag_arg = tag_parser.add_argument('-t', '--tags', nargs='*', + tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+', help='List of tag names', metavar='TAGS') tag_arg.completer = tags_completer # type: ignore - extag_arg = tag_parser.add_argument('-e', '--extags', nargs='*', + extag_arg = tag_parser.add_argument('-e', '--extags', nargs='+', help='List of tag names to exclude', metavar='EXTAGS') extag_arg.completer = tags_completer # type: ignore - otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='*', + otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', help='List of output tag names, default is input', metavar='OTAGS') otag_arg.completer = tags_completer # type: ignore tag_parser.add_argument('-a', '--match-all-tags', @@ -172,7 +172,7 @@ def create_parser() -> argparse.ArgumentParser: ask_cmd_parser.add_argument('-M', '--model', help='Model to use') ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1) - ask_cmd_parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query') + ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history', action='store_true') -- 2.36.6 From c4a7c07a0c8d2875bc87579e1fab3a042dd3ebe6 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 12 Aug 2023 14:14:51 +0200 Subject: [PATCH 028/170] fixed tests --- tests/test_main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_main.py b/tests/test_main.py index 632124a..0cfd1fd 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -222,6 +222,6 @@ class TestCreateParser(unittest.TestCase): mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY) mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY) mock_cmdparser.add_parser.assert_any_call('tag', help=ANY) - mock_cmdparser.add_parser.assert_any_call('model', help=ANY) + mock_cmdparser.add_parser.assert_any_call('config', help=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) -- 2.36.6 From 1e15a52e269f3bf44e472f84402958fe1e7b27e0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 12 Aug 2023 18:34:19 +0200 Subject: [PATCH 029/170] updated README and some minor renaming --- README.md | 8 +++++--- chatmastermind/main.py | 8 ++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 95d60a9..d55102a 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ cmm [global options] command [command options] - `ask`: Ask a question. - `hist`: Print chat history. - `tag`: Manage tags. -- `model`: Manage models. +- `config`: Manage configuration. - `print`: Print files. ### Command Options @@ -77,9 +77,11 @@ cmm [global options] command [command options] - `-l`, `--list`: List all tags and their frequency. -#### `model` Command Options +#### `config` Command Options -- `-l`, `--list`: List all available models. +- `-l`, `--list-models`: List all available models. +- `-m`, `--print-model`: Print the currently configured model. +- `-M`, `--model`: Set model in the config file. #### `print` Command Options diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 623b83a..0d68779 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -25,7 +25,7 @@ def create_question_with_hist(args: argparse.Namespace, config: ConfigType, ) -> tuple[list[dict[str, str]], str, list[str]]: """ - Creates the "SI request", including the question and chat history as determined + Creates the "AI request", including the question and chat history as determined by the specified tags. """ tags = args.tags or [] @@ -72,7 +72,7 @@ def config_cmd(args: argparse.Namespace, config: ConfigType) -> None: if args.list_models: print_models() - elif args.show_model: + elif args.print_model: print(config['openai']['model']) elif args.model: config['openai']['model'] = args.model @@ -201,9 +201,9 @@ def create_parser() -> argparse.ArgumentParser: help="Manage configuration") config_cmd_parser.set_defaults(func=config_cmd) config_group = config_cmd_parser.add_mutually_exclusive_group(required=True) - config_group.add_argument('-L', '--list-models', help="List all available models", + config_group.add_argument('-l', '--list-models', help="List all available models", action='store_true') - config_group.add_argument('-m', '--show-model', help="Show current model", + config_group.add_argument('-m', '--print-model', help="Print the currently configured model", action='store_true') config_group.add_argument('-M', '--model', help="Set model in the config file") -- 2.36.6 From a5075b14a0ad617b131565a906e8100cbd9fc76f Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 13 Aug 2023 08:41:59 +0200 Subject: [PATCH 030/170] added short aliases for subcommands --- chatmastermind/main.py | 20 +++++++++++++------- tests/test_main.py | 10 +++++----- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 0d68779..15e8208 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -163,7 +163,8 @@ def create_parser() -> argparse.ArgumentParser: # 'ask' command parser ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], - help="Ask a question.") + help="Ask a question.", + aliases=['a']) ask_cmd_parser.set_defaults(func=ask_cmd) ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask', required=True) @@ -178,7 +179,8 @@ def create_parser() -> argparse.ArgumentParser: # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], - help="Print chat history.") + help="Print chat history.", + aliases=['h']) hist_cmd_parser.set_defaults(func=hist_cmd) hist_cmd_parser.add_argument('-d', '--dump', help="Print chat history as Python structure", action='store_true') @@ -191,14 +193,17 @@ def create_parser() -> argparse.ArgumentParser: # 'tag' command parser tag_cmd_parser = cmdparser.add_parser('tag', - help="Manage tags.") + help="Manage tags.", + aliases=['t']) tag_cmd_parser.set_defaults(func=tag_cmd) - tag_cmd_parser.add_argument('-l', '--list', help="List all tags and their frequency", - action='store_true') + tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True) + tag_group.add_argument('-l', '--list', help="List all tags and their frequency", + action='store_true') # 'config' command parser config_cmd_parser = cmdparser.add_parser('config', - help="Manage configuration") + help="Manage configuration", + aliases=['c']) config_cmd_parser.set_defaults(func=config_cmd) config_group = config_cmd_parser.add_mutually_exclusive_group(required=True) config_group.add_argument('-l', '--list-models', help="List all available models", @@ -209,7 +214,8 @@ def create_parser() -> argparse.ArgumentParser: # 'print' command parser print_cmd_parser = cmdparser.add_parser('print', - help="Print files.") + help="Print files.", + aliases=['p']) print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) print_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', diff --git a/tests/test_main.py b/tests/test_main.py index 0cfd1fd..3634740 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -219,9 +219,9 @@ class TestCreateParser(unittest.TestCase): parser = create_parser() self.assertIsInstance(parser, argparse.ArgumentParser) mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) - mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY) - mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY) - mock_cmdparser.add_parser.assert_any_call('tag', help=ANY) - mock_cmdparser.add_parser.assert_any_call('config', help=ANY) - mock_cmdparser.add_parser.assert_any_call('print', help=ANY) + mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) + mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) + mock_cmdparser.add_parser.assert_any_call('tag', help=ANY, aliases=ANY) + mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) + mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) -- 2.36.6 From ba41794f4e9499b5ed611a25743bf27f704ea6c6 Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 15 Aug 2023 09:47:58 +0200 Subject: [PATCH 031/170] mypy: added 'disallow_untyped_defs = True' --- mypy.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/mypy.ini b/mypy.ini index b99c5a5..aecd40e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -5,3 +5,4 @@ strict_optional = True warn_unused_ignores = False warn_redundant_casts = True warn_unused_configs = True +disallow_untyped_defs = True -- 2.36.6 From 4303fb414f6f12702264549da625acfd6a53b5b7 Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 15 Aug 2023 23:36:45 +0200 Subject: [PATCH 032/170] added typ hints for all functions in 'main.py', 'utils.py', 'storage.py' and 'api_client.py' --- chatmastermind/api_client.py | 15 +++++++++++++-- chatmastermind/main.py | 18 +++++++++--------- chatmastermind/storage.py | 22 +++++++++++----------- chatmastermind/utils.py | 10 ++++++---- 4 files changed, 39 insertions(+), 26 deletions(-) diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py index 8eaf695..d3282eb 100644 --- a/chatmastermind/api_client.py +++ b/chatmastermind/api_client.py @@ -1,11 +1,16 @@ import openai +from .utils import ConfigType, ChatType + def openai_api_key(api_key: str) -> None: openai.api_key = api_key def print_models() -> None: + """ + Print all models supported by the current AI. + """ not_ready = [] for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): if engine['ready']: @@ -16,10 +21,16 @@ def print_models() -> None: print('\nNot ready: ' + ', '.join(not_ready)) -def ai(chat: list[dict[str, str]], - config: dict, +def ai(chat: ChatType, + config: ConfigType, number: int ) -> tuple[list[str], dict[str, int]]: + """ + Make AI request with the given chat history and configuration. + Return AI response and tokens used. + """ + if not isinstance(config['openai'], dict): + raise RuntimeError('Configuration openai is not a dict.') response = openai.ChatCompletion.create( model=config['openai']['model'], messages=chat, diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 15e8208..ec33cb3 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -7,15 +7,16 @@ import sys import argcomplete import argparse import pathlib -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType +from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType, ChatType from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data from .api_client import ai, openai_api_key, print_models from itertools import zip_longest +from typing import Any default_config = '.config.yaml' -def tags_completer(prefix, parsed_args, **kwargs): +def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: with open(parsed_args.config, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) return get_tags_unique(config, prefix) @@ -23,7 +24,7 @@ def tags_completer(prefix, parsed_args, **kwargs): def create_question_with_hist(args: argparse.Namespace, config: ConfigType, - ) -> tuple[list[dict[str, str]], str, list[str]]: + ) -> tuple[ChatType, str, list[str]]: """ Creates the "AI request", including the question and chat history as determined by the specified tags. @@ -67,7 +68,7 @@ def config_cmd(args: argparse.Namespace, config: ConfigType) -> None: """ Handler for the 'config' command. """ - if type(config['openai']) is not dict: + if not isinstance(config['openai'], dict): raise RuntimeError('Configuration openai is not a dict.') if args.list_models: @@ -83,15 +84,14 @@ def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: """ Handler for the 'ask' command. """ - if type(config['openai']) is not dict: + if not isinstance(config['openai'], dict): raise RuntimeError('Configuration openai is not a dict.') - config_openai = config['openai'] if args.max_tokens: - config_openai['max_tokens'] = args.max_tokens + config['openai']['max_tokens'] = args.max_tokens if args.temperature: - config_openai['temperature'] = args.temperature + config['openai']['temperature'] = args.temperature if args.model: - config_openai['model'] = args.model + config['openai']['model'] = args.model chat, question, tags = create_question_with_hist(args, config) print_chat_hist(chat, False, args.only_source_code) otags = args.output_tags or [] diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index d90598b..fa3fb14 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -1,11 +1,11 @@ import yaml import io import pathlib -from .utils import terminal_width, append_message, message_to_chat, ConfigType -from typing import List, Dict, Any, Optional +from .utils import terminal_width, append_message, message_to_chat, ConfigType, ChatType +from typing import Any, Optional -def read_file(fname: pathlib.Path, tags_only: bool = False) -> Dict[str, Any]: +def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: with open(fname, "r") as fd: tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() # also support tags separated by ',' (old format) @@ -33,7 +33,7 @@ def write_config(path: str, config: ConfigType) -> None: yaml.dump(config, f) -def dump_data(data: Dict[str, Any]) -> str: +def dump_data(data: dict[str, Any]) -> str: with io.StringIO() as fd: fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n') @@ -41,7 +41,7 @@ def dump_data(data: Dict[str, Any]) -> str: return fd.getvalue() -def write_file(fname: str, data: Dict[str, Any]) -> None: +def write_file(fname: str, data: dict[str, Any]) -> None: with open(fname, "w") as fd: fd.write(f'TAGS: {" ".join(data["tags"])}\n') fd.write(f'=== QUESTION ===\n{data["question"]}\n') @@ -75,14 +75,14 @@ def save_answers(question: str, def create_chat_hist(question: Optional[str], - tags: Optional[List[str]], - extags: Optional[List[str]], + tags: Optional[list[str]], + extags: Optional[list[str]], config: ConfigType, match_all_tags: bool = False, with_tags: bool = False, with_file: bool = False - ) -> List[Dict[str, str]]: - chat: List[Dict[str, str]] = [] + ) -> ChatType: + chat: ChatType = [] append_message(chat, 'system', str(config['system']).strip()) for file in sorted(pathlib.Path(str(config['db'])).iterdir()): if file.suffix == '.yaml': @@ -108,7 +108,7 @@ def create_chat_hist(question: Optional[str], return chat -def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]: +def get_tags(config: ConfigType, prefix: Optional[str]) -> list[str]: result = [] for file in sorted(pathlib.Path(str(config['db'])).iterdir()): if file.suffix == '.yaml': @@ -127,5 +127,5 @@ def get_tags(config: ConfigType, prefix: Optional[str]) -> List[str]: return result -def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> List[str]: +def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> list[str]: return list(set(get_tags(config, prefix))) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index fba8296..c6d527c 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -1,14 +1,16 @@ import shutil from pprint import PrettyPrinter +from typing import Any ConfigType = dict[str, str | dict[str, str | int | float]] +ChatType = list[dict[str, str]] def terminal_width() -> int: return shutil.get_terminal_size().columns -def pp(*args, **kwargs) -> None: +def pp(*args: Any, **kwargs: Any) -> None: return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) @@ -30,7 +32,7 @@ def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None print() -def append_message(chat: list[dict[str, str]], +def append_message(chat: ChatType, role: str, content: str ) -> None: @@ -38,7 +40,7 @@ def append_message(chat: list[dict[str, str]], def message_to_chat(message: dict[str, str], - chat: list[dict[str, str]], + chat: ChatType, with_tags: bool = False, with_file: bool = False ) -> None: @@ -61,7 +63,7 @@ def display_source_code(content: str) -> None: pass -def print_chat_hist(chat, dump=False, source_code=False) -> None: +def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None: if dump: pp(chat) return -- 2.36.6 From dc13213c4dd118848d395a009ad98bd2277a0a0a Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 16 Aug 2023 08:14:41 +0200 Subject: [PATCH 033/170] configuration is now a TypedDict in its own module --- chatmastermind/api_client.py | 5 +++-- chatmastermind/configuration.py | 23 +++++++++++++++++++++++ chatmastermind/main.py | 25 +++++++++---------------- chatmastermind/storage.py | 15 ++++++++------- chatmastermind/utils.py | 1 - 5 files changed, 43 insertions(+), 26 deletions(-) create mode 100644 chatmastermind/configuration.py diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py index d3282eb..d8634bd 100644 --- a/chatmastermind/api_client.py +++ b/chatmastermind/api_client.py @@ -1,6 +1,7 @@ import openai -from .utils import ConfigType, ChatType +from .utils import ChatType +from .configuration import Config def openai_api_key(api_key: str) -> None: @@ -22,7 +23,7 @@ def print_models() -> None: def ai(chat: ChatType, - config: ConfigType, + config: Config, number: int ) -> tuple[list[str], dict[str, int]]: """ diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py new file mode 100644 index 0000000..2917865 --- /dev/null +++ b/chatmastermind/configuration.py @@ -0,0 +1,23 @@ +from typing import TypedDict + + +class OpenAIConfig(TypedDict): + """ + The OpenAI section of the configuration file. + """ + api_key: str + model: str + temperature: float + max_tokens: int + top_p: float + frequency_penalty: float + presence_penalty: float + + +class Config(TypedDict): + """ + The configuration file structure. + """ + system: str + db: str + openai: OpenAIConfig diff --git a/chatmastermind/main.py b/chatmastermind/main.py index ec33cb3..7c6df33 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -7,9 +7,10 @@ import sys import argcomplete import argparse import pathlib -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ConfigType, ChatType +from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data from .api_client import ai, openai_api_key, print_models +from .configuration import Config from itertools import zip_longest from typing import Any @@ -23,7 +24,7 @@ def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: def create_question_with_hist(args: argparse.Namespace, - config: ConfigType, + config: Config, ) -> tuple[ChatType, str, list[str]]: """ Creates the "AI request", including the question and chat history as determined @@ -56,7 +57,7 @@ def create_question_with_hist(args: argparse.Namespace, return chat, full_question, tags -def tag_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def tag_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'tag' command. """ @@ -64,13 +65,10 @@ def tag_cmd(args: argparse.Namespace, config: ConfigType) -> None: print_tags_frequency(get_tags(config, None)) -def config_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def config_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'config' command. """ - if not isinstance(config['openai'], dict): - raise RuntimeError('Configuration openai is not a dict.') - if args.list_models: print_models() elif args.print_model: @@ -80,12 +78,10 @@ def config_cmd(args: argparse.Namespace, config: ConfigType) -> None: write_config(args.config, config) -def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def ask_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'ask' command. """ - if not isinstance(config['openai'], dict): - raise RuntimeError('Configuration openai is not a dict.') if args.max_tokens: config['openai']['max_tokens'] = args.max_tokens if args.temperature: @@ -101,7 +97,7 @@ def ask_cmd(args: argparse.Namespace, config: ConfigType) -> None: print(f"Usage: {usage}") -def hist_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def hist_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'hist' command. """ @@ -115,7 +111,7 @@ def hist_cmd(args: argparse.Namespace, config: ConfigType) -> None: print_chat_hist(chat, args.dump, args.only_source_code) -def print_cmd(args: argparse.Namespace, config: ConfigType) -> None: +def print_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'print' command. """ @@ -231,10 +227,7 @@ def main() -> int: command = parser.parse_args() config = read_config(args.config) - if type(config['openai']) is dict and type(config['openai']['api_key']) is str: - openai_api_key(config['openai']['api_key']) - else: - raise RuntimeError("Configuration openai.api_key is wrong.") + openai_api_key(config['openai']['api_key']) command.func(command, config) diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index fa3fb14..ca8ae32 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -1,7 +1,8 @@ import yaml import io import pathlib -from .utils import terminal_width, append_message, message_to_chat, ConfigType, ChatType +from .utils import terminal_width, append_message, message_to_chat, ChatType +from .configuration import Config from typing import Any, Optional @@ -22,13 +23,13 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: "file": fname.name} -def read_config(path: str) -> ConfigType: +def read_config(path: str) -> Config: with open(path, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) return config -def write_config(path: str, config: ConfigType) -> None: +def write_config(path: str, config: Config) -> None: with open(path, 'w') as f: yaml.dump(config, f) @@ -52,7 +53,7 @@ def save_answers(question: str, answers: list[str], tags: list[str], otags: Optional[list[str]], - config: ConfigType + config: Config ) -> None: wtags = otags or tags num, inum = 0, 0 @@ -77,7 +78,7 @@ def save_answers(question: str, def create_chat_hist(question: Optional[str], tags: Optional[list[str]], extags: Optional[list[str]], - config: ConfigType, + config: Config, match_all_tags: bool = False, with_tags: bool = False, with_file: bool = False @@ -108,7 +109,7 @@ def create_chat_hist(question: Optional[str], return chat -def get_tags(config: ConfigType, prefix: Optional[str]) -> list[str]: +def get_tags(config: Config, prefix: Optional[str]) -> list[str]: result = [] for file in sorted(pathlib.Path(str(config['db'])).iterdir()): if file.suffix == '.yaml': @@ -127,5 +128,5 @@ def get_tags(config: ConfigType, prefix: Optional[str]) -> list[str]: return result -def get_tags_unique(config: ConfigType, prefix: Optional[str]) -> list[str]: +def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]: return list(set(get_tags(config, prefix))) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index c6d527c..bd80e4f 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -2,7 +2,6 @@ import shutil from pprint import PrettyPrinter from typing import Any -ConfigType = dict[str, str | dict[str, str | int | float]] ChatType = list[dict[str, str]] -- 2.36.6 From ee8deed320c1c260f51c0c6a6c26f7f419264845 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 16 Aug 2023 08:39:15 +0200 Subject: [PATCH 034/170] configuration: added validation --- chatmastermind/configuration.py | 42 ++++++++++++++++++++++++++++++++- chatmastermind/storage.py | 5 +++- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 2917865..bc58574 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -1,4 +1,5 @@ -from typing import TypedDict +import pathlib +from typing import TypedDict, Any class OpenAIConfig(TypedDict): @@ -14,6 +15,25 @@ class OpenAIConfig(TypedDict): presence_penalty: float +def openai_config_valid(conf: dict[str, str | float | int]) -> bool: + """ + Checks if the given Open AI configuration dict is complete + and contains valid types and values. + """ + try: + str(conf['api_key']) + str(conf['model']) + int(conf['max_tokens']) + float(conf['temperature']) + float(conf['top_p']) + float(conf['frequency_penalty']) + float(conf['presence_penalty']) + return True + except Exception as e: + print(f"OpenAI configuration is invalid: {e}") + return False + + class Config(TypedDict): """ The configuration file structure. @@ -21,3 +41,23 @@ class Config(TypedDict): system: str db: str openai: OpenAIConfig + + +def config_valid(conf: dict[str, Any]) -> bool: + """ + Checks if the given configuration dict is complete + and contains valid types and values. + """ + try: + str(conf['system']) + pathlib.Path(str(conf['db'])) + return True + except Exception as e: + print(f"Configuration is invalid: {e}") + return False + if 'openai' in conf: + return openai_config_valid(conf['openai']) + else: + # required as long as we only support OpenAI + print("Section 'openai' is missing in the configuration!") + return False diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index ca8ae32..a4648b0 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -1,8 +1,9 @@ import yaml +import sys import io import pathlib from .utils import terminal_width, append_message, message_to_chat, ChatType -from .configuration import Config +from .configuration import Config, config_valid from typing import Any, Optional @@ -26,6 +27,8 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: def read_config(path: str) -> Config: with open(path, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) + if not config_valid(config): + sys.exit(1) return config -- 2.36.6 From e8343fde013303da6b84026da39cc0032a7385c7 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 16 Aug 2023 11:15:14 +0200 Subject: [PATCH 035/170] test_main: added type annotations and a helper class / function --- tests/test_main.py | 88 +++++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 40 deletions(-) diff --git a/tests/test_main.py b/tests/test_main.py index 3634740..4a70cbb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,25 +5,46 @@ import argparse from chatmastermind.utils import terminal_width from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai +from chatmastermind.configuration import Config, OpenAIConfig from chatmastermind.storage import create_chat_hist, save_answers, dump_data from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY -class TestCreateChat(unittest.TestCase): +class CmmTestCase(unittest.TestCase): + """ + Base class for all cmm testcases. + """ + def dummy_config(self, db: str) -> Config: + """ + Creates a dummy configuration. + """ + return Config( + system='dummy_system', + db=db, + openai=OpenAIConfig( + api_key='dummy_key', + model='dummy_model', + max_tokens=4000, + temperature=1.0, + top_p=1, + frequency_penalty=0, + presence_penalty=0 + ) + ) - def setUp(self): - self.config = { - 'system': 'System text', - 'db': 'test_files' - } + +class TestCreateChat(CmmTestCase): + + def setUp(self) -> None: + self.config = self.dummy_config(db='test_files') self.question = "test question" self.tags = ['test_tag'] @patch('os.listdir') @patch('pathlib.Path.iterdir') @patch('builtins.open') - def test_create_chat_with_tags(self, open_mock, iterdir_mock, listdir_mock): + def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: listdir_mock.return_value = ['testfile.txt'] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( @@ -45,7 +66,7 @@ class TestCreateChat(unittest.TestCase): @patch('os.listdir') @patch('pathlib.Path.iterdir') @patch('builtins.open') - def test_create_chat_with_other_tags(self, open_mock, iterdir_mock, listdir_mock): + def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: listdir_mock.return_value = ['testfile.txt'] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( @@ -63,7 +84,7 @@ class TestCreateChat(unittest.TestCase): @patch('os.listdir') @patch('pathlib.Path.iterdir') @patch('builtins.open') - def test_create_chat_without_tags(self, open_mock, iterdir_mock, listdir_mock): + def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] open_mock.side_effect = ( @@ -90,9 +111,9 @@ class TestCreateChat(unittest.TestCase): {'role': 'assistant', 'content': 'some answer2'}) -class TestHandleQuestion(unittest.TestCase): +class TestHandleQuestion(CmmTestCase): - def setUp(self): + def setUp(self) -> None: self.question = "test question" self.args = argparse.Namespace( tags=['tag1'], @@ -109,12 +130,7 @@ class TestHandleQuestion(unittest.TestCase): with_tags=False, with_file=False, ) - self.config = { - 'db': 'test_files', - 'setting1': 'value1', - 'setting2': 'value2', - 'openai': {}, - } + self.config = self.dummy_config(db='test_files') @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") @patch("chatmastermind.main.print_tag_args") @@ -122,9 +138,9 @@ class TestHandleQuestion(unittest.TestCase): @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) @patch("chatmastermind.utils.pp") @patch("builtins.print") - def test_ask_cmd(self, mock_print, mock_pp, mock_ai, - mock_print_chat_hist, mock_print_tag_args, - mock_create_chat_hist): + def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, + mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, + mock_create_chat_hist: MagicMock) -> None: open_mock = MagicMock() with patch("chatmastermind.storage.open", open_mock): ask_cmd(self.args, self.config) @@ -155,15 +171,15 @@ class TestHandleQuestion(unittest.TestCase): open_mock.assert_has_calls(open_expected_calls, any_order=True) -class TestSaveAnswers(unittest.TestCase): +class TestSaveAnswers(CmmTestCase): @mock.patch('builtins.open') @mock.patch('chatmastermind.storage.print') - def test_save_answers(self, print_mock, open_mock): + def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: question = "Test question?" answers = ["Answer 1", "Answer 2"] tags = ["tag1", "tag2"] otags = ["otag1", "otag2"] - config = {'db': 'test_db'} + config = self.dummy_config(db='test_db') with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ mock.patch('chatmastermind.storage.yaml.dump'), \ @@ -179,10 +195,10 @@ class TestSaveAnswers(unittest.TestCase): open_mock.assert_has_calls(open_calls, any_order=True) -class TestAI(unittest.TestCase): +class TestAI(CmmTestCase): @patch("openai.ChatCompletion.create") - def test_ai(self, mock_create: MagicMock): + def test_ai(self, mock_create: MagicMock) -> None: mock_create.return_value = { 'choices': [ {'message': {'content': 'response_text_1'}}, @@ -191,28 +207,20 @@ class TestAI(unittest.TestCase): 'usage': {'tokens': 10} } - number = 2 chat = [{"role": "system", "content": "hello ai"}] - config = { - "openai": { - "model": "text-davinci-002", - "temperature": 0.5, - "max_tokens": 150, - "top_p": 1, - "n": number, - "frequency_penalty": 0, - "presence_penalty": 0 - } - } + config = self.dummy_config(db='dummy') + config['openai']['model'] = "text-davinci-002" + config['openai']['max_tokens'] = 150 + config['openai']['temperature'] = 0.5 - result = ai(chat, config, number) + result = ai(chat, config, 2) expected_result = (['response_text_1', 'response_text_2'], {'tokens': 10}) self.assertEqual(result, expected_result) -class TestCreateParser(unittest.TestCase): - def test_create_parser(self): +class TestCreateParser(CmmTestCase): + def test_create_parser(self) -> None: with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: mock_cmdparser = Mock() mock_add_subparsers.return_value = mock_cmdparser -- 2.36.6 From 380b7c1b672937c6769b89477f584b9ea088491b Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Wed, 16 Aug 2023 12:24:03 +0200 Subject: [PATCH 036/170] Python 3.9 compatibility. --- chatmastermind/configuration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index bc58574..9cb7885 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -1,5 +1,5 @@ import pathlib -from typing import TypedDict, Any +from typing import TypedDict, Any, Union class OpenAIConfig(TypedDict): @@ -15,7 +15,7 @@ class OpenAIConfig(TypedDict): presence_penalty: float -def openai_config_valid(conf: dict[str, str | float | int]) -> bool: +def openai_config_valid(conf: dict[str, Union[str, float, int]]) -> bool: """ Checks if the given Open AI configuration dict is complete and contains valid types and values. -- 2.36.6 From a5c91adc4138bc5163d6665a0e35a7c42b835da9 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 16 Aug 2023 23:22:20 +0200 Subject: [PATCH 037/170] configuration: minor improvements / fixes Could not extend the subclass of 'TypedDict' the way I wanted, so I switched to 'dataclass'. --- chatmastermind/api_client.py | 14 +++--- chatmastermind/configuration.py | 83 +++++++++++++++++---------------- chatmastermind/main.py | 18 +++---- chatmastermind/storage.py | 24 ++-------- tests/test_main.py | 36 +++++++------- 5 files changed, 80 insertions(+), 95 deletions(-) diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py index d8634bd..2c4a094 100644 --- a/chatmastermind/api_client.py +++ b/chatmastermind/api_client.py @@ -30,17 +30,15 @@ def ai(chat: ChatType, Make AI request with the given chat history and configuration. Return AI response and tokens used. """ - if not isinstance(config['openai'], dict): - raise RuntimeError('Configuration openai is not a dict.') response = openai.ChatCompletion.create( - model=config['openai']['model'], + model=config.openai.model, messages=chat, - temperature=config['openai']['temperature'], - max_tokens=config['openai']['max_tokens'], - top_p=config['openai']['top_p'], + temperature=config.openai.temperature, + max_tokens=config.openai.max_tokens, + top_p=config.openai.top_p, n=number, - frequency_penalty=config['openai']['frequency_penalty'], - presence_penalty=config['openai']['presence_penalty']) + frequency_penalty=config.openai.frequency_penalty, + presence_penalty=config.openai.presence_penalty) result = [] for choice in response['choices']: # type: ignore result.append(choice['message']['content'].strip()) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 9cb7885..0037916 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -1,8 +1,13 @@ -import pathlib -from typing import TypedDict, Any, Union +import yaml +from typing import Type, TypeVar, Any +from dataclasses import dataclass, asdict + +ConfigInst = TypeVar('ConfigInst', bound='Config') +OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') -class OpenAIConfig(TypedDict): +@dataclass +class OpenAIConfig(): """ The OpenAI section of the configuration file. """ @@ -14,27 +19,24 @@ class OpenAIConfig(TypedDict): frequency_penalty: float presence_penalty: float - -def openai_config_valid(conf: dict[str, Union[str, float, int]]) -> bool: - """ - Checks if the given Open AI configuration dict is complete - and contains valid types and values. - """ - try: - str(conf['api_key']) - str(conf['model']) - int(conf['max_tokens']) - float(conf['temperature']) - float(conf['top_p']) - float(conf['frequency_penalty']) - float(conf['presence_penalty']) - return True - except Exception as e: - print(f"OpenAI configuration is invalid: {e}") - return False + @classmethod + def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: + """ + Create OpenAIConfig from a dict. + """ + return cls( + api_key=str(source['api_key']), + model=str(source['model']), + max_tokens=int(source['max_tokens']), + temperature=float(source['temperature']), + top_p=float(source['top_p']), + frequency_penalty=float(source['frequency_penalty']), + presence_penalty=float(source['presence_penalty']) + ) -class Config(TypedDict): +@dataclass +class Config(): """ The configuration file structure. """ @@ -42,22 +44,23 @@ class Config(TypedDict): db: str openai: OpenAIConfig + @classmethod + def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: + """ + Create OpenAIConfig from a dict. + """ + return cls( + system=str(source['system']), + db=str(source['db']), + openai=OpenAIConfig.from_dict(source['openai']) + ) -def config_valid(conf: dict[str, Any]) -> bool: - """ - Checks if the given configuration dict is complete - and contains valid types and values. - """ - try: - str(conf['system']) - pathlib.Path(str(conf['db'])) - return True - except Exception as e: - print(f"Configuration is invalid: {e}") - return False - if 'openai' in conf: - return openai_config_valid(conf['openai']) - else: - # required as long as we only support OpenAI - print("Section 'openai' is missing in the configuration!") - return False + @classmethod + def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: + with open(path, 'r') as f: + source = yaml.load(f, Loader=yaml.FullLoader) + return cls.from_dict(source) + + def to_file(self, path: str) -> None: + with open(path, 'w') as f: + yaml.dump(asdict(self), f) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 7c6df33..7866179 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -8,7 +8,7 @@ import argcomplete import argparse import pathlib from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType -from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, read_config, write_config, dump_data +from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data from .api_client import ai, openai_api_key, print_models from .configuration import Config from itertools import zip_longest @@ -72,10 +72,10 @@ def config_cmd(args: argparse.Namespace, config: Config) -> None: if args.list_models: print_models() elif args.print_model: - print(config['openai']['model']) + print(config.openai.model) elif args.model: - config['openai']['model'] = args.model - write_config(args.config, config) + config.openai.model = args.model + config.to_file(args.config) def ask_cmd(args: argparse.Namespace, config: Config) -> None: @@ -83,11 +83,11 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None: Handler for the 'ask' command. """ if args.max_tokens: - config['openai']['max_tokens'] = args.max_tokens + config.openai.max_tokens = args.max_tokens if args.temperature: - config['openai']['temperature'] = args.temperature + config.openai.temperature = args.temperature if args.model: - config['openai']['model'] = args.model + config.openai.model = args.model chat, question, tags = create_question_with_hist(args, config) print_chat_hist(chat, False, args.only_source_code) otags = args.output_tags or [] @@ -225,9 +225,9 @@ def main() -> int: parser = create_parser() args = parser.parse_args() command = parser.parse_args() - config = read_config(args.config) + config = Config.from_file(args.config) - openai_api_key(config['openai']['api_key']) + openai_api_key(config.openai.api_key) command.func(command, config) diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index a4648b0..8b9ed97 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -1,9 +1,8 @@ import yaml -import sys import io import pathlib from .utils import terminal_width, append_message, message_to_chat, ChatType -from .configuration import Config, config_valid +from .configuration import Config from typing import Any, Optional @@ -24,19 +23,6 @@ def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: "file": fname.name} -def read_config(path: str) -> Config: - with open(path, 'r') as f: - config = yaml.load(f, Loader=yaml.FullLoader) - if not config_valid(config): - sys.exit(1) - return config - - -def write_config(path: str, config: Config) -> None: - with open(path, 'w') as f: - yaml.dump(config, f) - - def dump_data(data: dict[str, Any]) -> str: with io.StringIO() as fd: fd.write(f'TAGS: {" ".join(data["tags"])}\n') @@ -60,7 +46,7 @@ def save_answers(question: str, ) -> None: wtags = otags or tags num, inum = 0, 0 - next_fname = pathlib.Path(str(config['db'])) / '.next' + next_fname = pathlib.Path(str(config.db)) / '.next' try: with open(next_fname, 'r') as f: num = int(f.read()) @@ -87,8 +73,8 @@ def create_chat_hist(question: Optional[str], with_file: bool = False ) -> ChatType: chat: ChatType = [] - append_message(chat, 'system', str(config['system']).strip()) - for file in sorted(pathlib.Path(str(config['db'])).iterdir()): + append_message(chat, 'system', str(config.system).strip()) + for file in sorted(pathlib.Path(str(config.db)).iterdir()): if file.suffix == '.yaml': with open(file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) @@ -114,7 +100,7 @@ def create_chat_hist(question: Optional[str], def get_tags(config: Config, prefix: Optional[str]) -> list[str]: result = [] - for file in sorted(pathlib.Path(str(config['db'])).iterdir()): + for file in sorted(pathlib.Path(str(config.db)).iterdir()): if file.suffix == '.yaml': with open(file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) diff --git a/tests/test_main.py b/tests/test_main.py index 4a70cbb..db5fcdb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -5,7 +5,7 @@ import argparse from chatmastermind.utils import terminal_width from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai -from chatmastermind.configuration import Config, OpenAIConfig +from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -19,18 +19,16 @@ class CmmTestCase(unittest.TestCase): """ Creates a dummy configuration. """ - return Config( - system='dummy_system', - db=db, - openai=OpenAIConfig( - api_key='dummy_key', - model='dummy_model', - max_tokens=4000, - temperature=1.0, - top_p=1, - frequency_penalty=0, - presence_penalty=0 - ) + return Config.from_dict( + {'system': 'dummy_system', + 'db': db, + 'openai': {'api_key': 'dummy_key', + 'model': 'dummy_model', + 'max_tokens': 4000, + 'temperature': 1.0, + 'top_p': 1, + 'frequency_penalty': 0, + 'presence_penalty': 0}} ) @@ -55,7 +53,7 @@ class TestCreateChat(CmmTestCase): self.assertEqual(len(test_chat), 4) self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config['system']}) + {'role': 'system', 'content': self.config.system}) self.assertEqual(test_chat[1], {'role': 'user', 'content': 'test_content'}) self.assertEqual(test_chat[2], @@ -77,7 +75,7 @@ class TestCreateChat(CmmTestCase): self.assertEqual(len(test_chat), 2) self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config['system']}) + {'role': 'system', 'content': self.config.system}) self.assertEqual(test_chat[1], {'role': 'user', 'content': self.question}) @@ -100,7 +98,7 @@ class TestCreateChat(CmmTestCase): self.assertEqual(len(test_chat), 6) self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config['system']}) + {'role': 'system', 'content': self.config.system}) self.assertEqual(test_chat[1], {'role': 'user', 'content': 'test_content'}) self.assertEqual(test_chat[2], @@ -209,9 +207,9 @@ class TestAI(CmmTestCase): chat = [{"role": "system", "content": "hello ai"}] config = self.dummy_config(db='dummy') - config['openai']['model'] = "text-davinci-002" - config['openai']['max_tokens'] = 150 - config['openai']['temperature'] = 0.5 + config.openai.model = "text-davinci-002" + config.openai.max_tokens = 150 + config.openai.temperature = 0.5 result = ai(chat, config, 2) expected_result = (['response_text_1', 'response_text_2'], -- 2.36.6 From b13a68836a4ed49f3777ad4c8cf7038a776bcb3e Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 16 Aug 2023 17:07:01 +0200 Subject: [PATCH 038/170] added new module 'tags.py' with classes 'Tag' and 'TagLine' --- chatmastermind/tags.py | 130 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 chatmastermind/tags.py diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py new file mode 100644 index 0000000..28583a2 --- /dev/null +++ b/chatmastermind/tags.py @@ -0,0 +1,130 @@ +""" +Module implementing tag related functions and classes. +""" +from typing import Type, TypeVar, Optional + +TagInst = TypeVar('TagInst', bound='Tag') +TagLineInst = TypeVar('TagLineInst', bound='TagLine') + + +class TagError(Exception): + pass + + +class Tag(str): + """ + A single tag. A string that can contain anything but the default separator (' '). + """ + # default separator + default_separator = ' ' + # alternative separators (e. g. for backwards compatibility) + alternative_separators = [','] + + def __new__(cls: Type[TagInst], string: str) -> TagInst: + """ + Make sure the tag string does not contain the default separator. + """ + if cls.default_separator in string: + raise TagError(f"Tag '{string}' contains the separator char '{cls.default_separator}'") + instance = super().__new__(cls, string) + return instance + + +class TagLine(str): + """ + A line of tags. It starts with a prefix ('TAGS:'), followed by a list of tags, + separated by the defaut separator (' '). Any operations on a TagLine will sort + the tags. + """ + # the prefix + prefix = 'TAGS:' + + def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst: + """ + Make sure the tagline string starts with the prefix. + """ + if not string.startswith(cls.prefix): + raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_set(cls: Type[TagLineInst], tags: set[Tag]) -> TagLineInst: + """ + Create a new TagLine from a set of tags. + """ + return cls(' '.join([TagLine.prefix] + sorted([t for t in tags]))) + + def tags(self) -> set[Tag]: + """ + Returns all tags contained in this line as a set. + """ + tagstr = self[len(self.prefix):].strip() + separator = Tag.default_separator + # look for alternative separators and use the first one found + # -> we don't support different separators in the same TagLine + for s in Tag.alternative_separators: + if s in tagstr: + separator = s + break + return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + + def merge(self, taglines: set['TagLine']) -> 'TagLine': + """ + Merges the tags of all given taglines into the current one + and returns a new TagLine. + """ + merged_tags = self.tags() + for tl in taglines: + merged_tags |= tl.tags() + return self.from_set(set(sorted(merged_tags))) + + def delete_tags(self, tags: set[Tag]) -> 'TagLine': + """ + Deletes the given tags and returns a new TagLine. + """ + return self.from_set(self.tags().difference(tags)) + + def add_tags(self, tags: set[Tag]) -> 'TagLine': + """ + Adds the given tags and returns a new TagLine. + """ + return self.from_set(set(sorted(self.tags() | tags))) + + def rename_tags(self, tags: set[tuple[Tag, Tag]]) -> 'TagLine': + """ + Renames the given tags and returns a new TagLine. The first + tuple element is the old name, the second one is the new name. + """ + new_tags = self.tags() + for t in tags: + if t[0] in new_tags: + new_tags.remove(t[0]) + new_tags.add(t[1]) + return self.from_set(set(sorted(new_tags))) + + def match_tags(self, tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], + tags_not: Optional[set[Tag]]) -> bool: + """ + Checks if the current TagLine matches the given tag requirements: + - 'tags_or' : matches if this TagLine contains ANY of those tags + - 'tags_and': matches if this TagLine contains ALL of those tags + - 'tags_not': matches if this TagLine contains NONE of those tags + + Note that it's sufficient if the TagLine matches one of 'tags_or' or 'tags_and', + i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' + or all of the tags in 'tags_and' but it must never contain any of the tags in + 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag + exclusion is still done if 'tags_not' is not 'None'). + """ + tag_set = self.tags() + required_tags_present = False + excluded_tags_missing = False + if ((tags_or is None and tags_and is None) + or (tags_or and any(tag in tag_set for tag in tags_or)) # noqa: W503 + or (tags_and and all(tag in tag_set for tag in tags_and))): # noqa: W503 + required_tags_present = True + if ((tags_not is None) + or (not any(tag in tag_set for tag in tags_not))): # noqa: W503 + excluded_tags_missing = True + return required_tags_present and excluded_tags_missing -- 2.36.6 From ef46f5efc942551b0ccbd37b9807eb983bcdb628 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 17 Aug 2023 08:28:15 +0200 Subject: [PATCH 039/170] added testcases for Tag and TagLine classes --- tests/test_main.py | 114 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index db5fcdb..eb13dc5 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -7,6 +7,7 @@ from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data +from chatmastermind.tags import Tag, TagLine, TagError from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -231,3 +232,116 @@ class TestCreateParser(CmmTestCase): mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) + + +class TestTag(CmmTestCase): + def test_valid_tag(self) -> None: + tag = Tag('mytag') + self.assertEqual(tag, 'mytag') + + def test_invalid_tag(self) -> None: + with self.assertRaises(TagError): + Tag('tag with space') + + def test_default_separator(self) -> None: + self.assertEqual(Tag.default_separator, ' ') + + def test_alternative_separators(self) -> None: + self.assertEqual(Tag.alternative_separators, [',']) + + +class TestTagLine(CmmTestCase): + def test_valid_tagline(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_invalid_tagline(self) -> None: + with self.assertRaises(TagError): + TagLine('tag1 tag2') + + def test_prefix(self) -> None: + self.assertEqual(TagLine.prefix, 'TAGS:') + + def test_from_set(self) -> None: + tags = {Tag('tag1'), Tag('tag2')} + tagline = TagLine.from_set(tags) + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_merge(self) -> None: + tagline1 = TagLine('TAGS: tag1 tag2') + tagline2 = TagLine('TAGS: tag2 tag3') + merged_tagline = tagline1.merge({tagline2}) + self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3') + + def test_delete_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag2') + + def test_add_tags(self) -> None: + tagline = TagLine('TAGS: tag1') + new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3') + + def test_rename_tags(self) -> None: + tagline = TagLine('TAGS: old1 old2') + new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))}) + self.assertEqual(new_tagline, 'TAGS: new1 new2') + + def test_match_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + + # Test case 1: Match any tag in 'tags_or' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and: set[Tag] = set() + tags_not: set[Tag] = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 2: Match all tags in 'tags_and' + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = {Tag('tag5')} + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 5: No matching tags in 'tags_or' + tags_or = {Tag('tag4'), Tag('tag5')} + tags_and = set() + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 6: Not all tags in 'tags_and' are present + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')} + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 7: Some tags in 'tags_not' are present + tags_or = {Tag('tag1')} + tags_and = set() + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 8: 'tags_or' and 'tags_and' are None, match all tags + tags_not = set() + self.assertTrue(tagline.match_tags(None, None, tags_not)) + + # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(None, None, tags_not)) -- 2.36.6 From 604e5ccf73e2d3aafc45a48128317f5462bd5348 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 18 Aug 2023 12:11:56 +0200 Subject: [PATCH 040/170] tags.py: converted most TagLine functions to module functions --- chatmastermind/tags.py | 99 ++++++++++++++++++++++++++++++------------ 1 file changed, 71 insertions(+), 28 deletions(-) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index 28583a2..bfe5fd5 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -30,6 +30,67 @@ class Tag(str): return instance +def delete_tags(tags: set[Tag], tags_delete: set[Tag]) -> set[Tag]: + """ + Deletes the given tags and returns a new set. + """ + return tags.difference(tags_delete) + + +def add_tags(tags: set[Tag], tags_add: set[Tag]) -> set[Tag]: + """ + Adds the given tags and returns a new set. + """ + return set(sorted(tags | tags_add)) + + +def merge_tags(tags: set[Tag], tags_merge: list[set[Tag]]) -> set[Tag]: + """ + Merges the tags in 'tags_merge' into the current one and returns a new set. + """ + for ts in tags_merge: + tags |= ts + return tags + + +def rename_tags(tags: set[Tag], tags_rename: set[tuple[Tag, Tag]]) -> set[Tag]: + """ + Renames the given tags and returns a new set. The first tuple element + is the old name, the second one is the new name. + """ + for t in tags_rename: + if t[0] in tags: + tags.remove(t[0]) + tags.add(t[1]) + return set(sorted(tags)) + + +def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], + tags_not: Optional[set[Tag]]) -> bool: + """ + Checks if the given set 'tags' matches the given tag requirements: + - 'tags_or' : matches if this TagLine contains ANY of those tags + - 'tags_and': matches if this TagLine contains ALL of those tags + - 'tags_not': matches if this TagLine contains NONE of those tags + + Note that it's sufficient if 'tags' matches one of 'tags_or' or 'tags_and', + i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' + or all of the tags in 'tags_and' but it must never contain any of the tags in + 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag + exclusion is still done if 'tags_not' is not 'None'). + """ + required_tags_present = False + excluded_tags_missing = False + if ((tags_or is None and tags_and is None) + or (tags_or and any(tag in tags for tag in tags_or)) # noqa: W503 + or (tags_and and all(tag in tags for tag in tags_and))): # noqa: W503 + required_tags_present = True + if ((tags_not is None) + or (not any(tag in tags for tag in tags_not))): # noqa: W503 + excluded_tags_missing = True + return required_tags_present and excluded_tags_missing + + class TagLine(str): """ A line of tags. It starts with a prefix ('TAGS:'), followed by a list of tags, @@ -71,37 +132,29 @@ class TagLine(str): def merge(self, taglines: set['TagLine']) -> 'TagLine': """ - Merges the tags of all given taglines into the current one - and returns a new TagLine. + Merges the tags of all given taglines into the current one and returns a new TagLine. """ - merged_tags = self.tags() - for tl in taglines: - merged_tags |= tl.tags() - return self.from_set(set(sorted(merged_tags))) + tags_merge = [tl.tags() for tl in taglines] + return self.from_set(merge_tags(self.tags(), tags_merge)) - def delete_tags(self, tags: set[Tag]) -> 'TagLine': + def delete_tags(self, tags_delete: set[Tag]) -> 'TagLine': """ Deletes the given tags and returns a new TagLine. """ - return self.from_set(self.tags().difference(tags)) + return self.from_set(delete_tags(self.tags(), tags_delete)) - def add_tags(self, tags: set[Tag]) -> 'TagLine': + def add_tags(self, tags_add: set[Tag]) -> 'TagLine': """ Adds the given tags and returns a new TagLine. """ - return self.from_set(set(sorted(self.tags() | tags))) + return self.from_set(add_tags(self.tags(), tags_add)) - def rename_tags(self, tags: set[tuple[Tag, Tag]]) -> 'TagLine': + def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> 'TagLine': """ Renames the given tags and returns a new TagLine. The first tuple element is the old name, the second one is the new name. """ - new_tags = self.tags() - for t in tags: - if t[0] in new_tags: - new_tags.remove(t[0]) - new_tags.add(t[1]) - return self.from_set(set(sorted(new_tags))) + return self.from_set(rename_tags(self.tags(), tags_rename)) def match_tags(self, tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], tags_not: Optional[set[Tag]]) -> bool: @@ -117,14 +170,4 @@ class TagLine(str): 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag exclusion is still done if 'tags_not' is not 'None'). """ - tag_set = self.tags() - required_tags_present = False - excluded_tags_missing = False - if ((tags_or is None and tags_and is None) - or (tags_or and any(tag in tag_set for tag in tags_or)) # noqa: W503 - or (tags_and and all(tag in tag_set for tag in tags_and))): # noqa: W503 - required_tags_present = True - if ((tags_not is None) - or (not any(tag in tag_set for tag in tags_not))): # noqa: W503 - excluded_tags_missing = True - return required_tags_present and excluded_tags_missing + return match_tags(self.tags(), tags_or, tags_and, tags_not) -- 2.36.6 From 173a46a9b52b1ff1f30d5f9acb27538daaa9379a Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 18 Aug 2023 16:07:50 +0200 Subject: [PATCH 041/170] added new module 'message.py' --- chatmastermind/message.py | 430 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 430 insertions(+) create mode 100644 chatmastermind/message.py diff --git a/chatmastermind/message.py b/chatmastermind/message.py new file mode 100644 index 0000000..157cd46 --- /dev/null +++ b/chatmastermind/message.py @@ -0,0 +1,430 @@ +""" +Module implementing message related functions and classes. +""" +import pathlib +import yaml +from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal +from dataclasses import dataclass, asdict, field +from .tags import Tag, TagLine, TagError, match_tags + +QuestionInst = TypeVar('QuestionInst', bound='Question') +AnswerInst = TypeVar('AnswerInst', bound='Answer') +MessageInst = TypeVar('MessageInst', bound='Message') +AILineInst = TypeVar('AILineInst', bound='AILine') +ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine') +YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]] + + +class MessageError(Exception): + pass + + +def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: + """ + Changes the YAML dump style to multiline syntax for multiline strings. + """ + if len(data.splitlines()) > 1: + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + + +yaml.add_representer(str, str_presenter) + + +def source_code(text: str, include_delims: bool = False) -> list[str]: + """ + Extract all source code sections from the given text, i. e. all lines + surrounded by lines tarting with '```'. If 'include_delims' is True, + the surrounding lines are included, otherwise they are omitted. The + result list contains every source code section as a single string. + The order in the list represents the order of the sections in the text. + """ + code_sections: list[str] = [] + code_lines: list[str] = [] + in_code_block = False + + for line in text.split('\n'): + if line.strip().startswith('```'): + if include_delims: + code_lines.append(line) + if in_code_block: + code_sections.append('\n'.join(code_lines) + '\n') + code_lines.clear() + in_code_block = not in_code_block + elif in_code_block: + code_lines.append(line) + + return code_sections + + +@dataclass(kw_only=True) +class MessageFilter: + """ + Various filters for a Message. + """ + tags_or: Optional[set[Tag]] = None + tags_and: Optional[set[Tag]] = None + tags_not: Optional[set[Tag]] = None + ai: Optional[str] = None + model: Optional[str] = None + question_contains: Optional[str] = None + answer_contains: Optional[str] = None + answer_state: Optional[Literal['available', 'missing']] = None + ai_state: Optional[Literal['available', 'missing']] = None + model_state: Optional[Literal['available', 'missing']] = None + + +class AILine(str): + """ + A line that represents the AI name in a '.txt' file.. + """ + prefix: Final[str] = 'AI:' + + def __new__(cls: Type[AILineInst], string: str) -> AILineInst: + if not string.startswith(cls.prefix): + raise TagError(f"AILine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + def ai(self) -> str: + return self[len(self.prefix):].strip() + + @classmethod + def from_ai(cls: Type[AILineInst], ai: str) -> AILineInst: + return cls(' '.join([cls.prefix, ai])) + + +class ModelLine(str): + """ + A line that represents the model name in a '.txt' file.. + """ + prefix: Final[str] = 'MODEL:' + + def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst: + if not string.startswith(cls.prefix): + raise TagError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + def model(self) -> str: + return self[len(self.prefix):].strip() + + @classmethod + def from_model(cls: Type[ModelLineInst], model: str) -> ModelLineInst: + return cls(' '.join([cls.prefix, model])) + + +class Question(str): + """ + A single question with a defined header. + """ + txt_header: ClassVar[str] = '=== QUESTION ===' + yaml_key: ClassVar[str] = 'question' + + def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: + """ + Make sure the question string does not contain the header. + """ + if cls.txt_header in string: + raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: + """ + Build Question from a list of strings. Make sure strings do not contain the header. + """ + if any(cls.txt_header in string for string in strings): + raise MessageError(f"Question contains the header '{cls.txt_header}'") + instance = super().__new__(cls, '\n'.join(strings).strip()) + return instance + + def source_code(self, include_delims: bool = False) -> list[str]: + """ + Extract and return all source code sections. + """ + return source_code(self, include_delims) + + +class Answer(str): + """ + A single answer with a defined header. + """ + txt_header: ClassVar[str] = '=== ANSWER ===' + yaml_key: ClassVar[str] = 'answer' + + def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: + """ + Make sure the answer string does not contain the header. + """ + if cls.txt_header in string: + raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: + """ + Build Question from a list of strings. Make sure strings do not contain the header. + """ + if any(cls.txt_header in string for string in strings): + raise MessageError(f"Question contains the header '{cls.txt_header}'") + instance = super().__new__(cls, '\n'.join(strings).strip()) + return instance + + def source_code(self, include_delims: bool = False) -> list[str]: + """ + Extract and return all source code sections. + """ + return source_code(self, include_delims) + + +@dataclass +class Message(): + """ + Single message. Consists of a question and optionally an answer, a set of tags + and a file path. + """ + question: Question + answer: Optional[Answer] = None + # metadata, ignored when comparing messages + tags: Optional[set[Tag]] = field(default=None, compare=False) + ai: Optional[str] = field(default=None, compare=False) + model: Optional[str] = field(default=None, compare=False) + file_path: Optional[pathlib.Path] = field(default=None, compare=False) + # class variables + file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml'] + tags_yaml_key: ClassVar[str] = 'tags' + file_yaml_key: ClassVar[str] = 'file_path' + ai_yaml_key: ClassVar[str] = 'ai' + model_yaml_key: ClassVar[str] = 'model' + + def __hash__(self) -> int: + """ + The hash value is computed based on immutable members. + """ + return hash((self.question, self.answer)) + + @classmethod + def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: + """ + Create a Message from the given dict. + """ + return cls(question=data[Question.yaml_key], + answer=data.get(Answer.yaml_key, None), + tags=set(data.get(cls.tags_yaml_key, [])), + ai=data.get(cls.ai_yaml_key, None), + model=data.get(cls.model_yaml_key, None), + file_path=data.get(cls.file_yaml_key, None)) + + @classmethod + def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]: + """ + Return only the tags from the given Message file. + """ + if not file_path.exists(): + raise MessageError(f"Message file '{file_path}' does not exist") + if file_path.suffix not in cls.file_suffixes: + raise MessageError(f"File type '{file_path.suffix}' is not supported") + if file_path.suffix == '.txt': + with open(file_path, "r") as fd: + tags = TagLine(fd.readline()).tags() + else: # '.yaml' + with open(file_path, "r") as fd: + data = yaml.load(fd, Loader=yaml.FullLoader) + tags = set(sorted(data[cls.tags_yaml_key])) + return tags + + @classmethod + def from_file(cls: Type[MessageInst], file_path: pathlib.Path, + mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]: + """ + Create a Message from the given file. Returns 'None' if the message does + not fulfill the filter requirements. For TXT files, the tags are matched + before building the whole message. The other filters are applied afterwards. + """ + if not file_path.exists(): + raise MessageError(f"Message file '{file_path}' does not exist") + if file_path.suffix not in cls.file_suffixes: + raise MessageError(f"File type '{file_path.suffix}' is not supported") + + if file_path.suffix == '.txt': + message = cls.__from_file_txt(file_path, + mfilter.tags_or if mfilter else None, + mfilter.tags_and if mfilter else None, + mfilter.tags_not if mfilter else None) + else: + message = cls.__from_file_yaml(file_path) + if message and (not mfilter or (mfilter and message.match(mfilter))): + return message + else: + return None + + @classmethod + def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11 + tags_or: Optional[set[Tag]] = None, + tags_and: Optional[set[Tag]] = None, + tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]: + """ + Create a Message from the given TXT file. Expects the following file structures: + For '.txt': + * TagLine [Optional] + * AI [Optional] + * Model [Optional] + * Question.txt_header + * Question + * Answer.txt_header [Optional] + * Answer [Optional] + + Returns 'None' if the message does not fulfill the tag requirements. + """ + tags: set[Tag] = set() + question: Question + answer: Optional[Answer] = None + ai: Optional[str] = None + model: Optional[str] = None + with open(file_path, "r") as fd: + # TagLine (Optional) + try: + pos = fd.tell() + tags = TagLine(fd.readline()).tags() + except TagError: + fd.seek(pos) + if tags_or or tags_and or tags_not: + # match with an empty set if the file has no tags + if not match_tags(tags, tags_or, tags_and, tags_not): + return None + # AILine (Optional) + try: + pos = fd.tell() + ai = AILine(fd.readline()).ai() + except TagError: + fd.seek(pos) + # ModelLine (Optional) + try: + pos = fd.tell() + model = ModelLine(fd.readline()).model() + except TagError: + fd.seek(pos) + # Question and Answer + text = fd.read().strip().split('\n') + question_idx = text.index(Question.txt_header) + 1 + try: + answer_idx = text.index(Answer.txt_header) + question = Question.from_list(text[question_idx:answer_idx]) + answer = Answer.from_list(text[answer_idx + 1:]) + except ValueError: + question = Question.from_list(text[question_idx:]) + return cls(question, answer, tags, ai, model, file_path) + + @classmethod + def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst: + """ + Create a Message from the given YAML file. Expects the following file structures: + * Question.yaml_key: single or multiline string + * Answer.yaml_key: single or multiline string [Optional] + * Message.tags_yaml_key: list of strings [Optional] + * Message.ai_yaml_key: str [Optional] + * Message.model_yaml_key: str [Optional] + """ + with open(file_path, "r") as fd: + data = yaml.load(fd, Loader=yaml.FullLoader) + data[cls.file_yaml_key] = file_path + return cls.from_dict(data) + + def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 + """ + Write a Message to the given file. Type is determined based on the suffix. + Currently supported suffixes: ['.txt', '.yaml'] + """ + if file_path: + self.file_path = file_path + if not self.file_path: + raise MessageError("Got no valid path to write message") + if self.file_path.suffix not in self.file_suffixes: + raise MessageError(f"File type '{self.file_path.suffix}' is not supported") + # TXT + if self.file_path.suffix == '.txt': + return self.__to_file_txt(self.file_path) + elif self.file_path.suffix == '.yaml': + return self.__to_file_yaml(self.file_path) + + def __to_file_txt(self, file_path: pathlib.Path) -> None: + """ + Write a Message to the given file in TXT format. + Creates the following file structures: + * TagLine + * AI [Optional] + * Model [Optional] + * Question.txt_header + * Question + * Answer.txt_header + * Answer + """ + with open(file_path, "w") as fd: + if self.tags: + fd.write(f'{TagLine.from_set(self.tags)}\n') + if self.ai: + fd.write(f'{AILine.from_ai(self.ai)}\n') + if self.model: + fd.write(f'{ModelLine.from_model(self.model)}\n') + fd.write(f'{Question.txt_header}\n{self.question}\n') + if self.answer: + fd.write(f'{Answer.txt_header}\n{self.answer}\n') + + def __to_file_yaml(self, file_path: pathlib.Path) -> None: + """ + Write a Message to the given file in YAML format. + Creates the following file structures: + * Question.yaml_key: single or multiline string + * Answer.yaml_key: single or multiline string + * Message.tags_yaml_key: list of strings + * Message.ai_yaml_key: str [Optional] + * Message.model_yaml_key: str [Optional] + """ + with open(file_path, "w") as fd: + data: YamlDict = {Question.yaml_key: str(self.question)} + if self.answer: + data[Answer.yaml_key] = str(self.answer) + if self.ai: + data[self.ai_yaml_key] = self.ai + if self.model: + data[self.model_yaml_key] = self.model + if self.tags: + data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) + yaml.dump(data, fd, sort_keys=False) + + def match(self, mfilter: MessageFilter) -> bool: # noqa: 13 + """ + Matches the current Message to the given filter atttributes. + Return True if all attributes match, else False. + """ + mytags = self.tags or set() + if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not) + and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 + or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 + or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 + or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503 + or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503 + or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503 + or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503 + or (mfilter.model_state == 'available' and not self.model) # noqa: W503 + or (mfilter.answer_state == 'missing' and self.answer) # noqa: W503 + or (mfilter.ai_state == 'missing' and self.ai) # noqa: W503 + or (mfilter.model_state == 'missing' and self.model)): # noqa: W503 + return False + return True + + def msg_id(self) -> str: + """ + Returns an ID that is unique throughout all messages in the same (DB) directory. + Currently this is the file name. The ID is also used for sorting messages. + """ + if self.file_path: + return self.file_path.name + else: + raise MessageError("Can't create file ID without a file path") + + def as_dict(self) -> dict[str, Any]: + return asdict(self) -- 2.36.6 From dfc12619319626757c6e776431a9581b32e4d984 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 18 Aug 2023 16:08:22 +0200 Subject: [PATCH 042/170] added testcases for messages.py --- tests/test_main.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index eb13dc5..8ce06cb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -8,6 +8,7 @@ from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data from chatmastermind.tags import Tag, TagLine, TagError +from chatmastermind.message import source_code, MessageError, Question, Answer from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -345,3 +346,79 @@ class TestTagLine(CmmTestCase): # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags tags_not = {Tag('tag2')} self.assertFalse(tagline.match_tags(None, None, tags_not)) + + +class SourceCodeTestCase(CmmTestCase): + def test_source_code_with_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " ```python\n print(\"Hello, World!\")\n ```\n", + " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" + ] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_without_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " print(\"Hello, World!\")\n", + " x = 10\n y = 20\n print(x + y)\n" + ] + result = source_code(text, include_delims=False) + self.assertEqual(result, expected_result) + + def test_source_code_with_single_code_block(self) -> None: + text = "```python\nprint(\"Hello, World!\")\n```" + expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_with_no_code_blocks(self) -> None: + text = "Some text without any code blocks" + expected_result: list[str] = [] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + +class QuestionTestCase(CmmTestCase): + def test_question_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Question("=== QUESTION === What is your name?") + + def test_question_without_prefix(self) -> None: + question = Question("What is your favorite color?") + self.assertIsInstance(question, Question) + self.assertEqual(question, "What is your favorite color?") + + +class AnswerTestCase(CmmTestCase): + def test_answer_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Answer("=== ANSWER === Yes") + + def test_answer_without_prefix(self) -> None: + answer = Answer("No") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, "No") -- 2.36.6 From 879831d7f50f6dd39a3571933453c0e8406ab3f9 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 19 Aug 2023 08:04:41 +0200 Subject: [PATCH 043/170] configuration: added 'as_dict()' as an instance function --- chatmastermind/configuration.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 0037916..5ae32d6 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -63,4 +63,7 @@ class Config(): def to_file(self, path: str) -> None: with open(path, 'w') as f: - yaml.dump(asdict(self), f) + yaml.dump(asdict(self), f, sort_keys=False) + + def as_dict(self) -> dict[str, Any]: + return asdict(self) -- 2.36.6 From 580c86e948bd5ac1b83209e8dbeafb4ebc6d7385 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 19 Aug 2023 08:30:24 +0200 Subject: [PATCH 044/170] tags: TagLine constructor now supports multiline taglines and multiple spaces --- chatmastermind/tags.py | 20 +++++++++++--------- tests/test_main.py | 9 +++++++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index bfe5fd5..544270c 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -1,7 +1,7 @@ """ Module implementing tag related functions and classes. """ -from typing import Type, TypeVar, Optional +from typing import Type, TypeVar, Optional, Final TagInst = TypeVar('TagInst', bound='Tag') TagLineInst = TypeVar('TagLineInst', bound='TagLine') @@ -16,9 +16,9 @@ class Tag(str): A single tag. A string that can contain anything but the default separator (' '). """ # default separator - default_separator = ' ' + default_separator: Final[str] = ' ' # alternative separators (e. g. for backwards compatibility) - alternative_separators = [','] + alternative_separators: Final[list[str]] = [','] def __new__(cls: Type[TagInst], string: str) -> TagInst: """ @@ -93,19 +93,21 @@ def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[s class TagLine(str): """ - A line of tags. It starts with a prefix ('TAGS:'), followed by a list of tags, - separated by the defaut separator (' '). Any operations on a TagLine will sort - the tags. + A line of tags in a '.txt' file. It starts with a prefix ('TAGS:'), followed by + a list of tags, separated by the defaut separator (' '). Any operations on a + TagLine will sort the tags. """ # the prefix - prefix = 'TAGS:' + prefix: Final[str] = 'TAGS:' def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst: """ - Make sure the tagline string starts with the prefix. + Make sure the tagline string starts with the prefix. Also replace newlines + and multiple spaces with ' ', in order to support multiline TagLines. """ if not string.startswith(cls.prefix): raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'") + string = ' '.join(string.split()) instance = super().__new__(cls, string) return instance @@ -114,7 +116,7 @@ class TagLine(str): """ Create a new TagLine from a set of tags. """ - return cls(' '.join([TagLine.prefix] + sorted([t for t in tags]))) + return cls(' '.join([cls.prefix] + sorted([t for t in tags]))) def tags(self) -> set[Tag]: """ diff --git a/tests/test_main.py b/tests/test_main.py index 8ce06cb..25cdc37 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -256,6 +256,10 @@ class TestTagLine(CmmTestCase): tagline = TagLine('TAGS: tag1 tag2') self.assertEqual(tagline, 'TAGS: tag1 tag2') + def test_valid_tagline_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + def test_invalid_tagline(self) -> None: with self.assertRaises(TagError): TagLine('tag1 tag2') @@ -273,6 +277,11 @@ class TestTagLine(CmmTestCase): tags = tagline.tags() self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_tags_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_merge(self) -> None: tagline1 = TagLine('TAGS: tag1 tag2') tagline2 = TagLine('TAGS: tag2 tag3') -- 2.36.6 From 0d6a6dd6043651ef33b4072b8298ca19a5dd507d Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 21 Aug 2023 08:29:48 +0200 Subject: [PATCH 045/170] gitignore: added vim session file --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 4ade1df..89bf5fb 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,5 @@ dmypy.json .config.yaml db -noweb \ No newline at end of file +noweb +Session.vim -- 2.36.6 From aa89270876c622fbf7205133f3af99283c7ef472 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 20 Aug 2023 08:46:03 +0200 Subject: [PATCH 046/170] tests: splitted 'test_main.py' into 3 modules --- tests/test_main.py | 200 ------------------------------------------ tests/test_message.py | 78 ++++++++++++++++ tests/test_tags.py | 124 ++++++++++++++++++++++++++ 3 files changed, 202 insertions(+), 200 deletions(-) create mode 100644 tests/test_message.py create mode 100644 tests/test_tags.py diff --git a/tests/test_main.py b/tests/test_main.py index 25cdc37..db5fcdb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -7,8 +7,6 @@ from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data -from chatmastermind.tags import Tag, TagLine, TagError -from chatmastermind.message import source_code, MessageError, Question, Answer from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -233,201 +231,3 @@ class TestCreateParser(CmmTestCase): mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) - - -class TestTag(CmmTestCase): - def test_valid_tag(self) -> None: - tag = Tag('mytag') - self.assertEqual(tag, 'mytag') - - def test_invalid_tag(self) -> None: - with self.assertRaises(TagError): - Tag('tag with space') - - def test_default_separator(self) -> None: - self.assertEqual(Tag.default_separator, ' ') - - def test_alternative_separators(self) -> None: - self.assertEqual(Tag.alternative_separators, [',']) - - -class TestTagLine(CmmTestCase): - def test_valid_tagline(self) -> None: - tagline = TagLine('TAGS: tag1 tag2') - self.assertEqual(tagline, 'TAGS: tag1 tag2') - - def test_valid_tagline_with_newline(self) -> None: - tagline = TagLine('TAGS: tag1\n tag2') - self.assertEqual(tagline, 'TAGS: tag1 tag2') - - def test_invalid_tagline(self) -> None: - with self.assertRaises(TagError): - TagLine('tag1 tag2') - - def test_prefix(self) -> None: - self.assertEqual(TagLine.prefix, 'TAGS:') - - def test_from_set(self) -> None: - tags = {Tag('tag1'), Tag('tag2')} - tagline = TagLine.from_set(tags) - self.assertEqual(tagline, 'TAGS: tag1 tag2') - - def test_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2') - tags = tagline.tags() - self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) - - def test_tags_with_newline(self) -> None: - tagline = TagLine('TAGS: tag1\n tag2') - tags = tagline.tags() - self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) - - def test_merge(self) -> None: - tagline1 = TagLine('TAGS: tag1 tag2') - tagline2 = TagLine('TAGS: tag2 tag3') - merged_tagline = tagline1.merge({tagline2}) - self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3') - - def test_delete_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2 tag3') - new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')}) - self.assertEqual(new_tagline, 'TAGS: tag2') - - def test_add_tags(self) -> None: - tagline = TagLine('TAGS: tag1') - new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')}) - self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3') - - def test_rename_tags(self) -> None: - tagline = TagLine('TAGS: old1 old2') - new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))}) - self.assertEqual(new_tagline, 'TAGS: new1 new2') - - def test_match_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2 tag3') - - # Test case 1: Match any tag in 'tags_or' - tags_or = {Tag('tag1'), Tag('tag4')} - tags_and: set[Tag] = set() - tags_not: set[Tag] = set() - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 2: Match all tags in 'tags_and' - tags_or = set() - tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')} - tags_not = set() - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and' - tags_or = {Tag('tag1'), Tag('tag4')} - tags_and = {Tag('tag1'), Tag('tag2')} - tags_not = set() - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not' - tags_or = {Tag('tag1'), Tag('tag4')} - tags_and = {Tag('tag1'), Tag('tag2')} - tags_not = {Tag('tag5')} - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 5: No matching tags in 'tags_or' - tags_or = {Tag('tag4'), Tag('tag5')} - tags_and = set() - tags_not = set() - self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 6: Not all tags in 'tags_and' are present - tags_or = set() - tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')} - tags_not = set() - self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 7: Some tags in 'tags_not' are present - tags_or = {Tag('tag1')} - tags_and = set() - tags_not = {Tag('tag2')} - self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 8: 'tags_or' and 'tags_and' are None, match all tags - tags_not = set() - self.assertTrue(tagline.match_tags(None, None, tags_not)) - - # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags - tags_not = {Tag('tag2')} - self.assertFalse(tagline.match_tags(None, None, tags_not)) - - -class SourceCodeTestCase(CmmTestCase): - def test_source_code_with_include_delims(self) -> None: - text = """ - Some text before the code block - ```python - print("Hello, World!") - ``` - Some text after the code block - ```python - x = 10 - y = 20 - print(x + y) - ``` - """ - expected_result = [ - " ```python\n print(\"Hello, World!\")\n ```\n", - " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" - ] - result = source_code(text, include_delims=True) - self.assertEqual(result, expected_result) - - def test_source_code_without_include_delims(self) -> None: - text = """ - Some text before the code block - ```python - print("Hello, World!") - ``` - Some text after the code block - ```python - x = 10 - y = 20 - print(x + y) - ``` - """ - expected_result = [ - " print(\"Hello, World!\")\n", - " x = 10\n y = 20\n print(x + y)\n" - ] - result = source_code(text, include_delims=False) - self.assertEqual(result, expected_result) - - def test_source_code_with_single_code_block(self) -> None: - text = "```python\nprint(\"Hello, World!\")\n```" - expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] - result = source_code(text, include_delims=True) - self.assertEqual(result, expected_result) - - def test_source_code_with_no_code_blocks(self) -> None: - text = "Some text without any code blocks" - expected_result: list[str] = [] - result = source_code(text, include_delims=True) - self.assertEqual(result, expected_result) - - -class QuestionTestCase(CmmTestCase): - def test_question_with_prefix(self) -> None: - with self.assertRaises(MessageError): - Question("=== QUESTION === What is your name?") - - def test_question_without_prefix(self) -> None: - question = Question("What is your favorite color?") - self.assertIsInstance(question, Question) - self.assertEqual(question, "What is your favorite color?") - - -class AnswerTestCase(CmmTestCase): - def test_answer_with_prefix(self) -> None: - with self.assertRaises(MessageError): - Answer("=== ANSWER === Yes") - - def test_answer_without_prefix(self) -> None: - answer = Answer("No") - self.assertIsInstance(answer, Answer) - self.assertEqual(answer, "No") diff --git a/tests/test_message.py b/tests/test_message.py new file mode 100644 index 0000000..220fef2 --- /dev/null +++ b/tests/test_message.py @@ -0,0 +1,78 @@ +from .test_main import CmmTestCase +from chatmastermind.message import source_code, MessageError, Question, Answer + + +class SourceCodeTestCase(CmmTestCase): + def test_source_code_with_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " ```python\n print(\"Hello, World!\")\n ```\n", + " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" + ] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_without_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " print(\"Hello, World!\")\n", + " x = 10\n y = 20\n print(x + y)\n" + ] + result = source_code(text, include_delims=False) + self.assertEqual(result, expected_result) + + def test_source_code_with_single_code_block(self) -> None: + text = "```python\nprint(\"Hello, World!\")\n```" + expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_with_no_code_blocks(self) -> None: + text = "Some text without any code blocks" + expected_result: list[str] = [] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + +class QuestionTestCase(CmmTestCase): + def test_question_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Question("=== QUESTION === What is your name?") + + def test_question_without_prefix(self) -> None: + question = Question("What is your favorite color?") + self.assertIsInstance(question, Question) + self.assertEqual(question, "What is your favorite color?") + + +class AnswerTestCase(CmmTestCase): + def test_answer_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Answer("=== ANSWER === Yes") + + def test_answer_without_prefix(self) -> None: + answer = Answer("No") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, "No") diff --git a/tests/test_tags.py b/tests/test_tags.py new file mode 100644 index 0000000..9ac9746 --- /dev/null +++ b/tests/test_tags.py @@ -0,0 +1,124 @@ +from .test_main import CmmTestCase +from chatmastermind.tags import Tag, TagLine, TagError + + +class TestTag(CmmTestCase): + def test_valid_tag(self) -> None: + tag = Tag('mytag') + self.assertEqual(tag, 'mytag') + + def test_invalid_tag(self) -> None: + with self.assertRaises(TagError): + Tag('tag with space') + + def test_default_separator(self) -> None: + self.assertEqual(Tag.default_separator, ' ') + + def test_alternative_separators(self) -> None: + self.assertEqual(Tag.alternative_separators, [',']) + + +class TestTagLine(CmmTestCase): + def test_valid_tagline(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_valid_tagline_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_invalid_tagline(self) -> None: + with self.assertRaises(TagError): + TagLine('tag1 tag2') + + def test_prefix(self) -> None: + self.assertEqual(TagLine.prefix, 'TAGS:') + + def test_from_set(self) -> None: + tags = {Tag('tag1'), Tag('tag2')} + tagline = TagLine.from_set(tags) + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_tags_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_merge(self) -> None: + tagline1 = TagLine('TAGS: tag1 tag2') + tagline2 = TagLine('TAGS: tag2 tag3') + merged_tagline = tagline1.merge({tagline2}) + self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3') + + def test_delete_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag2') + + def test_add_tags(self) -> None: + tagline = TagLine('TAGS: tag1') + new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3') + + def test_rename_tags(self) -> None: + tagline = TagLine('TAGS: old1 old2') + new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))}) + self.assertEqual(new_tagline, 'TAGS: new1 new2') + + def test_match_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + + # Test case 1: Match any tag in 'tags_or' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and: set[Tag] = set() + tags_not: set[Tag] = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 2: Match all tags in 'tags_and' + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = {Tag('tag5')} + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 5: No matching tags in 'tags_or' + tags_or = {Tag('tag4'), Tag('tag5')} + tags_and = set() + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 6: Not all tags in 'tags_and' are present + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')} + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 7: Some tags in 'tags_not' are present + tags_or = {Tag('tag1')} + tags_and = set() + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 8: 'tags_or' and 'tags_and' are None, match all tags + tags_not = set() + self.assertTrue(tagline.match_tags(None, None, tags_not)) + + # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(None, None, tags_not)) -- 2.36.6 From fc1b8006a0298bac1392756275a08f65e6be4db4 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 20 Aug 2023 19:59:38 +0200 Subject: [PATCH 047/170] tests: added testcases for Message.from/to_file() and others --- tests/test_message.py | 545 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 544 insertions(+), 1 deletion(-) diff --git a/tests/test_message.py b/tests/test_message.py index 220fef2..0e326b4 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,5 +1,9 @@ +import pathlib +import tempfile +from typing import cast from .test_main import CmmTestCase -from chatmastermind.message import source_code, MessageError, Question, Answer +from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter +from chatmastermind.tags import Tag, TagLine class SourceCodeTestCase(CmmTestCase): @@ -76,3 +80,542 @@ class AnswerTestCase(CmmTestCase): answer = Answer("No") self.assertIsInstance(answer, Answer) self.assertEqual(answer, "No") + + +class MessageToFileTxtTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + self.message_complete = Message(Question('This is a question.'), + Answer('This is an answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_min = Message(Question('This is a question.'), + file_path=self.file_path) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_to_file_txt_complete(self) -> None: + self.message_complete.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{TagLine.prefix} tag1 tag2 +{AILine.prefix} ChatGPT +{ModelLine.prefix} gpt-3.5-turbo +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""" + self.assertEqual(content, expected_content) + + def test_to_file_txt_min(self) -> None: + self.message_min.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.txt_header} +This is a question. +""" + self.assertEqual(content, expected_content) + + def test_to_file_unsupported_file_type(self) -> None: + unsupported_file_path = pathlib.Path("example.doc") + with self.assertRaises(MessageError) as cm: + self.message_complete.to_file(unsupported_file_path) + self.assertEqual(str(cm.exception), "File type '.doc' is not supported") + + def test_to_file_no_file_path(self) -> None: + """ + Provoke an exception using an empty path. + """ + with self.assertRaises(MessageError) as cm: + # clear the internal file_path + self.message_complete.file_path = None + self.message_complete.to_file(None) + self.assertEqual(str(cm.exception), "Got no valid path to write message") + # reset the internal file_path + self.message_complete.file_path = self.file_path + + +class MessageToFileYamlTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path = pathlib.Path(self.file.name) + self.message_complete = Message(Question('This is a question.'), + Answer('This is an answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_multiline = Message(Question('This is a\nmultiline question.'), + Answer('This is a\nmultiline answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_min = Message(Question('This is a question.'), + file_path=self.file_path) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_to_file_yaml_complete(self) -> None: + self.message_complete.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.yaml_key}: This is a question. +{Answer.yaml_key}: This is an answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: +- tag1 +- tag2 +""" + self.assertEqual(content, expected_content) + + def test_to_file_yaml_multiline(self) -> None: + self.message_multiline.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.yaml_key}: |- + This is a + multiline question. +{Answer.yaml_key}: |- + This is a + multiline answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: +- tag1 +- tag2 +""" + self.assertEqual(content, expected_content) + + def test_to_file_yaml_min(self) -> None: + self.message_min.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"{Question.yaml_key}: This is a question.\n" + self.assertEqual(content, expected_content) + + +class MessageFromFileTxtTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + with open(self.file_path, "w") as fd: + fd.write(f"""{TagLine.prefix} tag1 tag2 +{AILine.prefix} ChatGPT +{ModelLine.prefix} gpt-3.5-turbo +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""") + self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_min = pathlib.Path(self.file_min.name) + with open(self.file_path_min, "w") as fd: + fd.write(f"""{Question.txt_header} +This is a question. +""") + + def tearDown(self) -> None: + self.file.close() + self.file_min.close() + self.file_path.unlink() + self.file_path_min.unlink() + + def test_from_file_txt_complete(self) -> None: + """ + Read a complete message (with all optional values). + """ + message = Message.from_file(self.file_path) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.ai, 'ChatGPT') + self.assertEqual(message.model, 'gpt-3.5-turbo') + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_txt_min(self) -> None: + """ + Read a message with only required values. + """ + message = Message.from_file(self.file_path_min) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.file_path, self.file_path_min) + self.assertIsNone(message.answer) + + def test_from_file_txt_tags_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_txt_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag3')})) + self.assertIsNone(message) + + def test_from_file_txt_no_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNone(message) + + def test_from_file_txt_no_tags_match_tags_not(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_not={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + + def test_from_file_not_exists(self) -> None: + file_not_exists = pathlib.Path("example.txt") + with self.assertRaises(MessageError) as cm: + Message.from_file(file_not_exists) + self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") + + def test_from_file_txt_question_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='question')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='answer')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_available(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='available')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_missing(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='missing')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_question_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='answer')) + self.assertIsNone(message) + + def test_from_file_txt_answer_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='question')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_exists(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_contains='answer')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_available(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='available')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_missing(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='missing')) + self.assertIsNone(message) + + def test_from_file_txt_ai_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='ChatGPT')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_ai_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='Foo')) + self.assertIsNone(message) + + def test_from_file_txt_model_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='gpt-3.5-turbo')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_model_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='Bar')) + self.assertIsNone(message) + + +class MessageFromFileYamlTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path = pathlib.Path(self.file.name) + with open(self.file_path, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: + - tag1 + - tag2 +""") + self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_min = pathlib.Path(self.file_min.name) + with open(self.file_path_min, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +""") + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + self.file_min.close() + self.file_path_min.unlink() + + def test_from_file_yaml_complete(self) -> None: + """ + Read a complete message (with all optional values). + """ + message = Message.from_file(self.file_path) + self.assertIsInstance(message, Message) + self.assertIsNotNone(message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.ai, 'ChatGPT') + self.assertEqual(message.model, 'gpt-3.5-turbo') + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_yaml_min(self) -> None: + """ + Read a message with only the required values. + """ + message = Message.from_file(self.file_path_min) + self.assertIsInstance(message, Message) + self.assertIsNotNone(message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + self.assertIsNone(message.answer) + + def test_from_file_not_exists(self) -> None: + file_not_exists = pathlib.Path("example.yaml") + with self.assertRaises(MessageError) as cm: + Message.from_file(file_not_exists) + self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") + + def test_from_file_yaml_tags_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_yaml_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag3')})) + self.assertIsNone(message) + + def test_from_file_yaml_no_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNone(message) + + def test_from_file_yaml_no_tags_match_tags_not(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_not={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + + def test_from_file_yaml_question_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='question')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='answer')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_available(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='available')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_missing(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='missing')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_question_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='answer')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='question')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_exists(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_contains='answer')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_available(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='available')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_missing(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='missing')) + self.assertIsNone(message) + + def test_from_file_yaml_ai_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='ChatGPT')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_ai_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='Foo')) + self.assertIsNone(message) + + def test_from_file_yaml_model_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='gpt-3.5-turbo')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_model_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='Bar')) + self.assertIsNone(message) + + +class TagsFromFileTestCase(CmmTestCase): + def setUp(self) -> None: + self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt = pathlib.Path(self.file_txt.name) + with open(self.file_path_txt, "w") as fd: + fd.write(f"""{TagLine.prefix} tag1 tag2 +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""") + self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_yaml = pathlib.Path(self.file_yaml.name) + with open(self.file_path_yaml, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. +{Message.tags_yaml_key}: + - tag1 + - tag2 +""") + + def tearDown(self) -> None: + self.file_txt.close() + self.file_path_txt.unlink() + self.file_yaml.close() + self.file_path_yaml.unlink() + + def test_tags_from_file_txt(self) -> None: + tags = Message.tags_from_file(self.file_path_txt) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_tags_from_file_yaml(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + + +class MessageIDTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + self.message = Message(Question('This is a question.'), + file_path=self.file_path) + self.message_no_file_path = Message(Question('This is a question.')) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_msg_id_txt(self) -> None: + self.assertEqual(self.message.msg_id(), self.file_path.name) + + def test_msg_id_txt_exception(self) -> None: + with self.assertRaises(MessageError): + self.message_no_file_path.msg_id() + + +class MessageHashTestCase(CmmTestCase): + def setUp(self) -> None: + self.message1 = Message(Question('This is a question.'), + tags={Tag('tag1')}, + file_path=pathlib.Path('/tmp/foo/bla')) + self.message2 = Message(Question('This is a new question.'), + file_path=pathlib.Path('/tmp/foo/bla')) + self.message3 = Message(Question('This is a question.'), + Answer('This is an answer.'), + file_path=pathlib.Path('/tmp/foo/bla')) + # message4 is a copy of message1, because only question and + # answer are used for hashing and comparison + self.message4 = Message(Question('This is a question.'), + tags={Tag('tag1'), Tag('tag2')}, + ai='Blabla', + file_path=pathlib.Path('foobla')) + + def test_set_hashing(self) -> None: + msgs: set[Message] = {self.message1, self.message2, self.message3, self.message4} + self.assertEqual(len(msgs), 3) + for msg in [self.message1, self.message2, self.message3]: + self.assertIn(msg, msgs) -- 2.36.6 From 7f91a2b567c721f96e11ffc0156a90d3f59a5032 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 26 Aug 2023 12:50:47 +0200 Subject: [PATCH 048/170] Added tags filtering (prefix and contained string) to TagLine and Message --- chatmastermind/message.py | 71 ++++++++++++++++++++++-- chatmastermind/tags.py | 12 +++- tests/test_message.py | 113 +++++++++++++++++++++++++++++++++++++- tests/test_tags.py | 22 +++++++- 4 files changed, 204 insertions(+), 14 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 157cd46..902aaa2 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -219,21 +219,57 @@ class Message(): file_path=data.get(cls.file_yaml_key, None)) @classmethod - def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]: + def tags_from_file(cls: Type[MessageInst], + file_path: pathlib.Path, + prefix: Optional[str] = None, + contain: Optional[str] = None) -> set[Tag]: """ - Return only the tags from the given Message file. + Return only the tags from the given Message file, + optionally filtered based on prefix or contained string. """ + tags: set[Tag] = set() if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") if file_path.suffix not in cls.file_suffixes: raise MessageError(f"File type '{file_path.suffix}' is not supported") + # for TXT, it's enough to read the TagLine if file_path.suffix == '.txt': with open(file_path, "r") as fd: - tags = TagLine(fd.readline()).tags() + try: + tags = TagLine(fd.readline()).tags(prefix, contain) + except TagError: + pass # message without tags else: # '.yaml' - with open(file_path, "r") as fd: - data = yaml.load(fd, Loader=yaml.FullLoader) - tags = set(sorted(data[cls.tags_yaml_key])) + try: + message = cls.from_file(file_path) + if message: + msg_tags = message.filter_tags(prefix=prefix, contain=contain) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") + if msg_tags: + tags = msg_tags + return tags + + @classmethod + def tags_from_dir(cls: Type[MessageInst], + path: pathlib.Path, + glob: Optional[str] = None, + prefix: Optional[str] = None, + contain: Optional[str] = None) -> set[Tag]: + + """ + Return only the tags from message files in the given directory. + The files can be filtered using 'glob', the tags by using 'prefix' + and 'contain'. + """ + tags: set[Tag] = set() + file_iter = path.glob(glob) if glob else path.iterdir() + for file_path in sorted(file_iter): + if file_path.is_file(): + try: + tags |= cls.tags_from_file(file_path, prefix, contain) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") return tags @classmethod @@ -395,6 +431,29 @@ class Message(): data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) yaml.dump(data, fd, sort_keys=False) + def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: + """ + Filter tags based on their prefix (i. e. the tag starts with a given string) + or some contained string. + """ + res_tags = self.tags + if res_tags: + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags or set() + + def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str: + """ + Returns all tags as a string with the TagLine prefix. Optionally filtered + using 'Message.filter_tags()'. + """ + if self.tags: + return str(TagLine.from_set(self.filter_tags(prefix, contain))) + else: + return str(TagLine.from_set(set())) + def match(self, mfilter: MessageFilter) -> bool: # noqa: 13 """ Matches the current Message to the given filter atttributes. diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index 544270c..c438db9 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -118,9 +118,10 @@ class TagLine(str): """ return cls(' '.join([cls.prefix] + sorted([t for t in tags]))) - def tags(self) -> set[Tag]: + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ - Returns all tags contained in this line as a set. + Returns all tags contained in this line as a set, optionally + filtered based on prefix or contained string. """ tagstr = self[len(self.prefix):].strip() separator = Tag.default_separator @@ -130,7 +131,12 @@ class TagLine(str): if s in tagstr: separator = s break - return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + res_tags = set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags or set() def merge(self, taglines: set['TagLine']) -> 'TagLine': """ diff --git a/tests/test_message.py b/tests/test_message.py index 0e326b4..7b8aee9 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -543,11 +543,19 @@ class TagsFromFileTestCase(CmmTestCase): self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path_txt = pathlib.Path(self.file_txt.name) with open(self.file_path_txt, "w") as fd: - fd.write(f"""{TagLine.prefix} tag1 tag2 + fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3 {Question.txt_header} This is a question. {Answer.txt_header} This is an answer. +""") + self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name) + with open(self.file_path_txt_no_tags, "w") as fd: + fd.write(f"""{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. """) self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path_yaml = pathlib.Path(self.file_yaml.name) @@ -560,6 +568,16 @@ This is an answer. {Message.tags_yaml_key}: - tag1 - tag2 + - ptag3 +""") + self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name) + with open(self.file_path_yaml_no_tags, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. """) def tearDown(self) -> None: @@ -570,11 +588,90 @@ This is an answer. def test_tags_from_file_txt(self) -> None: tags = Message.tags_from_file(self.file_path_txt) - self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) + + def test_tags_from_file_txt_no_tags(self) -> None: + tags = Message.tags_from_file(self.file_path_txt_no_tags) + self.assertSetEqual(tags, set()) def test_tags_from_file_yaml(self) -> None: tags = Message.tags_from_file(self.file_path_yaml) - self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) + + def test_tags_from_file_yaml_no_tags(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml_no_tags) + self.assertSetEqual(tags, set()) + + def test_tags_from_file_txt_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_txt, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_txt, prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_yaml_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_yaml, prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_txt_contain(self) -> None: + tags = Message.tags_from_file(self.file_path_txt, contain='3') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_txt, contain='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_yaml_contain(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml, contain='3') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_yaml, contain='R') + self.assertSetEqual(tags, set()) + + +class TagsFromDirTestCase(CmmTestCase): + def setUp(self) -> None: + self.temp_dir = tempfile.TemporaryDirectory() + self.temp_dir_no_tags = tempfile.TemporaryDirectory() + self.tag_sets = [ + {Tag('atag1'), Tag('atag2')}, + {Tag('btag3'), Tag('btag4')}, + {Tag('ctag5'), Tag('ctag6')} + ] + self.files = [ + pathlib.Path(self.temp_dir.name, 'file1.txt'), + pathlib.Path(self.temp_dir.name, 'file2.yaml'), + pathlib.Path(self.temp_dir.name, 'file3.txt') + ] + self.files_no_tags = [ + pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'), + pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'), + pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt') + ] + for file, tags in zip(self.files, self.tag_sets): + message = Message(Question('This is a question.'), + Answer('This is an answer.'), + tags) + message.to_file(file) + for file in self.files_no_tags: + message = Message(Question('This is a question.'), + Answer('This is an answer.')) + message.to_file(file) + + def tearDown(self) -> None: + self.temp_dir.cleanup() + + def test_tags_from_dir(self) -> None: + all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name)) + expected_tags = self.tag_sets[0] | self.tag_sets[1] | self.tag_sets[2] + self.assertEqual(all_tags, expected_tags) + + def test_tags_from_dir_prefix(self) -> None: + atags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name), prefix='a') + expected_tags = self.tag_sets[0] + self.assertEqual(atags, expected_tags) + + def test_tags_from_dir_no_tags(self) -> None: + all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir_no_tags.name)) + self.assertSetEqual(all_tags, set()) class MessageIDTestCase(CmmTestCase): @@ -619,3 +716,13 @@ class MessageHashTestCase(CmmTestCase): self.assertEqual(len(msgs), 3) for msg in [self.message1, self.message2, self.message3]: self.assertIn(msg, msgs) + + +class MessageTagsStrTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('tag1')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_tags_str(self) -> None: + self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') diff --git a/tests/test_tags.py b/tests/test_tags.py index 9ac9746..bd2b685 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -40,15 +40,33 @@ class TestTagLine(CmmTestCase): self.assertEqual(tagline, 'TAGS: tag1 tag2') def test_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2') + tagline = TagLine('TAGS: atag1 btag2') tags = tagline.tags() - self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertEqual(tags, {Tag('atag1'), Tag('btag2')}) def test_tags_with_newline(self) -> None: tagline = TagLine('TAGS: tag1\n tag2') tags = tagline.tags() self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_tags_prefix(self) -> None: + tagline = TagLine('TAGS: atag1 stag2 stag3') + tags = tagline.tags(prefix='a') + self.assertSetEqual(tags, {Tag('atag1')}) + tags = tagline.tags(prefix='s') + self.assertSetEqual(tags, {Tag('stag2'), Tag('stag3')}) + tags = tagline.tags(prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_contain(self) -> None: + tagline = TagLine('TAGS: atag1 stag2 stag3') + tags = tagline.tags(contain='t') + self.assertSetEqual(tags, {Tag('atag1'), Tag('stag2'), Tag('stag3')}) + tags = tagline.tags(contain='1') + self.assertSetEqual(tags, {Tag('atag1')}) + tags = tagline.tags(contain='R') + self.assertSetEqual(tags, set()) + def test_merge(self) -> None: tagline1 = TagLine('TAGS: tag1 tag2') tagline2 = TagLine('TAGS: tag2 tag3') -- 2.36.6 From 169f1bb4585c495d8ce34856bf044af6da4bcc50 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 27 Aug 2023 18:07:38 +0200 Subject: [PATCH 049/170] fixed handling empty tags in TXT file --- chatmastermind/tags.py | 2 ++ tests/test_message.py | 13 +++++++++++++ tests/test_tags.py | 4 ++++ 3 files changed, 19 insertions(+) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index c438db9..bb45a08 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -124,6 +124,8 @@ class TagLine(str): filtered based on prefix or contained string. """ tagstr = self[len(self.prefix):].strip() + if tagstr == '': + return set() # no tags, only prefix separator = Tag.default_separator # look for alternative separators and use the first one found # -> we don't support different separators in the same TagLine diff --git a/tests/test_message.py b/tests/test_message.py index 7b8aee9..9cfb30a 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -556,6 +556,15 @@ This is an answer. This is a question. {Answer.txt_header} This is an answer. +""") + self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name) + with open(self.file_path_txt_tags_empty, "w") as fd: + fd.write(f"""TAGS: +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. """) self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path_yaml = pathlib.Path(self.file_yaml.name) @@ -594,6 +603,10 @@ This is an answer. tags = Message.tags_from_file(self.file_path_txt_no_tags) self.assertSetEqual(tags, set()) + def test_tags_from_file_txt_tags_empty(self) -> None: + tags = Message.tags_from_file(self.file_path_txt_tags_empty) + self.assertSetEqual(tags, set()) + def test_tags_from_file_yaml(self) -> None: tags = Message.tags_from_file(self.file_path_yaml) self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) diff --git a/tests/test_tags.py b/tests/test_tags.py index bd2b685..eeab199 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -44,6 +44,10 @@ class TestTagLine(CmmTestCase): tags = tagline.tags() self.assertEqual(tags, {Tag('atag1'), Tag('btag2')}) + def test_tags_empty(self) -> None: + tagline = TagLine('TAGS:') + self.assertSetEqual(tagline.tags(), set()) + def test_tags_with_newline(self) -> None: tagline = TagLine('TAGS: tag1\n tag2') tags = tagline.tags() -- 2.36.6 From 73d2a9ea3b866d7780b101ed255d7a4a198969fc Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 29 Aug 2023 11:35:18 +0200 Subject: [PATCH 050/170] fixed test case file cleanup --- tests/test_message.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_message.py b/tests/test_message.py index 9cfb30a..83a73ea 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -594,6 +594,12 @@ This is an answer. self.file_path_txt.unlink() self.file_yaml.close() self.file_path_yaml.unlink() + self.file_txt_no_tags.close + self.file_path_txt_no_tags.unlink() + self.file_txt_tags_empty.close + self.file_path_txt_tags_empty.unlink() + self.file_yaml_no_tags.close() + self.file_path_yaml_no_tags.unlink() def test_tags_from_file_txt(self) -> None: tags = Message.tags_from_file(self.file_path_txt) @@ -671,6 +677,7 @@ class TagsFromDirTestCase(CmmTestCase): def tearDown(self) -> None: self.temp_dir.cleanup() + self.temp_dir_no_tags.cleanup() def test_tags_from_dir(self) -> None: all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name)) -- 2.36.6 From 8e1cdee3bfca4c6b26c5a086feb9ac3671395c1c Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 30 Aug 2023 08:20:25 +0200 Subject: [PATCH 051/170] fixed Message.filter_tags --- chatmastermind/message.py | 15 ++++++++------- tests/test_message.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 902aaa2..820d104 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -436,13 +436,14 @@ class Message(): Filter tags based on their prefix (i. e. the tag starts with a given string) or some contained string. """ - res_tags = self.tags - if res_tags: - if prefix and len(prefix) > 0: - res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} - if contain and len(contain) > 0: - res_tags -= {tag for tag in res_tags if contain not in tag} - return res_tags or set() + if not self.tags: + return set() + res_tags = self.tags.copy() + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str: """ diff --git a/tests/test_message.py b/tests/test_message.py index 83a73ea..2a9d0ff 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -746,3 +746,18 @@ class MessageTagsStrTestCase(CmmTestCase): def test_tags_str(self) -> None: self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') + + +class MessageFilterTagsTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_filter_tags(self) -> None: + tags_all = self.message.filter_tags() + self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) + tags_pref = self.message.filter_tags(prefix='a') + self.assertSetEqual(tags_pref, {Tag('atag1')}) + tags_cont = self.message.filter_tags(contain='2') + self.assertSetEqual(tags_cont, {Tag('btag2')}) -- 2.36.6 From b83cbb719bc7ca617a6ab4c5e05bd94a8a5ef0d8 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 31 Aug 2023 09:19:38 +0200 Subject: [PATCH 052/170] added 'message_in()' function and test --- chatmastermind/message.py | 16 +++++++++++++++- tests/test_message.py | 16 +++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 820d104..3eca26e 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -3,7 +3,7 @@ Module implementing message related functions and classes. """ import pathlib import yaml -from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal +from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags @@ -57,6 +57,20 @@ def source_code(text: str, include_delims: bool = False) -> list[str]: return code_sections +def message_in(message: MessageInst, messages: Iterable[MessageInst]) -> bool: + """ + Searches the given message list for a message with the same file + name as the given one (i. e. it compares Message.file_path.name). + If the given message has no file_path, False is returned. + """ + if not message.file_path: + return False + for m in messages: + if m.file_path and m.file_path.name == message.file_path.name: + return True + return False + + @dataclass(kw_only=True) class MessageFilter: """ diff --git a/tests/test_message.py b/tests/test_message.py index 2a9d0ff..0d7953e 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -2,7 +2,7 @@ import pathlib import tempfile from typing import cast from .test_main import CmmTestCase -from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter +from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in from chatmastermind.tags import Tag, TagLine @@ -761,3 +761,17 @@ class MessageFilterTagsTestCase(CmmTestCase): self.assertSetEqual(tags_pref, {Tag('atag1')}) tags_cont = self.message.filter_tags(contain='2') self.assertSetEqual(tags_cont, {Tag('btag2')}) + + +class MessageInTestCase(CmmTestCase): + def setUp(self) -> None: + self.message1 = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + self.message2 = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/bla/foo')) + + def test_message_in(self) -> None: + self.assertTrue(message_in(self.message1, [self.message1])) + self.assertFalse(message_in(self.message1, [self.message2])) -- 2.36.6 From 214a6919db1051437f2b0f05b1ce8ababd05a8b0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 31 Aug 2023 15:47:29 +0200 Subject: [PATCH 053/170] tags: some clarification and new tests --- chatmastermind/tags.py | 3 ++- tests/test_tags.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index bb45a08..5ea1a3a 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -77,7 +77,8 @@ def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[s i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' or all of the tags in 'tags_and' but it must never contain any of the tags in 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag - exclusion is still done if 'tags_not' is not 'None'). + exclusion is still done if 'tags_not' is not 'None'). If they are empty (set()), + they match no tags. """ required_tags_present = False excluded_tags_missing = False diff --git a/tests/test_tags.py b/tests/test_tags.py index eeab199..aa89a06 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -144,3 +144,20 @@ class TestTagLine(CmmTestCase): # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags tags_not = {Tag('tag2')} self.assertFalse(tagline.match_tags(None, None, tags_not)) + + # Test case 10: 'tags_or' and 'tags_and' are empty, match no tags + self.assertFalse(tagline.match_tags(set(), set(), None)) + + # Test case 11: 'tags_or' is empty, match no tags + self.assertFalse(tagline.match_tags(set(), None, None)) + + # Test case 12: 'tags_and' is empty, match no tags + self.assertFalse(tagline.match_tags(None, set(), None)) + + # Test case 13: 'tags_or' is empty, match 'tags_and' + tags_and = {Tag('tag1'), Tag('tag2')} + self.assertTrue(tagline.match_tags(None, tags_and, None)) + + # Test case 14: 'tags_and' is empty, match 'tags_or' + tags_or = {Tag('tag1'), Tag('tag2')} + self.assertTrue(tagline.match_tags(tags_or, None, None)) -- 2.36.6 From 9f4897a5b8e94bb347b03da63e089b7f18eb6e77 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 24 Aug 2023 16:49:54 +0200 Subject: [PATCH 054/170] added new module 'chat.py' --- chatmastermind/chat.py | 278 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 278 insertions(+) create mode 100644 chatmastermind/chat.py diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py new file mode 100644 index 0000000..c5d8bf3 --- /dev/null +++ b/chatmastermind/chat.py @@ -0,0 +1,278 @@ +""" +Module implementing various chat classes and functions for managing a chat history. +""" +import shutil +import pathlib +from pprint import PrettyPrinter +from pydoc import pager +from dataclasses import dataclass +from typing import TypeVar, Type, Optional, ClassVar, Any, Callable +from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in +from .tags import Tag + +ChatInst = TypeVar('ChatInst', bound='Chat') +ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') + + +class ChatError(Exception): + pass + + +def terminal_width() -> int: + return shutil.get_terminal_size().columns + + +def pp(*args: Any, **kwargs: Any) -> None: + return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) + + +def print_paged(text: str) -> None: + pager(text) + + +def read_dir(dir_path: pathlib.Path, + glob: Optional[str] = None, + mfilter: Optional[MessageFilter] = None) -> list[Message]: + """ + Reads the messages from the given folder. + Parameters: + * 'dir_path': source directory + * 'glob': if specified, files will be filtered using 'path.glob()', + otherwise it uses 'path.iterdir()'. + * 'mfilter': use with 'Message.from_file()' to filter messages + when reading them. + """ + messages: list[Message] = [] + file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() + for file_path in sorted(file_iter): + if file_path.is_file(): + try: + message = Message.from_file(file_path, mfilter) + if message: + messages.append(message) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") + return messages + + +def write_dir(dir_path: pathlib.Path, + messages: list[Message], + file_suffix: str, + next_fid: Callable[[], int]) -> None: + """ + Write all messages to the given directory. If a message has no file_path, + a new one will be created. If message.file_path exists, it will be modified + to point to the given directory. + Parameters: + * 'dir_path': destination directory + * 'messages': list of messages to write + * 'file_suffix': suffix for the message files ['.txt'|'.yaml'] + * 'next_fid': callable that returns the next file ID + """ + for message in messages: + file_path = message.file_path + # message has no file_path: create one + if not file_path: + fid = next_fid() + fname = f"{fid:04d}{file_suffix}" + file_path = dir_path / fname + # file_path does not point to given directory: modify it + elif not file_path.parent.samefile(dir_path): + file_path = dir_path / file_path.name + message.to_file(file_path) + + +@dataclass +class Chat: + """ + A class containing a complete chat history. + """ + + messages: list[Message] + + def filter(self, mfilter: MessageFilter) -> None: + """ + Use 'Message.match(mfilter) to remove all messages that + don't fulfill the filter requirements. + """ + self.messages = [m for m in self.messages if m.match(mfilter)] + + def sort(self, reverse: bool = False) -> None: + """ + Sort the messages according to 'Message.msg_id()'. + """ + try: + # the message may not have an ID if it doesn't have a file_path + self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse) + except MessageError: + pass + + def clear(self) -> None: + """ + Delete all messages. + """ + self.messages = [] + + def add_msgs(self, msgs: list[Message]) -> None: + """ + Add new messages and sort them if possible. + """ + self.messages += msgs + self.sort() + + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: + """ + Get the tags of all messages, optionally filtered by prefix or substring. + """ + tags: set[Tag] = set() + for m in self.messages: + tags |= m.filter_tags(prefix, contain) + return tags + + def print(self, dump: bool = False, source_code_only: bool = False, + with_tags: bool = False, with_file: bool = False, + paged: bool = True) -> None: + if dump: + pp(self) + return + output: list[str] = [] + for message in self.messages: + if source_code_only: + output.extend(source_code(message.question, include_delims=True)) + continue + output.append('-' * terminal_width()) + output.append(Question.txt_header) + output.append(message.question) + if message.answer: + output.append(Answer.txt_header) + output.append(message.answer) + if with_tags: + output.append(message.tags_str()) + if with_file: + output.append('FILE: ' + str(message.file_path)) + if paged: + print_paged('\n'.join(output)) + else: + print(*output, sep='\n') + + +@dataclass +class ChatDB(Chat): + """ + A 'Chat' class that is bound to a given directory structure. Supports reading + and writing messages from / to that structure. Such a structure consists of + two directories: a 'cache directory', where all messages are temporarily + stored, and a 'DB' directory, where selected messages can be stored + persistently. + """ + + default_file_suffix: ClassVar[str] = '.txt' + + cache_path: pathlib.Path + db_path: pathlib.Path + # a MessageFilter that all messages must match (if given) + mfilter: Optional[MessageFilter] = None + file_suffix: str = default_file_suffix + # the glob pattern for all messages + glob: Optional[str] = None + + def __post_init__(self) -> None: + # contains the latest message ID + self.next_fname = self.db_path / '.next' + # make all paths absolute + self.cache_path = self.cache_path.absolute() + self.db_path = self.db_path.absolute() + + @classmethod + def from_dir(cls: Type[ChatDBInst], + cache_path: pathlib.Path, + db_path: pathlib.Path, + glob: Optional[str] = None, + mfilter: Optional[MessageFilter] = None) -> ChatDBInst: + """ + Create a 'ChatDB' instance from the given directory structure. + Reads all messages from 'db_path' into the local message list. + Parameters: + * 'cache_path': path to the directory for temporary messages + * 'db_path': path to the directory for persistent messages + * 'glob': if specified, files will be filtered using 'path.glob()', + otherwise it uses 'path.iterdir()'. + * 'mfilter': use with 'Message.from_file()' to filter messages + when reading them. + """ + messages = read_dir(db_path, glob, mfilter) + return cls(messages, cache_path, db_path, mfilter, + cls.default_file_suffix, glob) + + @classmethod + def from_messages(cls: Type[ChatDBInst], + cache_path: pathlib.Path, + db_path: pathlib.Path, + messages: list[Message], + mfilter: Optional[MessageFilter] = None) -> ChatDBInst: + """ + Create a ChatDB instance from the given message list. + """ + return cls(messages, cache_path, db_path, mfilter) + + def get_next_fid(self) -> int: + try: + with open(self.next_fname, 'r') as f: + next_fid = int(f.read()) + 1 + self.set_next_fid(next_fid) + return next_fid + except Exception: + self.set_next_fid(1) + return 1 + + def set_next_fid(self, fid: int) -> None: + with open(self.next_fname, 'w') as f: + f.write(f'{fid}') + + def read_db(self) -> None: + """ + Reads new messages from the DB directory. New ones are added to the internal list, + existing ones are replaced. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. + """ + new_messages = read_dir(self.db_path, self.glob, self.mfilter) + # remove all messages from self.messages that are in the new list + self.messages = [m for m in self.messages if not message_in(m, new_messages)] + # copy the messages from the temporary list to self.messages and sort them + self.messages += new_messages + self.sort() + + def read_cache(self) -> None: + """ + Reads new messages from the cache directory. New ones are added to the internal list, + existing ones are replaced. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. + """ + new_messages = read_dir(self.cache_path, self.glob, self.mfilter) + # remove all messages from self.messages that are in the new list + self.messages = [m for m in self.messages if not message_in(m, new_messages)] + # copy the messages from the temporary list to self.messages and sort them + self.messages += new_messages + self.sort() + + def write_db(self, msgs: Optional[list[Message]] = None) -> None: + """ + Write messages to the DB directory. If a message has no file_path, a new one + will be created. If message.file_path exists, it will be modified to point + to the DB directory. + """ + write_dir(self.db_path, + msgs if msgs else self.messages, + self.file_suffix, + self.get_next_fid) + + def write_cache(self, msgs: Optional[list[Message]] = None) -> None: + """ + Write messages to the cache directory. If a message has no file_path, a new one + will be created. If message.file_path exists, it will be modified to point to + the cache directory. + """ + write_dir(self.cache_path, + msgs if msgs else self.messages, + self.file_suffix, + self.get_next_fid) -- 2.36.6 From 93290da5b5badac83c1319bd5643475894b77697 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 28 Aug 2023 14:24:24 +0200 Subject: [PATCH 055/170] added tests for 'chat.py' --- tests/test_chat.py | 297 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 297 insertions(+) create mode 100644 tests/test_chat.py diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..2d0ffa0 --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,297 @@ +import pathlib +import tempfile +import time +from io import StringIO +from unittest.mock import patch +from chatmastermind.tags import TagLine +from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter +from chatmastermind.chat import Chat, ChatDB, terminal_width +from .test_main import CmmTestCase + + +class TestChat(CmmTestCase): + def setUp(self) -> None: + self.chat = Chat([]) + self.message1 = Message(Question('Question 1'), + Answer('Answer 1'), + {Tag('atag1')}, + file_path=pathlib.Path('0001.txt')) + self.message2 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('btag2')}, + file_path=pathlib.Path('0002.txt')) + + def test_filter(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.filter(MessageFilter(answer_contains='Answer 1')) + + self.assertEqual(len(self.chat.messages), 1) + self.assertEqual(self.chat.messages[0].question, 'Question 1') + + def test_sort(self) -> None: + self.chat.add_msgs([self.message2, self.message1]) + self.chat.sort() + self.assertEqual(self.chat.messages[0].question, 'Question 1') + self.assertEqual(self.chat.messages[1].question, 'Question 2') + self.chat.sort(reverse=True) + self.assertEqual(self.chat.messages[0].question, 'Question 2') + self.assertEqual(self.chat.messages[1].question, 'Question 1') + + def test_clear(self) -> None: + self.chat.add_msgs([self.message1]) + self.chat.clear() + self.assertEqual(len(self.chat.messages), 0) + + def test_add_msgs(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.assertEqual(len(self.chat.messages), 2) + self.assertEqual(self.chat.messages[0].question, 'Question 1') + self.assertEqual(self.chat.messages[1].question, 'Question 2') + + def test_tags(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + tags_all = self.chat.tags() + self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) + tags_pref = self.chat.tags(prefix='a') + self.assertSetEqual(tags_pref, {Tag('atag1')}) + tags_cont = self.chat.tags(contain='2') + self.assertSetEqual(tags_cont, {Tag('btag2')}) + + @patch('sys.stdout', new_callable=StringIO) + def test_print(self, mock_stdout: StringIO) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.print(paged=False) + expected_output = f"""{'-'*terminal_width()} +{Question.txt_header} +Question 1 +{Answer.txt_header} +Answer 1 +{'-'*terminal_width()} +{Question.txt_header} +Question 2 +{Answer.txt_header} +Answer 2 +""" + self.assertEqual(mock_stdout.getvalue(), expected_output) + + @patch('sys.stdout', new_callable=StringIO) + def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.print(paged=False, with_tags=True, with_file=True) + expected_output = f"""{'-'*terminal_width()} +{Question.txt_header} +Question 1 +{Answer.txt_header} +Answer 1 +{TagLine.prefix} atag1 +FILE: 0001.txt +{'-'*terminal_width()} +{Question.txt_header} +Question 2 +{Answer.txt_header} +Answer 2 +{TagLine.prefix} btag2 +FILE: 0002.txt +""" + self.assertEqual(mock_stdout.getvalue(), expected_output) + + +class TestChatDB(CmmTestCase): + def setUp(self) -> None: + self.db_path = tempfile.TemporaryDirectory() + self.cache_path = tempfile.TemporaryDirectory() + + self.message1 = Message(Question('Question 1'), + Answer('Answer 1'), + {Tag('tag1')}, + file_path=pathlib.Path('0001.txt')) + self.message2 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('tag2')}, + file_path=pathlib.Path('0002.yaml')) + self.message3 = Message(Question('Question 3'), + Answer('Answer 3'), + {Tag('tag3')}, + file_path=pathlib.Path('0003.txt')) + self.message4 = Message(Question('Question 4'), + Answer('Answer 4'), + {Tag('tag4')}, + file_path=pathlib.Path('0004.yaml')) + + self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt')) + self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) + self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) + self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) + + def tearDown(self) -> None: + self.db_path.cleanup() + self.cache_path.cleanup() + pass + + def test_chat_db_from_dir(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(len(chat_db.messages), 4) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + # check that the files are sorted + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, + pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, + pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, + pathlib.Path(self.db_path.name, '0004.yaml')) + + def test_chat_db_from_dir_glob(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + glob='*.txt') + self.assertEqual(len(chat_db.messages), 2) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, + pathlib.Path(self.db_path.name, '0003.txt')) + + def test_chat_db_filter(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(answer_contains='Answer 2')) + self.assertEqual(len(chat_db.messages), 1) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[0].answer, 'Answer 2') + + def test_chat_db_from_messges(self) -> None: + chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + messages=[self.message1, self.message2, + self.message3, self.message4]) + self.assertEqual(len(chat_db.messages), 4) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + + def test_chat_db_fids(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.get_next_fid(), 1) + self.assertEqual(chat_db.get_next_fid(), 2) + self.assertEqual(chat_db.get_next_fid(), 3) + with open(chat_db.next_fname, 'r') as f: + self.assertEqual(f.read(), '3') + + def test_chat_db_write(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # check that Message.file_path is correct + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + # write the messages to the cache directory + chat_db.write_cache() + # check if the written files are in the cache directory + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 4) + self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files) + # check that Message.file_path has been correctly updated + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) + + # check the timestamp of the files in the DB directory + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} + # overwrite the messages in the db directory + time.sleep(0.05) + chat_db.write_db() + # check if the written files are in the DB directory + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + # check if all files in the DB dir have actually been overwritten + for file in db_dir_files: + self.assertGreater(file.stat().st_mtime, old_timestamps[file]) + # check that Message.file_path has been correctly updated (again) + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + def test_chat_db_read(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(len(chat_db.messages), 4) + + # create 2 new files in the DB directory + new_message1 = Message(Question('Question 5'), + Answer('Answer 5'), + {Tag('tag5')}) + new_message2 = Message(Question('Question 6'), + Answer('Answer 6'), + {Tag('tag6')}) + new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) + new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) + # read and check them + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 6) + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # create 2 new files in the cache directory + new_message3 = Message(Question('Question 7'), + Answer('Answer 5'), + {Tag('tag7')}) + new_message4 = Message(Question('Question 8'), + Answer('Answer 6'), + {Tag('tag8')}) + new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt')) + new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml')) + # read and check them + chat_db.read_cache() + self.assertEqual(len(chat_db.messages), 8) + # check that the new message have the cache dir path + self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt')) + self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml')) + # an the old ones keep their path (since they have not been replaced) + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # now overwrite two messages in the DB directory + new_message1.question = Question('New Question 1') + new_message2.question = Question('New Question 2') + new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) + new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) + # read from the DB dir and check if the modified messages have been updated + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 8) + self.assertEqual(chat_db.messages[4].question, 'New Question 1') + self.assertEqual(chat_db.messages[5].question, 'New Question 2') + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # now write the messages from the cache to the DB directory + new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt')) + new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml')) + # read and check them + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 8) + # check that they now have the DB path + self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) + self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) -- 2.36.6 From 7f612bfc1745711334dac3f427d2cd63b988eda1 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 08:57:54 +0200 Subject: [PATCH 056/170] added tokens() function to Message and Chat --- chatmastermind/chat.py | 7 +++++++ chatmastermind/message.py | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index c5d8bf3..4a458df 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -129,6 +129,13 @@ class Chat: tags |= m.filter_tags(prefix, contain) return tags + def tokens(self) -> int: + """ + Returns the nr. of AI language tokens used by all messages in this chat. + If unknown, 0 is returned. + """ + return sum(m.tokens() for m in self.messages) + def print(self, dump: bool = False, source_code_only: bool = False, with_tags: bool = False, with_file: bool = False, paged: bool = True) -> None: diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 3eca26e..675ab3a 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -132,6 +132,7 @@ class Question(str): """ A single question with a defined header. """ + tokens: int = 0 # tokens used by this question txt_header: ClassVar[str] = '=== QUESTION ===' yaml_key: ClassVar[str] = 'question' @@ -165,6 +166,7 @@ class Answer(str): """ A single answer with a defined header. """ + tokens: int = 0 # tokens used by this answer txt_header: ClassVar[str] = '=== ANSWER ===' yaml_key: ClassVar[str] = 'answer' @@ -502,3 +504,13 @@ class Message(): def as_dict(self) -> dict[str, Any]: return asdict(self) + + def tokens(self) -> int: + """ + Returns the nr. of AI language tokens used by this message. + If unknown, 0 is returned. + """ + if self.answer: + return self.question.tokens + self.answer.tokens + else: + return self.question.tokens -- 2.36.6 From d93598a74fa7490f79158b219ac5a22f2310ccb1 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 09:07:58 +0200 Subject: [PATCH 057/170] configuration: added AIConfig class --- chatmastermind/configuration.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 5ae32d6..0780604 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -7,7 +7,15 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') @dataclass -class OpenAIConfig(): +class AIConfig: + """ + The base class of all AI configurations. + """ + name: str + + +@dataclass +class OpenAIConfig(AIConfig): """ The OpenAI section of the configuration file. """ @@ -25,6 +33,7 @@ class OpenAIConfig(): Create OpenAIConfig from a dict. """ return cls( + name='OpenAI', api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), @@ -36,7 +45,7 @@ class OpenAIConfig(): @dataclass -class Config(): +class Config: """ The configuration file structure. """ @@ -47,7 +56,7 @@ class Config(): @classmethod def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: """ - Create OpenAIConfig from a dict. + Create Config from a dict. """ return cls( system=str(source['system']), -- 2.36.6 From ddfe29b9511e7e77c4df45e3e2ac55b8d10c5a36 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 12:35:32 +0200 Subject: [PATCH 058/170] chat: added tags_frequency() function and test --- chatmastermind/chat.py | 11 ++++++++++- tests/test_chat.py | 9 +++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 4a458df..759467d 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -127,7 +127,16 @@ class Chat: tags: set[Tag] = set() for m in self.messages: tags |= m.filter_tags(prefix, contain) - return tags + return set(sorted(tags)) + + def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]: + """ + Get the frequency of all tags of all messages, optionally filtered by prefix or substring. + """ + tags: list[Tag] = [] + for m in self.messages: + tags += [tag for tag in m.filter_tags(prefix, contain)] + return {tag: tags.count(tag) for tag in sorted(tags)} def tokens(self) -> int: """ diff --git a/tests/test_chat.py b/tests/test_chat.py index 2d0ffa0..5f1fcb6 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -14,7 +14,7 @@ class TestChat(CmmTestCase): self.chat = Chat([]) self.message1 = Message(Question('Question 1'), Answer('Answer 1'), - {Tag('atag1')}, + {Tag('atag1'), Tag('btag2')}, file_path=pathlib.Path('0001.txt')) self.message2 = Message(Question('Question 2'), Answer('Answer 2'), @@ -57,6 +57,11 @@ class TestChat(CmmTestCase): tags_cont = self.chat.tags(contain='2') self.assertSetEqual(tags_cont, {Tag('btag2')}) + def test_tags_frequency(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + tags_freq = self.chat.tags_frequency() + self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) + @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_msgs([self.message1, self.message2]) @@ -83,7 +88,7 @@ Answer 2 Question 1 {Answer.txt_header} Answer 1 -{TagLine.prefix} atag1 +{TagLine.prefix} atag1 btag2 FILE: 0001.txt {'-'*terminal_width()} {Question.txt_header} -- 2.36.6 From d80c3962bd9451ef45681ca8df6ef5780dc55d5f Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 12:44:27 +0200 Subject: [PATCH 059/170] chat: fixed handling of unsupported files in DB and chache dir --- chatmastermind/chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 759467d..11f1d74 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -45,7 +45,7 @@ def read_dir(dir_path: pathlib.Path, messages: list[Message] = [] file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() for file_path in sorted(file_iter): - if file_path.is_file(): + if file_path.is_file() and file_path.suffix in Message.file_suffixes: try: message = Message.from_file(file_path, mfilter) if message: -- 2.36.6 From ba56caf01309c50a46ff679f77fef3c2037c2a0a Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 08:18:41 +0200 Subject: [PATCH 060/170] chat: improved history printing --- chatmastermind/chat.py | 15 ++++++--------- tests/test_chat.py | 10 +++++----- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 11f1d74..e4e8ab6 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -145,27 +145,24 @@ class Chat: """ return sum(m.tokens() for m in self.messages) - def print(self, dump: bool = False, source_code_only: bool = False, - with_tags: bool = False, with_file: bool = False, + def print(self, source_code_only: bool = False, + with_tags: bool = False, with_files: bool = False, paged: bool = True) -> None: - if dump: - pp(self) - return output: list[str] = [] for message in self.messages: if source_code_only: output.extend(source_code(message.question, include_delims=True)) continue output.append('-' * terminal_width()) + if with_tags: + output.append(message.tags_str()) + if with_files: + output.append('FILE: ' + str(message.file_path)) output.append(Question.txt_header) output.append(message.question) if message.answer: output.append(Answer.txt_header) output.append(message.answer) - if with_tags: - output.append(message.tags_str()) - if with_file: - output.append('FILE: ' + str(message.file_path)) if paged: print_paged('\n'.join(output)) else: diff --git a/tests/test_chat.py b/tests/test_chat.py index 5f1fcb6..8e1ad0d 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -82,21 +82,21 @@ Answer 2 @patch('sys.stdout', new_callable=StringIO) def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: self.chat.add_msgs([self.message1, self.message2]) - self.chat.print(paged=False, with_tags=True, with_file=True) + self.chat.print(paged=False, with_tags=True, with_files=True) expected_output = f"""{'-'*terminal_width()} +{TagLine.prefix} atag1 btag2 +FILE: 0001.txt {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 -{TagLine.prefix} atag1 btag2 -FILE: 0001.txt {'-'*terminal_width()} +{TagLine.prefix} btag2 +FILE: 0002.txt {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 -{TagLine.prefix} btag2 -FILE: 0002.txt """ self.assertEqual(mock_stdout.getvalue(), expected_output) -- 2.36.6 From f9d749cdd8f3f921b275c89302fedc8f844caa4a Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 09:19:47 +0200 Subject: [PATCH 061/170] chat: added clear_cache() function and test --- chatmastermind/chat.py | 20 +++++++++++++++++++ tests/test_chat.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index e4e8ab6..9fc0a27 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -82,6 +82,17 @@ def write_dir(dir_path: pathlib.Path, message.to_file(file_path) +def clear_dir(dir_path: pathlib.Path, + glob: Optional[str] = None) -> None: + """ + Deletes all Message files in the given directory. + """ + file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() + for file_path in file_iter: + if file_path.is_file() and file_path.suffix in Message.file_suffixes: + file_path.unlink(missing_ok=True) + + @dataclass class Chat: """ @@ -289,3 +300,12 @@ class ChatDB(Chat): msgs if msgs else self.messages, self.file_suffix, self.get_next_fid) + + def clear_cache(self) -> None: + """ + Deletes all Message files from the cache dir and removes those messages from + the internal list. + """ + clear_dir(self.cache_path, self.glob) + # only keep messages from DB dir (or those that have not yet been written) + self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] diff --git a/tests/test_chat.py b/tests/test_chat.py index 8e1ad0d..9e74061 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -300,3 +300,48 @@ class TestChatDB(CmmTestCase): # check that they now have the DB path self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) + + def test_chat_db_clear(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # check that Message.file_path is correct + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + # write the messages to the cache directory + chat_db.write_cache() + # check if the written files are in the cache directory + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 4) + + # now rewrite them to the DB dir and check for modified paths + chat_db.write_db() + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + + # add a new message with empty file_path + message_empty = Message(question=Question("What the hell am I doing here?"), + answer=Answer("You don't belong here!")) + # and one for the cache dir + message_cache = Message(question=Question("What the hell am I doing here?"), + answer=Answer("You're a creep!"), + file_path=pathlib.Path(self.cache_path.name, '0005.txt')) + chat_db.add_msgs([message_empty, message_cache]) + + # clear the cache and check the cache dir + chat_db.clear_cache() + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 0) + # make sure that the DB messages (and the new message) are still there + self.assertEqual(len(chat_db.messages), 5) + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + # but not the message with the cache dir path + self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) -- 2.36.6 From fa292fb73a97e167ba79d894af62d3cee40202d0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 16:00:24 +0200 Subject: [PATCH 062/170] message: improved robustness of Question and Answer content checks and tests --- chatmastermind/message.py | 48 +++++++++++++++++++++------------------ tests/test_message.py | 29 ++++++++++++++++++----- 2 files changed, 49 insertions(+), 28 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 675ab3a..384fb96 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -128,29 +128,29 @@ class ModelLine(str): return cls(' '.join([cls.prefix, model])) -class Question(str): +class Answer(str): """ - A single question with a defined header. + A single answer with a defined header. """ - tokens: int = 0 # tokens used by this question - txt_header: ClassVar[str] = '=== QUESTION ===' - yaml_key: ClassVar[str] = 'question' + tokens: int = 0 # tokens used by this answer + txt_header: ClassVar[str] = '=== ANSWER ===' + yaml_key: ClassVar[str] = 'answer' - def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: + def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: """ - Make sure the question string does not contain the header. + Make sure the answer string does not contain the header as a whole line. """ - if cls.txt_header in string: - raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + if cls.txt_header in string.split('\n'): + raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod - def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: + def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ - if any(cls.txt_header in string for string in strings): + if cls.txt_header in strings: raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance @@ -162,29 +162,33 @@ class Question(str): return source_code(self, include_delims) -class Answer(str): +class Question(str): """ - A single answer with a defined header. + A single question with a defined header. """ - tokens: int = 0 # tokens used by this answer - txt_header: ClassVar[str] = '=== ANSWER ===' - yaml_key: ClassVar[str] = 'answer' + tokens: int = 0 # tokens used by this question + txt_header: ClassVar[str] = '=== QUESTION ===' + yaml_key: ClassVar[str] = 'question' - def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: + def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: """ - Make sure the answer string does not contain the header. + Make sure the question string does not contain the header as a whole line + (also not that from 'Answer', so it's always clear where the answer starts). """ - if cls.txt_header in string: - raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") + string_lines = string.split('\n') + if cls.txt_header in string_lines: + raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + if Answer.txt_header in string_lines: + raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod - def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: + def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ - if any(cls.txt_header in string for string in strings): + if cls.txt_header in strings: raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance diff --git a/tests/test_message.py b/tests/test_message.py index 0d7953e..e01de66 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -61,22 +61,39 @@ class SourceCodeTestCase(CmmTestCase): class QuestionTestCase(CmmTestCase): - def test_question_with_prefix(self) -> None: + def test_question_with_header(self) -> None: with self.assertRaises(MessageError): - Question("=== QUESTION === What is your name?") + Question(f"{Question.txt_header}\nWhat is your name?") - def test_question_without_prefix(self) -> None: + def test_question_with_answer_header(self) -> None: + with self.assertRaises(MessageError): + Question(f"{Answer.txt_header}\nBob") + + def test_question_with_legal_header(self) -> None: + """ + If the header is just a part of a line, it's fine. + """ + question = Question(f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?") + self.assertIsInstance(question, Question) + self.assertEqual(question, f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?") + + def test_question_without_header(self) -> None: question = Question("What is your favorite color?") self.assertIsInstance(question, Question) self.assertEqual(question, "What is your favorite color?") class AnswerTestCase(CmmTestCase): - def test_answer_with_prefix(self) -> None: + def test_answer_with_header(self) -> None: with self.assertRaises(MessageError): - Answer("=== ANSWER === Yes") + Answer(f"{Answer.txt_header}\nno") - def test_answer_without_prefix(self) -> None: + def test_answer_with_legal_header(self) -> None: + answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") + + def test_answer_without_header(self) -> None: answer = Answer("No") self.assertIsInstance(answer, Answer) self.assertEqual(answer, "No") -- 2.36.6 From 4b0f40bccdf5a1f10caf037cd41b726830ecef90 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 10:00:08 +0200 Subject: [PATCH 063/170] message: fixed Answer header for TXT format --- chatmastermind/message.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 384fb96..87de8e2 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -96,7 +96,7 @@ class AILine(str): def __new__(cls: Type[AILineInst], string: str) -> AILineInst: if not string.startswith(cls.prefix): - raise TagError(f"AILine '{string}' is missing prefix '{cls.prefix}'") + raise MessageError(f"AILine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance @@ -116,7 +116,7 @@ class ModelLine(str): def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst: if not string.startswith(cls.prefix): - raise TagError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") + raise MessageError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance @@ -133,7 +133,7 @@ class Answer(str): A single answer with a defined header. """ tokens: int = 0 # tokens used by this answer - txt_header: ClassVar[str] = '=== ANSWER ===' + txt_header: ClassVar[str] = '==== ANSWER ====' yaml_key: ClassVar[str] = 'answer' def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: @@ -355,17 +355,20 @@ class Message(): try: pos = fd.tell() ai = AILine(fd.readline()).ai() - except TagError: + except MessageError: fd.seek(pos) # ModelLine (Optional) try: pos = fd.tell() model = ModelLine(fd.readline()).model() - except TagError: + except MessageError: fd.seek(pos) # Question and Answer text = fd.read().strip().split('\n') - question_idx = text.index(Question.txt_header) + 1 + try: + question_idx = text.index(Question.txt_header) + 1 + except ValueError: + raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'") try: answer_idx = text.index(Answer.txt_header) question = Question.from_list(text[question_idx:answer_idx]) -- 2.36.6 From 44cd1fab4587ce9dc2b1b0f7f5a2a66d023a1ef0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 10:19:14 +0200 Subject: [PATCH 064/170] message: added rename_tags() function and test --- chatmastermind/message.py | 10 +++++++++- tests/test_message.py | 12 ++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 87de8e2..0fb949c 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -5,7 +5,7 @@ import pathlib import yaml from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field -from .tags import Tag, TagLine, TagError, match_tags +from .tags import Tag, TagLine, TagError, match_tags, rename_tags QuestionInst = TypeVar('QuestionInst', bound='Question') AnswerInst = TypeVar('AnswerInst', bound='Answer') @@ -499,6 +499,14 @@ class Message(): return False return True + def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> None: + """ + Renames the given tags. The first tuple element is the old name, + the second one is the new name. + """ + if self.tags: + self.tags = rename_tags(self.tags, tags_rename) + def msg_id(self) -> str: """ Returns an ID that is unique throughout all messages in the same (DB) directory. diff --git a/tests/test_message.py b/tests/test_message.py index e01de66..e860538 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -792,3 +792,15 @@ class MessageInTestCase(CmmTestCase): def test_message_in(self) -> None: self.assertTrue(message_in(self.message1, [self.message1])) self.assertFalse(message_in(self.message1, [self.message2])) + + +class MessageRenameTagsTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_rename_tags(self) -> None: + self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))}) + self.assertIsNotNone(self.message.tags) + self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] -- 2.36.6 From 6e2d5009c15768e1c66f396e00b0b3d68391432d Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 3 Sep 2023 10:18:16 +0200 Subject: [PATCH 065/170] chat: new possibilites for adding messages and better tests --- chatmastermind/chat.py | 75 ++++++++++++++++++++++++---- tests/test_chat.py | 109 ++++++++++++++++++++++++++++++++--------- 2 files changed, 153 insertions(+), 31 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 9fc0a27..7e6df8f 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -55,6 +55,16 @@ def read_dir(dir_path: pathlib.Path, return messages +def make_file_path(dir_path: pathlib.Path, + file_suffix: str, + next_fid: Callable[[], int]) -> pathlib.Path: + """ + Create a file_path for the given directory using the + given file_suffix and ID generator function. + """ + return dir_path / f"{next_fid():04d}{file_suffix}" + + def write_dir(dir_path: pathlib.Path, messages: list[Message], file_suffix: str, @@ -73,9 +83,7 @@ def write_dir(dir_path: pathlib.Path, file_path = message.file_path # message has no file_path: create one if not file_path: - fid = next_fid() - fname = f"{fid:04d}{file_suffix}" - file_path = dir_path / fname + file_path = make_file_path(dir_path, file_suffix, next_fid) # file_path does not point to given directory: modify it elif not file_path.parent.samefile(dir_path): file_path = dir_path / file_path.name @@ -124,11 +132,11 @@ class Chat: """ self.messages = [] - def add_msgs(self, msgs: list[Message]) -> None: + def add_messages(self, messages: list[Message]) -> None: """ Add new messages and sort them if possible. """ - self.messages += msgs + self.messages += messages self.sort() def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: @@ -279,25 +287,25 @@ class ChatDB(Chat): self.messages += new_messages self.sort() - def write_db(self, msgs: Optional[list[Message]] = None) -> None: + def write_db(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the DB directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the DB directory. """ write_dir(self.db_path, - msgs if msgs else self.messages, + messages if messages else self.messages, self.file_suffix, self.get_next_fid) - def write_cache(self, msgs: Optional[list[Message]] = None) -> None: + def write_cache(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the cache directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the cache directory. """ write_dir(self.cache_path, - msgs if msgs else self.messages, + messages if messages else self.messages, self.file_suffix, self.get_next_fid) @@ -309,3 +317,52 @@ class ChatDB(Chat): clear_dir(self.cache_path, self.glob) # only keep messages from DB dir (or those that have not yet been written) self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] + + def add_to_db(self, messages: list[Message], write: bool = True) -> None: + """ + Add the given new messages and set the file_path to the DB directory. + Only accepts messages without a file_path. + """ + if any(m.file_path is not None for m in messages): + raise ChatError("Can't add new messages with existing file_path") + if write: + write_dir(self.db_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() + + def add_to_cache(self, messages: list[Message], write: bool = True) -> None: + """ + Add the given new messages and set the file_path to the cache directory. + Only accepts messages without a file_path. + """ + if any(m.file_path is not None for m in messages): + raise ChatError("Can't add new messages with existing file_path") + if write: + write_dir(self.cache_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() + + def write_messages(self, messages: Optional[list[Message]] = None) -> None: + """ + Write either the given messages or the internal ones to their current file_path. + If messages are given, they all must have a valid file_path. When writing the + internal messages, the ones with a valid file_path are written, the others + are ignored. + """ + if messages and any(m.file_path is None for m in messages): + raise ChatError("Can't write files without a valid file_path") + msgs = iter(messages if messages else self.messages) + while (m := next(msgs, None)): + m.to_file() diff --git a/tests/test_chat.py b/tests/test_chat.py index 9e74061..a1c020e 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -5,7 +5,7 @@ from io import StringIO from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter -from chatmastermind.chat import Chat, ChatDB, terminal_width +from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError from .test_main import CmmTestCase @@ -22,14 +22,14 @@ class TestChat(CmmTestCase): file_path=pathlib.Path('0002.txt')) def test_filter(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.filter(MessageFilter(answer_contains='Answer 1')) self.assertEqual(len(self.chat.messages), 1) self.assertEqual(self.chat.messages[0].question, 'Question 1') def test_sort(self) -> None: - self.chat.add_msgs([self.message2, self.message1]) + self.chat.add_messages([self.message2, self.message1]) self.chat.sort() self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') @@ -38,18 +38,18 @@ class TestChat(CmmTestCase): self.assertEqual(self.chat.messages[1].question, 'Question 1') def test_clear(self) -> None: - self.chat.add_msgs([self.message1]) + self.chat.add_messages([self.message1]) self.chat.clear() self.assertEqual(len(self.chat.messages), 0) - def test_add_msgs(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + def test_add_messages(self) -> None: + self.chat.add_messages([self.message1, self.message2]) self.assertEqual(len(self.chat.messages), 2) self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') def test_tags(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) tags_all = self.chat.tags() self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) tags_pref = self.chat.tags(prefix='a') @@ -58,13 +58,13 @@ class TestChat(CmmTestCase): self.assertSetEqual(tags_cont, {Tag('btag2')}) def test_tags_frequency(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) tags_freq = self.chat.tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False) expected_output = f"""{'-'*terminal_width()} {Question.txt_header} @@ -81,7 +81,7 @@ Answer 2 @patch('sys.stdout', new_callable=StringIO) def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False, with_tags=True, with_files=True) expected_output = f"""{'-'*terminal_width()} {TagLine.prefix} atag1 btag2 @@ -127,6 +127,17 @@ class TestChatDB(CmmTestCase): self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) + # make the next FID match the current state + next_fname = pathlib.Path(self.db_path.name) / '.next' + with open(next_fname, 'w') as f: + f.write('4') + + def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: + """ + List all Message files in the given TemporaryDirectory. + """ + # exclude '.next' + return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*')) def tearDown(self) -> None: self.db_path.cleanup() @@ -184,11 +195,11 @@ class TestChatDB(CmmTestCase): def test_chat_db_fids(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) - self.assertEqual(chat_db.get_next_fid(), 1) - self.assertEqual(chat_db.get_next_fid(), 2) - self.assertEqual(chat_db.get_next_fid(), 3) + self.assertEqual(chat_db.get_next_fid(), 5) + self.assertEqual(chat_db.get_next_fid(), 6) + self.assertEqual(chat_db.get_next_fid(), 7) with open(chat_db.next_fname, 'r') as f: - self.assertEqual(f.read(), '3') + self.assertEqual(f.read(), '7') def test_chat_db_write(self) -> None: # create a new ChatDB instance @@ -203,7 +214,7 @@ class TestChatDB(CmmTestCase): # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) @@ -216,14 +227,14 @@ class TestChatDB(CmmTestCase): self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) # check the timestamp of the files in the DB directory - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} # overwrite the messages in the db directory time.sleep(0.05) chat_db.write_db() # check if the written files are in the DB directory - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) @@ -314,12 +325,12 @@ class TestChatDB(CmmTestCase): # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) # now rewrite them to the DB dir and check for modified paths chat_db.write_db() - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) @@ -333,15 +344,69 @@ class TestChatDB(CmmTestCase): message_cache = Message(question=Question("What the hell am I doing here?"), answer=Answer("You're a creep!"), file_path=pathlib.Path(self.cache_path.name, '0005.txt')) - chat_db.add_msgs([message_empty, message_cache]) + chat_db.add_messages([message_empty, message_cache]) # clear the cache and check the cache dir chat_db.clear_cache() - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) # make sure that the DB messages (and the new message) are still there self.assertEqual(len(chat_db.messages), 5) - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) # but not the message with the cache dir path self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) + + def test_chat_db_add(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + + # add new messages to the cache dir + message1 = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + chat_db.add_to_cache([message1]) + # check if the file_path has been correctly set + self.assertIsNotNone(message1.file_path) + self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + + # add new messages to the DB dir + message2 = Message(question=Question("Question 2"), + answer=Answer("Answer 2")) + chat_db.add_to_db([message2]) + # check if the file_path has been correctly set + self.assertIsNotNone(message2.file_path) + self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 5) + + with self.assertRaises(ChatError): + chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))]) + + def test_chat_db_write_messages(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + + # try to write a message without a valid file_path + message = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + with self.assertRaises(ChatError): + chat_db.write_messages([message]) + + # write a message with a valid file_path + message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt' + chat_db.write_messages([message]) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) -- 2.36.6 From 63040b368895fb8065c0c03a15b0f40beb561339 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 4 Sep 2023 08:49:43 +0200 Subject: [PATCH 066/170] message / chat: output improvements --- chatmastermind/chat.py | 16 ++++------------ chatmastermind/message.py | 24 ++++++++++++++++++++++++ tests/test_chat.py | 16 ++++++++++++---- tests/test_message.py | 24 ++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 16 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 7e6df8f..c631dab 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -7,7 +7,7 @@ from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass from typing import TypeVar, Type, Optional, ClassVar, Any, Callable -from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in +from .message import Message, MessageFilter, MessageError, message_in from .tags import Tag ChatInst = TypeVar('ChatInst', bound='Chat') @@ -170,18 +170,10 @@ class Chat: output: list[str] = [] for message in self.messages: if source_code_only: - output.extend(source_code(message.question, include_delims=True)) + output.append(message.to_str(source_code_only=True)) continue - output.append('-' * terminal_width()) - if with_tags: - output.append(message.tags_str()) - if with_files: - output.append('FILE: ' + str(message.file_path)) - output.append(Question.txt_header) - output.append(message.question) - if message.answer: - output.append(Answer.txt_header) - output.append(message.answer) + output.append(message.to_str(with_tags, with_files)) + output.append('\n' + ('-' * terminal_width()) + '\n') if paged: print_paged('\n'.join(output)) else: diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 0fb949c..35de3b9 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -392,6 +392,30 @@ class Message(): data[cls.file_yaml_key] = file_path return cls.from_dict(data) + def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str: + """ + Return the current Message as a string. + """ + output: list[str] = [] + if source_code_only: + # use the source code from answer only + if self.answer: + output.extend(self.answer.source_code(include_delims=True)) + return '\n'.join(output) if len(output) > 0 else '' + if with_tags: + output.append(self.tags_str()) + if with_file: + output.append('FILE: ' + str(self.file_path)) + output.append(Question.txt_header) + output.append(self.question) + if self.answer: + output.append(Answer.txt_header) + output.append(self.answer) + return '\n'.join(output) + + def __str__(self) -> str: + return self.to_str(False, False, False) + def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 """ Write a Message to the given file. Type is determined based on the suffix. diff --git a/tests/test_chat.py b/tests/test_chat.py index a1c020e..f8302eb 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -66,16 +66,20 @@ class TestChat(CmmTestCase): def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False) - expected_output = f"""{'-'*terminal_width()} -{Question.txt_header} + expected_output = f"""{Question.txt_header} Question 1 {Answer.txt_header} Answer 1 + {'-'*terminal_width()} + {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 + +{'-'*terminal_width()} + """ self.assertEqual(mock_stdout.getvalue(), expected_output) @@ -83,20 +87,24 @@ Answer 2 def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False, with_tags=True, with_files=True) - expected_output = f"""{'-'*terminal_width()} -{TagLine.prefix} atag1 btag2 + expected_output = f"""{TagLine.prefix} atag1 btag2 FILE: 0001.txt {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 + {'-'*terminal_width()} + {TagLine.prefix} btag2 FILE: 0002.txt {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 + +{'-'*terminal_width()} + """ self.assertEqual(mock_stdout.getvalue(), expected_output) diff --git a/tests/test_message.py b/tests/test_message.py index e860538..a49c893 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -804,3 +804,27 @@ class MessageRenameTagsTestCase(CmmTestCase): self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))}) self.assertIsNotNone(self.message.tags) self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] + + +class MessageToStrTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + Answer('This is an answer.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_to_str(self) -> None: + expected_output = f"""{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer.""" + self.assertEqual(self.message.to_str(), expected_output) + + def test_to_str_with_tags_and_file(self) -> None: + expected_output = f"""{TagLine.prefix} atag1 btag2 +FILE: /tmp/foo/bla +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer.""" + self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output) -- 2.36.6 From 7e25a08d6e8d1fd7e3a3ba782f4b8d20e67c8ef0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 6 Sep 2023 08:16:55 +0200 Subject: [PATCH 067/170] chat: added functions for finding and deleting messages --- chatmastermind/chat.py | 52 ++++++++++++++++++++++++++++++++---------- tests/test_chat.py | 22 ++++++++++++++++++ 2 files changed, 62 insertions(+), 12 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index c631dab..4e8fb20 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -2,7 +2,7 @@ Module implementing various chat classes and functions for managing a chat history. """ import shutil -import pathlib +from pathlib import Path from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass @@ -30,7 +30,7 @@ def print_paged(text: str) -> None: pager(text) -def read_dir(dir_path: pathlib.Path, +def read_dir(dir_path: Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> list[Message]: """ @@ -55,9 +55,9 @@ def read_dir(dir_path: pathlib.Path, return messages -def make_file_path(dir_path: pathlib.Path, +def make_file_path(dir_path: Path, file_suffix: str, - next_fid: Callable[[], int]) -> pathlib.Path: + next_fid: Callable[[], int]) -> Path: """ Create a file_path for the given directory using the given file_suffix and ID generator function. @@ -65,7 +65,7 @@ def make_file_path(dir_path: pathlib.Path, return dir_path / f"{next_fid():04d}{file_suffix}" -def write_dir(dir_path: pathlib.Path, +def write_dir(dir_path: Path, messages: list[Message], file_suffix: str, next_fid: Callable[[], int]) -> None: @@ -90,7 +90,7 @@ def write_dir(dir_path: pathlib.Path, message.to_file(file_path) -def clear_dir(dir_path: pathlib.Path, +def clear_dir(dir_path: Path, glob: Optional[str] = None) -> None: """ Deletes all Message files in the given directory. @@ -139,6 +139,34 @@ class Chat: self.messages += messages self.sort() + def latest_message(self) -> Optional[Message]: + """ + Returns the last added message (according to the file ID). + """ + if len(self.messages) > 0: + self.sort() + return self.messages[-1] + else: + return None + + def find_messages(self, msg_names: list[str]) -> list[Message]: + """ + Search and return the messages with the given names. Names can either be filenames + (incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the + caller should check the result if he requires all messages). + """ + return [m for m in self.messages + if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + + def remove_messages(self, msg_names: list[str]) -> None: + """ + Remove the messages with the given names. Names can either be filenames + (incl. the suffix) or full paths. + """ + self.messages = [m for m in self.messages + if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + self.sort() + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Get the tags of all messages, optionally filtered by prefix or substring. @@ -192,8 +220,8 @@ class ChatDB(Chat): default_file_suffix: ClassVar[str] = '.txt' - cache_path: pathlib.Path - db_path: pathlib.Path + cache_path: Path + db_path: Path # a MessageFilter that all messages must match (if given) mfilter: Optional[MessageFilter] = None file_suffix: str = default_file_suffix @@ -209,8 +237,8 @@ class ChatDB(Chat): @classmethod def from_dir(cls: Type[ChatDBInst], - cache_path: pathlib.Path, - db_path: pathlib.Path, + cache_path: Path, + db_path: Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ @@ -230,8 +258,8 @@ class ChatDB(Chat): @classmethod def from_messages(cls: Type[ChatDBInst], - cache_path: pathlib.Path, - db_path: pathlib.Path, + cache_path: Path, + db_path: Path, messages: list[Message], mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ diff --git a/tests/test_chat.py b/tests/test_chat.py index f8302eb..d81a97a 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -62,6 +62,28 @@ class TestChat(CmmTestCase): tags_freq = self.chat.tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) + def test_find_remove_messages(self) -> None: + self.chat.add_messages([self.message1, self.message2]) + msgs = self.chat.find_messages(['0001.txt']) + self.assertListEqual(msgs, [self.message1]) + msgs = self.chat.find_messages(['0001.txt', '0002.txt']) + self.assertListEqual(msgs, [self.message1, self.message2]) + # add new Message with full path + message3 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('btag2')}, + file_path=pathlib.Path('/foo/bla/0003.txt')) + self.chat.add_messages([message3]) + # find new Message by full path + msgs = self.chat.find_messages(['/foo/bla/0003.txt']) + self.assertListEqual(msgs, [message3]) + # find Message with full path only by filename + msgs = self.chat.find_messages(['0003.txt']) + self.assertListEqual(msgs, [message3]) + # remove last message + self.chat.remove_messages(['0003.txt']) + self.assertListEqual(self.chat.messages, [self.message1, self.message2]) + @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) -- 2.36.6 From eb0d97ddc8cad58626d85d6a32eb10085e850128 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 12:46:23 +0200 Subject: [PATCH 068/170] cmm: the 'tags' command now uses the new 'ChatDB' --- chatmastermind/main.py | 34 +++++++++++++++++++++------------- chatmastermind/utils.py | 5 ----- tests/test_main.py | 2 +- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 7866179..3f31aee 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -7,10 +7,11 @@ import sys import argcomplete import argparse import pathlib -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType -from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data +from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType +from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, dump_data from .api_client import ai, openai_api_key, print_models from .configuration import Config +from .chat import ChatDB from itertools import zip_longest from typing import Any @@ -57,12 +58,17 @@ def create_question_with_hist(args: argparse.Namespace, return chat, full_question, tags -def tag_cmd(args: argparse.Namespace, config: Config) -> None: +def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ - Handler for the 'tag' command. + Handler for the 'tags' command. """ + chat = ChatDB.from_dir(cache_path=pathlib.Path('.'), + db_path=pathlib.Path(config.db)) if args.list: - print_tags_frequency(get_tags(config, None)) + tags_freq = chat.tags_frequency(args.prefix, args.contain) + for tag, freq in tags_freq.items(): + print(f"- {tag}: {freq}") + # TODO: add renaming def config_cmd(args: argparse.Namespace, config: Config) -> None: @@ -187,14 +193,16 @@ def create_parser() -> argparse.ArgumentParser: hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') - # 'tag' command parser - tag_cmd_parser = cmdparser.add_parser('tag', - help="Manage tags.", - aliases=['t']) - tag_cmd_parser.set_defaults(func=tag_cmd) - tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True) - tag_group.add_argument('-l', '--list', help="List all tags and their frequency", - action='store_true') + # 'tags' command parser + tags_cmd_parser = cmdparser.add_parser('tags', + help="Manage tags.", + aliases=['t']) + tags_cmd_parser.set_defaults(func=tags_cmd) + tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True) + tags_group.add_argument('-l', '--list', help="List all tags and their frequency", + action='store_true') + tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix") + tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring") # 'config' command parser config_cmd_parser = cmdparser.add_parser('config', diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index bd80e4f..e6eeb97 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -78,8 +78,3 @@ def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = Fals print(message['content']) else: print(f"{message['role'].upper()}: {message['content']}") - - -def print_tags_frequency(tags: list[str]) -> None: - for tag in sorted(set(tags)): - print(f"- {tag}: {tags.count(tag)}") diff --git a/tests/test_main.py b/tests/test_main.py index db5fcdb..23c3d00 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -227,7 +227,7 @@ class TestCreateParser(CmmTestCase): mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('tag', help=ANY, aliases=ANY) + mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) -- 2.36.6 From b0504aedbef6e167469c703174d57164fc637595 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 08:21:49 +0200 Subject: [PATCH 069/170] cmm: the 'hist' command now uses the new 'ChatDB' --- chatmastermind/main.py | 60 +++++++++++++++++++++++------------------- tests/test_main.py | 15 ++++++----- 2 files changed, 42 insertions(+), 33 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 3f31aee..08c5e3e 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -12,6 +12,7 @@ from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB +from .message import MessageFilter from itertools import zip_longest from typing import Any @@ -32,11 +33,11 @@ def create_question_with_hist(args: argparse.Namespace, by the specified tags. """ tags = args.tags or [] - extags = args.extags or [] + etags = args.etags or [] otags = args.output_tags or [] - if not args.only_source_code: - print_tag_args(tags, extags, otags) + if not args.source_code_only: + print_tag_args(tags, etags, otags) question_parts = [] question_list = args.question if args.question is not None else [] @@ -53,8 +54,10 @@ def create_question_with_hist(args: argparse.Namespace, question_parts.append(f"```\n{r.read().strip()}\n```") full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, extags, config, - args.match_all_tags, False, False) + chat = create_chat_hist(full_question, tags, etags, config, + match_all_tags=True if args.atags else False, # FIXME + with_tags=False, + with_file=False) return chat, full_question, tags @@ -95,7 +98,7 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None: if args.model: config.openai.model = args.model chat, question, tags = create_question_with_hist(args, config) - print_chat_hist(chat, False, args.only_source_code) + print_chat_hist(chat, False, args.source_code_only) otags = args.output_tags or [] answers, usage = ai(chat, config, args.number) save_answers(question, answers, tags, otags, config) @@ -107,14 +110,18 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'hist' command. """ - tags = args.tags or [] - extags = args.extags or [] - chat = create_chat_hist(None, tags, extags, config, - args.match_all_tags, - args.with_tags, - args.with_files) - print_chat_hist(chat, args.dump, args.only_source_code) + mfilter = MessageFilter(tags_or=args.tags, + tags_and=args.atags, + tags_not=args.etags, + question_contains=args.question, + answer_contains=args.answer) + chat = ChatDB.from_dir(Path('.'), + Path(config.db), + mfilter=mfilter) + chat.print(args.source_code_only, + args.with_tags, + args.with_files) def print_cmd(args: argparse.Namespace, config: Config) -> None: @@ -130,7 +137,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: else: print(f"Unknown file type: {args.file}") sys.exit(1) - if args.only_source_code: + if args.source_code_only: display_source_code(data['answer']) else: print(dump_data(data).strip()) @@ -150,18 +157,17 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+', - help='List of tag names', metavar='TAGS') + help='List of tag names (one must match)', metavar='TAGS') tag_arg.completer = tags_completer # type: ignore - extag_arg = tag_parser.add_argument('-e', '--extags', nargs='+', - help='List of tag names to exclude', metavar='EXTAGS') - extag_arg.completer = tags_completer # type: ignore + atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+', + help='List of tag names (all must match)', metavar='TAGS') + atag_arg.completer = tags_completer # type: ignore + etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+', + help='List of tag names to exclude', metavar='ETAGS') + etag_arg.completer = tags_completer # type: ignore otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', help='List of output tag names, default is input', metavar='OTAGS') otag_arg.completer = tags_completer # type: ignore - tag_parser.add_argument('-a', '--match-all-tags', - help="All given tags must match when selecting chat history entries", - action='store_true') - # enable autocompletion for tags # 'ask' command parser ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], @@ -176,7 +182,7 @@ def create_parser() -> argparse.ArgumentParser: ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1) ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history', + ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', action='store_true') # 'hist' command parser @@ -184,14 +190,14 @@ def create_parser() -> argparse.ArgumentParser: help="Print chat history.", aliases=['h']) hist_cmd_parser.set_defaults(func=hist_cmd) - hist_cmd_parser.add_argument('-d', '--dump', help="Print chat history as Python structure", - action='store_true') hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", action='store_true') hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", action='store_true') - hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', + hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', action='store_true') + hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring') + hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring') # 'tags' command parser tags_cmd_parser = cmdparser.add_parser('tags', @@ -222,7 +228,7 @@ def create_parser() -> argparse.ArgumentParser: aliases=['p']) print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) - print_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', + print_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', action='store_true') argcomplete.autocomplete(parser) diff --git a/tests/test_main.py b/tests/test_main.py index 23c3d00..bb9aa2a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -115,11 +115,12 @@ class TestHandleQuestion(CmmTestCase): self.question = "test question" self.args = argparse.Namespace( tags=['tag1'], - extags=['extag1'], + atags=None, + etags=['etag1'], output_tags=None, question=[self.question], source=None, - only_source_code=False, + source_code_only=False, number=3, max_tokens=None, temperature=None, @@ -143,16 +144,18 @@ class TestHandleQuestion(CmmTestCase): with patch("chatmastermind.storage.open", open_mock): ask_cmd(self.args, self.config) mock_print_tag_args.assert_called_once_with(self.args.tags, - self.args.extags, + self.args.etags, []) mock_create_chat_hist.assert_called_once_with(self.question, self.args.tags, - self.args.extags, + self.args.etags, self.config, - False, False, False) + match_all_tags=False, + with_tags=False, + with_file=False) mock_print_chat_hist.assert_called_once_with('test_chat', False, - self.args.only_source_code) + self.args.source_code_only) mock_ai.assert_called_with("test_chat", self.config, self.args.number) -- 2.36.6 From f93a57c00da39ff658a261970501d0c8c5140ec2 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 08:42:59 +0200 Subject: [PATCH 070/170] cmm: tags completion now uses 'Message.tags_from_dir' (fixes tag completion for me) --- chatmastermind/main.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 08c5e3e..b3bd1b8 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,13 +6,13 @@ import yaml import sys import argcomplete import argparse -import pathlib +from pathlib import Path from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType -from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, dump_data +from .storage import save_answers, create_chat_hist, read_file, dump_data from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB -from .message import MessageFilter +from .message import Message, MessageFilter from itertools import zip_longest from typing import Any @@ -20,9 +20,8 @@ default_config = '.config.yaml' def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: - with open(parsed_args.config, 'r') as f: - config = yaml.load(f, Loader=yaml.FullLoader) - return get_tags_unique(config, prefix) + config = Config.from_file(parsed_args.config) + return list(Message.tags_from_dir(Path(config.db), prefix=prefix)) def create_question_with_hist(args: argparse.Namespace, @@ -65,8 +64,8 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'tags' command. """ - chat = ChatDB.from_dir(cache_path=pathlib.Path('.'), - db_path=pathlib.Path(config.db)) + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) if args.list: tags_freq = chat.tags_frequency(args.prefix, args.contain) for tag, freq in tags_freq.items(): @@ -128,7 +127,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'print' command. """ - fname = pathlib.Path(args.file) + fname = Path(args.file) if fname.suffix == '.yaml': with open(args.file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) -- 2.36.6 From bf1cbff6a2c11411097cca53229ac6b1c6ecae06 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 4 Sep 2023 22:07:02 +0200 Subject: [PATCH 071/170] cmm: the 'print' command now uses 'Message.from_file()' --- chatmastermind/main.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index b3bd1b8..951d3cf 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -2,17 +2,16 @@ # -*- coding: utf-8 -*- # vim: set fileencoding=utf-8 : -import yaml import sys import argcomplete import argparse from pathlib import Path -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType -from .storage import save_answers, create_chat_hist, read_file, dump_data +from .utils import terminal_width, print_tag_args, print_chat_hist, ChatType +from .storage import save_answers, create_chat_hist from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB -from .message import Message, MessageFilter +from .message import Message, MessageFilter, MessageError from itertools import zip_longest from typing import Any @@ -128,18 +127,13 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: Handler for the 'print' command. """ fname = Path(args.file) - if fname.suffix == '.yaml': - with open(args.file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - elif fname.suffix == '.txt': - data = read_file(fname) - else: - print(f"Unknown file type: {args.file}") + try: + message = Message.from_file(fname) + if message: + print(message.to_str(source_code_only=args.source_code_only)) + except MessageError: + print(f"File is not a valid message: {args.file}") sys.exit(1) - if args.source_code_only: - display_source_code(data['answer']) - else: - print(dump_data(data).strip()) def create_parser() -> argparse.ArgumentParser: @@ -223,11 +217,11 @@ def create_parser() -> argparse.ArgumentParser: # 'print' command parser print_cmd_parser = cmdparser.add_parser('print', - help="Print files.", + help="Print message files.", aliases=['p']) print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) - print_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', + print_cmd_parser.add_argument('-S', '--source-code-only', help='Print source code only (from the answer, if available)', action='store_true') argcomplete.autocomplete(parser) -- 2.36.6 From aa322de71866ed513f56b10ba5564b5a482c888b Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 09:00:15 +0200 Subject: [PATCH 072/170] added new module 'ai.py' --- chatmastermind/ai.py | 63 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 chatmastermind/ai.py diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py new file mode 100644 index 0000000..4a8b914 --- /dev/null +++ b/chatmastermind/ai.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass +from typing import Protocol, Optional, Union +from .configuration import AIConfig +from .tags import Tag +from .message import Message +from .chat import Chat + + +class AIError(Exception): + pass + + +@dataclass +class Tokens: + prompt: int = 0 + completion: int = 0 + total: int = 0 + + +@dataclass +class AIResponse: + """ + The response to an AI request. Consists of one or more messages + (each containing the question and a single answer) and the nr. + of used tokens. + """ + messages: list[Message] + tokens: Optional[Tokens] = None + + +class AI(Protocol): + """ + The base class for AI clients. + """ + + name: str + config: AIConfig + + def request(self, + question: Message, + context: Chat, + num_answers: int = 1, + otags: Optional[set[Tag]] = None) -> AIResponse: + """ + Make an AI request, asking the given question with the given + context (i. e. chat history). The nr. of requested answers + corresponds to the nr. of messages in the 'AIResponse'. + """ + raise NotImplementedError + + def models(self) -> list[str]: + """ + Return all models supported by this AI. + """ + raise NotImplementedError + + def tokens(self, data: Union[Message, Chat]) -> int: + """ + Computes the nr. of AI language tokens for the given message + or chat. Note that the computation may not be 100% accurate + and is not implemented for all AIs. + """ + raise NotImplementedError -- 2.36.6 From b7e3ca7ca77d65c069f05b4ca005385793026c68 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 10:18:09 +0200 Subject: [PATCH 073/170] added new module 'openai.py' --- chatmastermind/ais/openai.py | 96 ++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 chatmastermind/ais/openai.py diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py new file mode 100644 index 0000000..74438b8 --- /dev/null +++ b/chatmastermind/ais/openai.py @@ -0,0 +1,96 @@ +""" +Implements the OpenAI client classes and functions. +""" +import openai +from typing import Optional, Union +from ..tags import Tag +from ..message import Message, Answer +from ..chat import Chat +from ..ai import AI, AIResponse, Tokens +from ..configuration import OpenAIConfig + +ChatType = list[dict[str, str]] + + +class OpenAI(AI): + """ + The OpenAI AI client. + """ + + def __init__(self, name: str, config: OpenAIConfig) -> None: + self.name = name + self.config = config + + def request(self, + question: Message, + chat: Chat, + num_answers: int = 1, + otags: Optional[set[Tag]] = None) -> AIResponse: + """ + Make an AI request, asking the given question with the given + chat history. The nr. of requested answers corresponds to the + nr. of messages in the 'AIResponse'. + """ + # FIXME: use real 'system' message (store in OpenAIConfig) + oai_chat = self.openai_chat(chat, "system", question) + response = openai.ChatCompletion.create( + model=self.config.model, + messages=oai_chat, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + top_p=self.config.top_p, + n=num_answers, + frequency_penalty=self.config.frequency_penalty, + presence_penalty=self.config.presence_penalty) + answers: list[Message] = [] + for choice in response['choices']: # type: ignore + answers.append(Message(question=question.question, + answer=Answer(choice['message']['content']), + tags=otags, + ai=self.name, + model=self.config.model)) + return AIResponse(answers, Tokens(response['usage']['prompt'], + response['usage']['completion'], + response['usage']['total'])) + + def models(self) -> list[str]: + """ + Return all models supported by this AI. + """ + raise NotImplementedError + + def print_models(self) -> None: + """ + Print all models supported by the current AI. + """ + not_ready = [] + for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): + if engine['ready']: + print(engine['id']) + else: + not_ready.append(engine['id']) + if len(not_ready) > 0: + print('\nNot ready: ' + ', '.join(not_ready)) + + def openai_chat(self, chat: Chat, system: str, + question: Optional[Message] = None) -> ChatType: + """ + Create a chat history with system message in OpenAI format. + Optionally append a new question. + """ + oai_chat: ChatType = [] + + def append(role: str, content: str) -> None: + oai_chat.append({'role': role, 'content': content.replace("''", "'")}) + + append('system', system) + for message in chat.messages: + if message.answer: + append('user', message.question) + append('assistant', message.answer) + if question: + append('user', question.question) + return oai_chat + + def tokens(self, data: Union[Message, Chat]) -> int: + raise NotImplementedError -- 2.36.6 From eb2fcba99d6918edf89b42ff8f2c171e49532c4a Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 5 Sep 2023 23:24:20 +0200 Subject: [PATCH 074/170] added new module 'ai_factory' --- chatmastermind/ai_factory.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 chatmastermind/ai_factory.py diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py new file mode 100644 index 0000000..c90366b --- /dev/null +++ b/chatmastermind/ai_factory.py @@ -0,0 +1,20 @@ +""" +Creates different AI instances, based on the given configuration. +""" + +import argparse +from .configuration import Config +from .ai import AI, AIError +from .ais.openai import OpenAI + + +def create_ai(args: argparse.Namespace, config: Config) -> AI: + """ + Creates an AI subclass instance from the given args and configuration. + """ + if args.ai == 'openai': + # FIXME: create actual 'OpenAIConfig' and set values from 'args' + # FIXME: use actual name from config + return OpenAI("openai", config.openai) + else: + raise AIError(f"AI '{args.ai}' is not supported") -- 2.36.6 From ba5aa1fbc73013cee81c7bb27b0a970866b6bf25 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 4 Sep 2023 22:35:53 +0200 Subject: [PATCH 075/170] cmm: added 'question' command --- chatmastermind/main.py | 103 +++++++++++++++++++++++++++++++++-------- tests/test_main.py | 18 +++---- 2 files changed, 93 insertions(+), 28 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 951d3cf..b10b97b 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -11,7 +11,9 @@ from .storage import save_answers, create_chat_hist from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB -from .message import Message, MessageFilter, MessageError +from .message import Message, MessageFilter, MessageError, Question +from .ai_factory import create_ai +from .ai import AI, AIResponse from itertools import zip_longest from typing import Any @@ -30,12 +32,12 @@ def create_question_with_hist(args: argparse.Namespace, Creates the "AI request", including the question and chat history as determined by the specified tags. """ - tags = args.tags or [] - etags = args.etags or [] + tags = args.or_tags or [] + xtags = args.exclude_tags or [] otags = args.output_tags or [] if not args.source_code_only: - print_tag_args(tags, etags, otags) + print_tag_args(tags, xtags, otags) question_parts = [] question_list = args.question if args.question is not None else [] @@ -52,8 +54,8 @@ def create_question_with_hist(args: argparse.Namespace, question_parts.append(f"```\n{r.read().strip()}\n```") full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, etags, config, - match_all_tags=True if args.atags else False, # FIXME + chat = create_chat_hist(full_question, tags, xtags, config, + match_all_tags=True if args.and_tags else False, # FIXME with_tags=False, with_file=False) return chat, full_question, tags @@ -85,6 +87,47 @@ def config_cmd(args: argparse.Namespace, config: Config) -> None: config.to_file(args.config) +def question_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'question' command. + """ + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) + # if it's a new question, create and store it immediately + if args.ask or args.create: + message = Message(question=Question(args.question), + tags=args.ouput_tags, # FIXME + ai=args.ai, + model=args.model) + chat.add_to_cache([message]) + if args.create: + return + + # create the correct AI instance + ai: AI = create_ai(args, config) + if args.ask: + response: AIResponse = ai.request(message, + chat, + args.num_answers, # FIXME + args.otags) # FIXME + assert response + # TODO: + # * add answer to the message above (and create + # more messages for any additional answers) + pass + elif args.repeat: + lmessage = chat.latest_message() + assert lmessage + # TODO: repeat either the last question or the + # one(s) given in 'args.repeat' (overwrite + # existing ones if 'args.overwrite' is True) + pass + elif args.process: + # TODO: process either all questions without an + # answer or the one(s) given in 'args.process' + pass + + def ask_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'ask' command. @@ -98,7 +141,7 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None: chat, question, tags = create_question_with_hist(args, config) print_chat_hist(chat, False, args.source_code_only) otags = args.output_tags or [] - answers, usage = ai(chat, config, args.number) + answers, usage = ai(chat, config, args.num_answers) save_answers(question, answers, tags, otags, config) print("-" * terminal_width()) print(f"Usage: {usage}") @@ -109,9 +152,9 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None: Handler for the 'hist' command. """ - mfilter = MessageFilter(tags_or=args.tags, - tags_and=args.atags, - tags_not=args.etags, + mfilter = MessageFilter(tags_or=args.or_tags, + tags_and=args.and_tags, + tags_not=args.exclude_tags, question_contains=args.question, answer_contains=args.answer) chat = ChatDB.from_dir(Path('.'), @@ -139,7 +182,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") - parser.add_argument('-c', '--config', help='Config file name.', default=default_config) + parser.add_argument('-C', '--config', help='Config file name.', default=default_config) # subcommand-parser cmdparser = parser.add_subparsers(dest='command', @@ -149,19 +192,41 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) - tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+', - help='List of tag names (one must match)', metavar='TAGS') + tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+', + help='List of tag names (one must match)', metavar='OTAGS') tag_arg.completer = tags_completer # type: ignore - atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+', - help='List of tag names (all must match)', metavar='TAGS') + atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+', + help='List of tag names (all must match)', metavar='ATAGS') atag_arg.completer = tags_completer # type: ignore - etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+', - help='List of tag names to exclude', metavar='ETAGS') + etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+', + help='List of tag names to exclude', metavar='XTAGS') etag_arg.completer = tags_completer # type: ignore otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', - help='List of output tag names, default is input', metavar='OTAGS') + help='List of output tag names, default is input', metavar='OUTTAGS') otag_arg.completer = tags_completer # type: ignore + # 'question' command parser + question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser], + help="ask, create and process questions.", + aliases=['q']) + question_cmd_parser.set_defaults(func=question_cmd) + question_group = question_cmd_parser.add_mutually_exclusive_group(required=True) + question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question') + question_group.add_argument('-c', '--create', nargs='+', help='Create a question') + question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question') + question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') + question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', + action='store_true') + question_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) + question_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) + question_cmd_parser.add_argument('-A', '--AI', help='AI to use') + question_cmd_parser.add_argument('-M', '--model', help='Model to use') + question_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, + default=1) + question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') + question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', + action='store_true') + # 'ask' command parser ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], help="Ask a question.", @@ -172,7 +237,7 @@ def create_parser() -> argparse.ArgumentParser: ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) ask_cmd_parser.add_argument('-M', '--model', help='Model to use') - ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, + ask_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, default=1) ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', diff --git a/tests/test_main.py b/tests/test_main.py index bb9aa2a..ce9121a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -114,14 +114,14 @@ class TestHandleQuestion(CmmTestCase): def setUp(self) -> None: self.question = "test question" self.args = argparse.Namespace( - tags=['tag1'], - atags=None, - etags=['etag1'], + or_tags=['tag1'], + and_tags=None, + exclude_tags=['xtag1'], output_tags=None, question=[self.question], source=None, source_code_only=False, - number=3, + num_answers=3, max_tokens=None, temperature=None, model=None, @@ -143,12 +143,12 @@ class TestHandleQuestion(CmmTestCase): open_mock = MagicMock() with patch("chatmastermind.storage.open", open_mock): ask_cmd(self.args, self.config) - mock_print_tag_args.assert_called_once_with(self.args.tags, - self.args.etags, + mock_print_tag_args.assert_called_once_with(self.args.or_tags, + self.args.exclude_tags, []) mock_create_chat_hist.assert_called_once_with(self.question, - self.args.tags, - self.args.etags, + self.args.or_tags, + self.args.exclude_tags, self.config, match_all_tags=False, with_tags=False, @@ -158,7 +158,7 @@ class TestHandleQuestion(CmmTestCase): self.args.source_code_only) mock_ai.assert_called_with("test_chat", self.config, - self.args.number) + self.args.num_answers) expected_calls = [] for num, answer in enumerate(mock_ai.return_value[0], start=1): title = f'-- ANSWER {num} ' -- 2.36.6 From 893917e455f87b7059657a5d2f01854f096ac5bc Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 6 Sep 2023 22:12:05 +0200 Subject: [PATCH 076/170] test_main: temporarily disabled all testcases --- tests/test_chat.py | 6 +- tests/test_main.py | 468 +++++++++++++++++++++--------------------- tests/test_message.py | 34 +-- tests/test_tags.py | 6 +- 4 files changed, 257 insertions(+), 257 deletions(-) diff --git a/tests/test_chat.py b/tests/test_chat.py index d81a97a..8e4aa8c 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,3 +1,4 @@ +import unittest import pathlib import tempfile import time @@ -6,10 +7,9 @@ from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError -from .test_main import CmmTestCase -class TestChat(CmmTestCase): +class TestChat(unittest.TestCase): def setUp(self) -> None: self.chat = Chat([]) self.message1 = Message(Question('Question 1'), @@ -131,7 +131,7 @@ Answer 2 self.assertEqual(mock_stdout.getvalue(), expected_output) -class TestChatDB(CmmTestCase): +class TestChatDB(unittest.TestCase): def setUp(self) -> None: self.db_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory() diff --git a/tests/test_main.py b/tests/test_main.py index ce9121a..91e6462 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,236 +1,236 @@ -import unittest -import io -import pathlib -import argparse -from chatmastermind.utils import terminal_width -from chatmastermind.main import create_parser, ask_cmd -from chatmastermind.api_client import ai -from chatmastermind.configuration import Config -from chatmastermind.storage import create_chat_hist, save_answers, dump_data -from unittest import mock -from unittest.mock import patch, MagicMock, Mock, ANY +# import unittest +# import io +# import pathlib +# import argparse +# from chatmastermind.utils import terminal_width +# from chatmastermind.main import create_parser, ask_cmd +# from chatmastermind.api_client import ai +# from chatmastermind.configuration import Config +# from chatmastermind.storage import create_chat_hist, save_answers, dump_data +# from unittest import mock +# from unittest.mock import patch, MagicMock, Mock, ANY -class CmmTestCase(unittest.TestCase): - """ - Base class for all cmm testcases. - """ - def dummy_config(self, db: str) -> Config: - """ - Creates a dummy configuration. - """ - return Config.from_dict( - {'system': 'dummy_system', - 'db': db, - 'openai': {'api_key': 'dummy_key', - 'model': 'dummy_model', - 'max_tokens': 4000, - 'temperature': 1.0, - 'top_p': 1, - 'frequency_penalty': 0, - 'presence_penalty': 0}} - ) - - -class TestCreateChat(CmmTestCase): - - def setUp(self) -> None: - self.config = self.dummy_config(db='test_files') - self.question = "test question" - self.tags = ['test_tag'] - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( - {'question': 'test_content', 'answer': 'some answer', - 'tags': ['test_tag']})) - - test_chat = create_chat_hist(self.question, self.tags, None, self.config) - - self.assertEqual(len(test_chat), 4) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': 'test_content'}) - self.assertEqual(test_chat[2], - {'role': 'assistant', 'content': 'some answer'}) - self.assertEqual(test_chat[3], - {'role': 'user', 'content': self.question}) - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( - {'question': 'test_content', 'answer': 'some answer', - 'tags': ['other_tag']})) - - test_chat = create_chat_hist(self.question, self.tags, None, self.config) - - self.assertEqual(len(test_chat), 2) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': self.question}) - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.side_effect = ( - io.StringIO(dump_data({'question': 'test_content', - 'answer': 'some answer', - 'tags': ['test_tag']})), - io.StringIO(dump_data({'question': 'test_content2', - 'answer': 'some answer2', - 'tags': ['test_tag2']})), - ) - - test_chat = create_chat_hist(self.question, [], None, self.config) - - self.assertEqual(len(test_chat), 6) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': 'test_content'}) - self.assertEqual(test_chat[2], - {'role': 'assistant', 'content': 'some answer'}) - self.assertEqual(test_chat[3], - {'role': 'user', 'content': 'test_content2'}) - self.assertEqual(test_chat[4], - {'role': 'assistant', 'content': 'some answer2'}) - - -class TestHandleQuestion(CmmTestCase): - - def setUp(self) -> None: - self.question = "test question" - self.args = argparse.Namespace( - or_tags=['tag1'], - and_tags=None, - exclude_tags=['xtag1'], - output_tags=None, - question=[self.question], - source=None, - source_code_only=False, - num_answers=3, - max_tokens=None, - temperature=None, - model=None, - match_all_tags=False, - with_tags=False, - with_file=False, - ) - self.config = self.dummy_config(db='test_files') - - @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") - @patch("chatmastermind.main.print_tag_args") - @patch("chatmastermind.main.print_chat_hist") - @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) - @patch("chatmastermind.utils.pp") - @patch("builtins.print") - def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, - mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, - mock_create_chat_hist: MagicMock) -> None: - open_mock = MagicMock() - with patch("chatmastermind.storage.open", open_mock): - ask_cmd(self.args, self.config) - mock_print_tag_args.assert_called_once_with(self.args.or_tags, - self.args.exclude_tags, - []) - mock_create_chat_hist.assert_called_once_with(self.question, - self.args.or_tags, - self.args.exclude_tags, - self.config, - match_all_tags=False, - with_tags=False, - with_file=False) - mock_print_chat_hist.assert_called_once_with('test_chat', - False, - self.args.source_code_only) - mock_ai.assert_called_with("test_chat", - self.config, - self.args.num_answers) - expected_calls = [] - for num, answer in enumerate(mock_ai.return_value[0], start=1): - title = f'-- ANSWER {num} ' - title_end = '-' * (terminal_width() - len(title)) - expected_calls.append(((f'{title}{title_end}',),)) - expected_calls.append(((answer,),)) - expected_calls.append((("-" * terminal_width(),),)) - expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) - self.assertEqual(mock_print.call_args_list, expected_calls) - open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) - open_mock.assert_has_calls(open_expected_calls, any_order=True) - - -class TestSaveAnswers(CmmTestCase): - @mock.patch('builtins.open') - @mock.patch('chatmastermind.storage.print') - def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: - question = "Test question?" - answers = ["Answer 1", "Answer 2"] - tags = ["tag1", "tag2"] - otags = ["otag1", "otag2"] - config = self.dummy_config(db='test_db') - - with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ - mock.patch('chatmastermind.storage.yaml.dump'), \ - mock.patch('io.StringIO') as stringio_mock: - stringio_instance = stringio_mock.return_value - stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] - save_answers(question, answers, tags, otags, config) - - open_calls = [ - mock.call(pathlib.Path('test_db/.next'), 'r'), - mock.call(pathlib.Path('test_db/.next'), 'w'), - ] - open_mock.assert_has_calls(open_calls, any_order=True) - - -class TestAI(CmmTestCase): - - @patch("openai.ChatCompletion.create") - def test_ai(self, mock_create: MagicMock) -> None: - mock_create.return_value = { - 'choices': [ - {'message': {'content': 'response_text_1'}}, - {'message': {'content': 'response_text_2'}} - ], - 'usage': {'tokens': 10} - } - - chat = [{"role": "system", "content": "hello ai"}] - config = self.dummy_config(db='dummy') - config.openai.model = "text-davinci-002" - config.openai.max_tokens = 150 - config.openai.temperature = 0.5 - - result = ai(chat, config, 2) - expected_result = (['response_text_1', 'response_text_2'], - {'tokens': 10}) - self.assertEqual(result, expected_result) - - -class TestCreateParser(CmmTestCase): - def test_create_parser(self) -> None: - with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: - mock_cmdparser = Mock() - mock_add_subparsers.return_value = mock_cmdparser - parser = create_parser() - self.assertIsInstance(parser, argparse.ArgumentParser) - mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) - mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) - self.assertTrue('.config.yaml' in parser.get_default('config')) +# class CmmTestCase(unittest.TestCase): +# """ +# Base class for all cmm testcases. +# """ +# def dummy_config(self, db: str) -> Config: +# """ +# Creates a dummy configuration. +# """ +# return Config.from_dict( +# {'system': 'dummy_system', +# 'db': db, +# 'openai': {'api_key': 'dummy_key', +# 'model': 'dummy_model', +# 'max_tokens': 4000, +# 'temperature': 1.0, +# 'top_p': 1, +# 'frequency_penalty': 0, +# 'presence_penalty': 0}} +# ) +# +# +# class TestCreateChat(CmmTestCase): +# +# def setUp(self) -> None: +# self.config = self.dummy_config(db='test_files') +# self.question = "test question" +# self.tags = ['test_tag'] +# +# @patch('os.listdir') +# @patch('pathlib.Path.iterdir') +# @patch('builtins.open') +# def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: +# listdir_mock.return_value = ['testfile.txt'] +# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] +# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( +# {'question': 'test_content', 'answer': 'some answer', +# 'tags': ['test_tag']})) +# +# test_chat = create_chat_hist(self.question, self.tags, None, self.config) +# +# self.assertEqual(len(test_chat), 4) +# self.assertEqual(test_chat[0], +# {'role': 'system', 'content': self.config.system}) +# self.assertEqual(test_chat[1], +# {'role': 'user', 'content': 'test_content'}) +# self.assertEqual(test_chat[2], +# {'role': 'assistant', 'content': 'some answer'}) +# self.assertEqual(test_chat[3], +# {'role': 'user', 'content': self.question}) +# +# @patch('os.listdir') +# @patch('pathlib.Path.iterdir') +# @patch('builtins.open') +# def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: +# listdir_mock.return_value = ['testfile.txt'] +# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] +# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( +# {'question': 'test_content', 'answer': 'some answer', +# 'tags': ['other_tag']})) +# +# test_chat = create_chat_hist(self.question, self.tags, None, self.config) +# +# self.assertEqual(len(test_chat), 2) +# self.assertEqual(test_chat[0], +# {'role': 'system', 'content': self.config.system}) +# self.assertEqual(test_chat[1], +# {'role': 'user', 'content': self.question}) +# +# @patch('os.listdir') +# @patch('pathlib.Path.iterdir') +# @patch('builtins.open') +# def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: +# listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] +# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] +# open_mock.side_effect = ( +# io.StringIO(dump_data({'question': 'test_content', +# 'answer': 'some answer', +# 'tags': ['test_tag']})), +# io.StringIO(dump_data({'question': 'test_content2', +# 'answer': 'some answer2', +# 'tags': ['test_tag2']})), +# ) +# +# test_chat = create_chat_hist(self.question, [], None, self.config) +# +# self.assertEqual(len(test_chat), 6) +# self.assertEqual(test_chat[0], +# {'role': 'system', 'content': self.config.system}) +# self.assertEqual(test_chat[1], +# {'role': 'user', 'content': 'test_content'}) +# self.assertEqual(test_chat[2], +# {'role': 'assistant', 'content': 'some answer'}) +# self.assertEqual(test_chat[3], +# {'role': 'user', 'content': 'test_content2'}) +# self.assertEqual(test_chat[4], +# {'role': 'assistant', 'content': 'some answer2'}) +# +# +# class TestHandleQuestion(CmmTestCase): +# +# def setUp(self) -> None: +# self.question = "test question" +# self.args = argparse.Namespace( +# or_tags=['tag1'], +# and_tags=None, +# exclude_tags=['xtag1'], +# output_tags=None, +# question=[self.question], +# source=None, +# source_code_only=False, +# num_answers=3, +# max_tokens=None, +# temperature=None, +# model=None, +# match_all_tags=False, +# with_tags=False, +# with_file=False, +# ) +# self.config = self.dummy_config(db='test_files') +# +# @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") +# @patch("chatmastermind.main.print_tag_args") +# @patch("chatmastermind.main.print_chat_hist") +# @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) +# @patch("chatmastermind.utils.pp") +# @patch("builtins.print") +# def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, +# mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, +# mock_create_chat_hist: MagicMock) -> None: +# open_mock = MagicMock() +# with patch("chatmastermind.storage.open", open_mock): +# ask_cmd(self.args, self.config) +# mock_print_tag_args.assert_called_once_with(self.args.or_tags, +# self.args.exclude_tags, +# []) +# mock_create_chat_hist.assert_called_once_with(self.question, +# self.args.or_tags, +# self.args.exclude_tags, +# self.config, +# match_all_tags=False, +# with_tags=False, +# with_file=False) +# mock_print_chat_hist.assert_called_once_with('test_chat', +# False, +# self.args.source_code_only) +# mock_ai.assert_called_with("test_chat", +# self.config, +# self.args.num_answers) +# expected_calls = [] +# for num, answer in enumerate(mock_ai.return_value[0], start=1): +# title = f'-- ANSWER {num} ' +# title_end = '-' * (terminal_width() - len(title)) +# expected_calls.append(((f'{title}{title_end}',),)) +# expected_calls.append(((answer,),)) +# expected_calls.append((("-" * terminal_width(),),)) +# expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) +# self.assertEqual(mock_print.call_args_list, expected_calls) +# open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) +# open_mock.assert_has_calls(open_expected_calls, any_order=True) +# +# +# class TestSaveAnswers(CmmTestCase): +# @mock.patch('builtins.open') +# @mock.patch('chatmastermind.storage.print') +# def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: +# question = "Test question?" +# answers = ["Answer 1", "Answer 2"] +# tags = ["tag1", "tag2"] +# otags = ["otag1", "otag2"] +# config = self.dummy_config(db='test_db') +# +# with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ +# mock.patch('chatmastermind.storage.yaml.dump'), \ +# mock.patch('io.StringIO') as stringio_mock: +# stringio_instance = stringio_mock.return_value +# stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] +# save_answers(question, answers, tags, otags, config) +# +# open_calls = [ +# mock.call(pathlib.Path('test_db/.next'), 'r'), +# mock.call(pathlib.Path('test_db/.next'), 'w'), +# ] +# open_mock.assert_has_calls(open_calls, any_order=True) +# +# +# class TestAI(CmmTestCase): +# +# @patch("openai.ChatCompletion.create") +# def test_ai(self, mock_create: MagicMock) -> None: +# mock_create.return_value = { +# 'choices': [ +# {'message': {'content': 'response_text_1'}}, +# {'message': {'content': 'response_text_2'}} +# ], +# 'usage': {'tokens': 10} +# } +# +# chat = [{"role": "system", "content": "hello ai"}] +# config = self.dummy_config(db='dummy') +# config.openai.model = "text-davinci-002" +# config.openai.max_tokens = 150 +# config.openai.temperature = 0.5 +# +# result = ai(chat, config, 2) +# expected_result = (['response_text_1', 'response_text_2'], +# {'tokens': 10}) +# self.assertEqual(result, expected_result) +# +# +# class TestCreateParser(CmmTestCase): +# def test_create_parser(self) -> None: +# with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: +# mock_cmdparser = Mock() +# mock_add_subparsers.return_value = mock_cmdparser +# parser = create_parser() +# self.assertIsInstance(parser, argparse.ArgumentParser) +# mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) +# mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) +# self.assertTrue('.config.yaml' in parser.get_default('config')) diff --git a/tests/test_message.py b/tests/test_message.py index a49c893..57d5982 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,12 +1,12 @@ +import unittest import pathlib import tempfile from typing import cast -from .test_main import CmmTestCase from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in from chatmastermind.tags import Tag, TagLine -class SourceCodeTestCase(CmmTestCase): +class SourceCodeTestCase(unittest.TestCase): def test_source_code_with_include_delims(self) -> None: text = """ Some text before the code block @@ -60,7 +60,7 @@ class SourceCodeTestCase(CmmTestCase): self.assertEqual(result, expected_result) -class QuestionTestCase(CmmTestCase): +class QuestionTestCase(unittest.TestCase): def test_question_with_header(self) -> None: with self.assertRaises(MessageError): Question(f"{Question.txt_header}\nWhat is your name?") @@ -83,7 +83,7 @@ class QuestionTestCase(CmmTestCase): self.assertEqual(question, "What is your favorite color?") -class AnswerTestCase(CmmTestCase): +class AnswerTestCase(unittest.TestCase): def test_answer_with_header(self) -> None: with self.assertRaises(MessageError): Answer(f"{Answer.txt_header}\nno") @@ -99,7 +99,7 @@ class AnswerTestCase(CmmTestCase): self.assertEqual(answer, "No") -class MessageToFileTxtTestCase(CmmTestCase): +class MessageToFileTxtTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path = pathlib.Path(self.file.name) @@ -160,7 +160,7 @@ This is a question. self.message_complete.file_path = self.file_path -class MessageToFileYamlTestCase(CmmTestCase): +class MessageToFileYamlTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path = pathlib.Path(self.file.name) @@ -226,7 +226,7 @@ class MessageToFileYamlTestCase(CmmTestCase): self.assertEqual(content, expected_content) -class MessageFromFileTxtTestCase(CmmTestCase): +class MessageFromFileTxtTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path = pathlib.Path(self.file.name) @@ -388,7 +388,7 @@ This is a question. self.assertIsNone(message) -class MessageFromFileYamlTestCase(CmmTestCase): +class MessageFromFileYamlTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path = pathlib.Path(self.file.name) @@ -555,7 +555,7 @@ class MessageFromFileYamlTestCase(CmmTestCase): self.assertIsNone(message) -class TagsFromFileTestCase(CmmTestCase): +class TagsFromFileTestCase(unittest.TestCase): def setUp(self) -> None: self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path_txt = pathlib.Path(self.file_txt.name) @@ -663,7 +663,7 @@ This is an answer. self.assertSetEqual(tags, set()) -class TagsFromDirTestCase(CmmTestCase): +class TagsFromDirTestCase(unittest.TestCase): def setUp(self) -> None: self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir_no_tags = tempfile.TemporaryDirectory() @@ -711,7 +711,7 @@ class TagsFromDirTestCase(CmmTestCase): self.assertSetEqual(all_tags, set()) -class MessageIDTestCase(CmmTestCase): +class MessageIDTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path = pathlib.Path(self.file.name) @@ -731,7 +731,7 @@ class MessageIDTestCase(CmmTestCase): self.message_no_file_path.msg_id() -class MessageHashTestCase(CmmTestCase): +class MessageHashTestCase(unittest.TestCase): def setUp(self) -> None: self.message1 = Message(Question('This is a question.'), tags={Tag('tag1')}, @@ -755,7 +755,7 @@ class MessageHashTestCase(CmmTestCase): self.assertIn(msg, msgs) -class MessageTagsStrTestCase(CmmTestCase): +class MessageTagsStrTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), tags={Tag('tag1')}, @@ -765,7 +765,7 @@ class MessageTagsStrTestCase(CmmTestCase): self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') -class MessageFilterTagsTestCase(CmmTestCase): +class MessageFilterTagsTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), tags={Tag('atag1'), Tag('btag2')}, @@ -780,7 +780,7 @@ class MessageFilterTagsTestCase(CmmTestCase): self.assertSetEqual(tags_cont, {Tag('btag2')}) -class MessageInTestCase(CmmTestCase): +class MessageInTestCase(unittest.TestCase): def setUp(self) -> None: self.message1 = Message(Question('This is a question.'), tags={Tag('atag1'), Tag('btag2')}, @@ -794,7 +794,7 @@ class MessageInTestCase(CmmTestCase): self.assertFalse(message_in(self.message1, [self.message2])) -class MessageRenameTagsTestCase(CmmTestCase): +class MessageRenameTagsTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), tags={Tag('atag1'), Tag('btag2')}, @@ -806,7 +806,7 @@ class MessageRenameTagsTestCase(CmmTestCase): self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] -class MessageToStrTestCase(CmmTestCase): +class MessageToStrTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), Answer('This is an answer.'), diff --git a/tests/test_tags.py b/tests/test_tags.py index aa89a06..edd3c05 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -1,8 +1,8 @@ -from .test_main import CmmTestCase +import unittest from chatmastermind.tags import Tag, TagLine, TagError -class TestTag(CmmTestCase): +class TestTag(unittest.TestCase): def test_valid_tag(self) -> None: tag = Tag('mytag') self.assertEqual(tag, 'mytag') @@ -18,7 +18,7 @@ class TestTag(CmmTestCase): self.assertEqual(Tag.alternative_separators, [',']) -class TestTagLine(CmmTestCase): +class TestTagLine(unittest.TestCase): def test_valid_tagline(self) -> None: tagline = TagLine('TAGS: tag1 tag2') self.assertEqual(tagline, 'TAGS: tag1 tag2') -- 2.36.6 From 6e447018d5ac72b0427bad026e4dbbac3b9a949b Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Thu, 7 Sep 2023 18:11:32 +0200 Subject: [PATCH 077/170] Fix tags_completter. --- chatmastermind/main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 7866179..32e4ccd 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -18,8 +18,7 @@ default_config = '.config.yaml' def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: - with open(parsed_args.config, 'r') as f: - config = yaml.load(f, Loader=yaml.FullLoader) + config = Config.from_file(parsed_args.config) return get_tags_unique(config, prefix) -- 2.36.6 From 74a26b8c2f42ec59916e3e744c9db23d40ee6fa4 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 09:23:29 +0200 Subject: [PATCH 078/170] setup: added 'ais' subfolder --- chatmastermind/ais/__init__.py | 0 setup.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 chatmastermind/ais/__init__.py diff --git a/chatmastermind/ais/__init__.py b/chatmastermind/ais/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py index 02d9ab1..8484629 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/ok2/ChatMastermind", - packages=find_packages(), + packages=find_packages() + ["chatmastermind.ais"], classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Console", @@ -32,7 +32,7 @@ setup( "openai", "PyYAML", "argcomplete", - "pytest" + "pytest", ], python_requires=">=3.9", test_suite="tests", -- 2.36.6 From 2df9dd64274a9f0c3214281d7e032ea8e131432a Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 09:43:23 +0200 Subject: [PATCH 079/170] cmm: removed all the old code and modules --- chatmastermind/api_client.py | 45 ------- chatmastermind/main.py | 104 ++------------- chatmastermind/storage.py | 121 ------------------ chatmastermind/utils.py | 80 ------------ tests/test_main.py | 236 ----------------------------------- 5 files changed, 12 insertions(+), 574 deletions(-) delete mode 100644 chatmastermind/api_client.py delete mode 100644 chatmastermind/storage.py delete mode 100644 chatmastermind/utils.py delete mode 100644 tests/test_main.py diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py deleted file mode 100644 index 2c4a094..0000000 --- a/chatmastermind/api_client.py +++ /dev/null @@ -1,45 +0,0 @@ -import openai - -from .utils import ChatType -from .configuration import Config - - -def openai_api_key(api_key: str) -> None: - openai.api_key = api_key - - -def print_models() -> None: - """ - Print all models supported by the current AI. - """ - not_ready = [] - for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): - if engine['ready']: - print(engine['id']) - else: - not_ready.append(engine['id']) - if len(not_ready) > 0: - print('\nNot ready: ' + ', '.join(not_ready)) - - -def ai(chat: ChatType, - config: Config, - number: int - ) -> tuple[list[str], dict[str, int]]: - """ - Make AI request with the given chat history and configuration. - Return AI response and tokens used. - """ - response = openai.ChatCompletion.create( - model=config.openai.model, - messages=chat, - temperature=config.openai.temperature, - max_tokens=config.openai.max_tokens, - top_p=config.openai.top_p, - n=number, - frequency_penalty=config.openai.frequency_penalty, - presence_penalty=config.openai.presence_penalty) - result = [] - for choice in response['choices']: # type: ignore - result.append(choice['message']['content'].strip()) - return result, dict(response['usage']) # type: ignore diff --git a/chatmastermind/main.py b/chatmastermind/main.py index b10b97b..857bb5a 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,61 +6,19 @@ import sys import argcomplete import argparse from pathlib import Path -from .utils import terminal_width, print_tag_args, print_chat_hist, ChatType -from .storage import save_answers, create_chat_hist -from .api_client import ai, openai_api_key, print_models -from .configuration import Config +from .configuration import Config, default_config_path from .chat import ChatDB from .message import Message, MessageFilter, MessageError, Question from .ai_factory import create_ai from .ai import AI, AIResponse -from itertools import zip_longest from typing import Any -default_config = '.config.yaml' - def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: config = Config.from_file(parsed_args.config) return list(Message.tags_from_dir(Path(config.db), prefix=prefix)) -def create_question_with_hist(args: argparse.Namespace, - config: Config, - ) -> tuple[ChatType, str, list[str]]: - """ - Creates the "AI request", including the question and chat history as determined - by the specified tags. - """ - tags = args.or_tags or [] - xtags = args.exclude_tags or [] - otags = args.output_tags or [] - - if not args.source_code_only: - print_tag_args(tags, xtags, otags) - - question_parts = [] - question_list = args.question if args.question is not None else [] - source_list = args.source if args.source is not None else [] - - for question, source in zip_longest(question_list, source_list, fillvalue=None): - if question is not None and source is not None: - with open(source) as r: - question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") - elif question is not None: - question_parts.append(question) - elif source is not None: - with open(source) as r: - question_parts.append(f"```\n{r.read().strip()}\n```") - - full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, xtags, config, - match_all_tags=True if args.and_tags else False, # FIXME - with_tags=False, - with_file=False) - return chat, full_question, tags - - def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'tags' command. @@ -74,17 +32,12 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None: # TODO: add renaming -def config_cmd(args: argparse.Namespace, config: Config) -> None: +def config_cmd(args: argparse.Namespace) -> None: """ Handler for the 'config' command. """ - if args.list_models: - print_models() - elif args.print_model: - print(config.openai.model) - elif args.model: - config.openai.model = args.model - config.to_file(args.config) + if args.create: + Config.create_default(Path(args.create)) def question_cmd(args: argparse.Namespace, config: Config) -> None: @@ -95,6 +48,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: db_path=Path(config.db)) # if it's a new question, create and store it immediately if args.ask or args.create: + # FIXME: add sources to the question message = Message(question=Question(args.question), tags=args.ouput_tags, # FIXME ai=args.ai, @@ -128,25 +82,6 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: pass -def ask_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'ask' command. - """ - if args.max_tokens: - config.openai.max_tokens = args.max_tokens - if args.temperature: - config.openai.temperature = args.temperature - if args.model: - config.openai.model = args.model - chat, question, tags = create_question_with_hist(args, config) - print_chat_hist(chat, False, args.source_code_only) - otags = args.output_tags or [] - answers, usage = ai(chat, config, args.num_answers) - save_answers(question, answers, tags, otags, config) - print("-" * terminal_width()) - print(f"Usage: {usage}") - - def hist_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'hist' command. @@ -182,7 +117,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") - parser.add_argument('-C', '--config', help='Config file name.', default=default_config) + parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path) # subcommand-parser cmdparser = parser.add_subparsers(dest='command', @@ -227,22 +162,6 @@ def create_parser() -> argparse.ArgumentParser: question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', action='store_true') - # 'ask' command parser - ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], - help="Ask a question.", - aliases=['a']) - ask_cmd_parser.set_defaults(func=ask_cmd) - ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask', - required=True) - ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) - ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) - ask_cmd_parser.add_argument('-M', '--model', help='Model to use') - ask_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, - default=1) - ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', - action='store_true') - # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], help="Print chat history.", @@ -278,7 +197,7 @@ def create_parser() -> argparse.ArgumentParser: action='store_true') config_group.add_argument('-m', '--print-model', help="Print the currently configured model", action='store_true') - config_group.add_argument('-M', '--model', help="Set model in the config file") + config_group.add_argument('-c', '--create', help="Create config with default settings in the given file") # 'print' command parser print_cmd_parser = cmdparser.add_parser('print', @@ -297,11 +216,12 @@ def main() -> int: parser = create_parser() args = parser.parse_args() command = parser.parse_args() - config = Config.from_file(args.config) - openai_api_key(config.openai.api_key) - - command.func(command, config) + if command.func == config_cmd: + command.func(command) + else: + config = Config.from_file(args.config) + command.func(command, config) return 0 diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py deleted file mode 100644 index 8b9ed97..0000000 --- a/chatmastermind/storage.py +++ /dev/null @@ -1,121 +0,0 @@ -import yaml -import io -import pathlib -from .utils import terminal_width, append_message, message_to_chat, ChatType -from .configuration import Config -from typing import Any, Optional - - -def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: - with open(fname, "r") as fd: - tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() - # also support tags separated by ',' (old format) - separator = ',' if ',' in tagline else ' ' - tags = [t.strip() for t in tagline.split(separator)] - if tags_only: - return {"tags": tags} - text = fd.read().strip().split('\n') - question_idx = text.index("=== QUESTION ===") + 1 - answer_idx = text.index("==== ANSWER ====") - question = "\n".join(text[question_idx:answer_idx]).strip() - answer = "\n".join(text[answer_idx + 1:]).strip() - return {"question": question, "answer": answer, "tags": tags, - "file": fname.name} - - -def dump_data(data: dict[str, Any]) -> str: - with io.StringIO() as fd: - fd.write(f'TAGS: {" ".join(data["tags"])}\n') - fd.write(f'=== QUESTION ===\n{data["question"]}\n') - fd.write(f'==== ANSWER ====\n{data["answer"]}\n') - return fd.getvalue() - - -def write_file(fname: str, data: dict[str, Any]) -> None: - with open(fname, "w") as fd: - fd.write(f'TAGS: {" ".join(data["tags"])}\n') - fd.write(f'=== QUESTION ===\n{data["question"]}\n') - fd.write(f'==== ANSWER ====\n{data["answer"]}\n') - - -def save_answers(question: str, - answers: list[str], - tags: list[str], - otags: Optional[list[str]], - config: Config - ) -> None: - wtags = otags or tags - num, inum = 0, 0 - next_fname = pathlib.Path(str(config.db)) / '.next' - try: - with open(next_fname, 'r') as f: - num = int(f.read()) - except Exception: - pass - for answer in answers: - num += 1 - inum += 1 - title = f'-- ANSWER {inum} ' - title_end = '-' * (terminal_width() - len(title)) - print(f'{title}{title_end}') - print(answer) - write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags}) - with open(next_fname, 'w') as f: - f.write(f'{num}') - - -def create_chat_hist(question: Optional[str], - tags: Optional[list[str]], - extags: Optional[list[str]], - config: Config, - match_all_tags: bool = False, - with_tags: bool = False, - with_file: bool = False - ) -> ChatType: - chat: ChatType = [] - append_message(chat, 'system', str(config.system).strip()) - for file in sorted(pathlib.Path(str(config.db)).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - data['file'] = file.name - elif file.suffix == '.txt': - data = read_file(file) - else: - continue - data_tags = set(data.get('tags', [])) - tags_match: bool - if match_all_tags: - tags_match = not tags or set(tags).issubset(data_tags) - else: - tags_match = not tags or bool(data_tags.intersection(tags)) - extags_do_not_match = \ - not extags or not data_tags.intersection(extags) - if tags_match and extags_do_not_match: - message_to_chat(data, chat, with_tags, with_file) - if question: - append_message(chat, 'user', question) - return chat - - -def get_tags(config: Config, prefix: Optional[str]) -> list[str]: - result = [] - for file in sorted(pathlib.Path(str(config.db)).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - elif file.suffix == '.txt': - data = read_file(file, tags_only=True) - else: - continue - for tag in data.get('tags', []): - if prefix and len(prefix) > 0: - if tag.startswith(prefix): - result.append(tag) - else: - result.append(tag) - return result - - -def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]: - return list(set(get_tags(config, prefix))) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py deleted file mode 100644 index e6eeb97..0000000 --- a/chatmastermind/utils.py +++ /dev/null @@ -1,80 +0,0 @@ -import shutil -from pprint import PrettyPrinter -from typing import Any - -ChatType = list[dict[str, str]] - - -def terminal_width() -> int: - return shutil.get_terminal_size().columns - - -def pp(*args: Any, **kwargs: Any) -> None: - return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) - - -def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None: - """ - Prints the tags specified in the given args. - """ - printed_messages = [] - - if tags: - printed_messages.append(f"Tags: {' '.join(tags)}") - if extags: - printed_messages.append(f"Excluding tags: {' '.join(extags)}") - if otags: - printed_messages.append(f"Output tags: {' '.join(otags)}") - - if printed_messages: - print("\n".join(printed_messages)) - print() - - -def append_message(chat: ChatType, - role: str, - content: str - ) -> None: - chat.append({'role': role, 'content': content.replace("''", "'")}) - - -def message_to_chat(message: dict[str, str], - chat: ChatType, - with_tags: bool = False, - with_file: bool = False - ) -> None: - append_message(chat, 'user', message['question']) - append_message(chat, 'assistant', message['answer']) - if with_tags: - tags = " ".join(message['tags']) - append_message(chat, 'tags', tags) - if with_file: - append_message(chat, 'file', message['file']) - - -def display_source_code(content: str) -> None: - try: - content_start = content.index('```') - content_end = content.rindex('```') - if content_start + 3 < content_end: - print(content[content_start + 3:content_end].strip()) - except ValueError: - pass - - -def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None: - if dump: - pp(chat) - return - for message in chat: - text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2 - if source_code: - display_source_code(message['content']) - continue - if message['role'] == 'user': - print('-' * terminal_width()) - if text_too_long: - print(f"{message['role'].upper()}:") - print(message['content']) - else: - print(f"{message['role'].upper()}: {message['content']}") diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 91e6462..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,236 +0,0 @@ -# import unittest -# import io -# import pathlib -# import argparse -# from chatmastermind.utils import terminal_width -# from chatmastermind.main import create_parser, ask_cmd -# from chatmastermind.api_client import ai -# from chatmastermind.configuration import Config -# from chatmastermind.storage import create_chat_hist, save_answers, dump_data -# from unittest import mock -# from unittest.mock import patch, MagicMock, Mock, ANY - - -# class CmmTestCase(unittest.TestCase): -# """ -# Base class for all cmm testcases. -# """ -# def dummy_config(self, db: str) -> Config: -# """ -# Creates a dummy configuration. -# """ -# return Config.from_dict( -# {'system': 'dummy_system', -# 'db': db, -# 'openai': {'api_key': 'dummy_key', -# 'model': 'dummy_model', -# 'max_tokens': 4000, -# 'temperature': 1.0, -# 'top_p': 1, -# 'frequency_penalty': 0, -# 'presence_penalty': 0}} -# ) -# -# -# class TestCreateChat(CmmTestCase): -# -# def setUp(self) -> None: -# self.config = self.dummy_config(db='test_files') -# self.question = "test question" -# self.tags = ['test_tag'] -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( -# {'question': 'test_content', 'answer': 'some answer', -# 'tags': ['test_tag']})) -# -# test_chat = create_chat_hist(self.question, self.tags, None, self.config) -# -# self.assertEqual(len(test_chat), 4) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': 'test_content'}) -# self.assertEqual(test_chat[2], -# {'role': 'assistant', 'content': 'some answer'}) -# self.assertEqual(test_chat[3], -# {'role': 'user', 'content': self.question}) -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( -# {'question': 'test_content', 'answer': 'some answer', -# 'tags': ['other_tag']})) -# -# test_chat = create_chat_hist(self.question, self.tags, None, self.config) -# -# self.assertEqual(len(test_chat), 2) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': self.question}) -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.side_effect = ( -# io.StringIO(dump_data({'question': 'test_content', -# 'answer': 'some answer', -# 'tags': ['test_tag']})), -# io.StringIO(dump_data({'question': 'test_content2', -# 'answer': 'some answer2', -# 'tags': ['test_tag2']})), -# ) -# -# test_chat = create_chat_hist(self.question, [], None, self.config) -# -# self.assertEqual(len(test_chat), 6) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': 'test_content'}) -# self.assertEqual(test_chat[2], -# {'role': 'assistant', 'content': 'some answer'}) -# self.assertEqual(test_chat[3], -# {'role': 'user', 'content': 'test_content2'}) -# self.assertEqual(test_chat[4], -# {'role': 'assistant', 'content': 'some answer2'}) -# -# -# class TestHandleQuestion(CmmTestCase): -# -# def setUp(self) -> None: -# self.question = "test question" -# self.args = argparse.Namespace( -# or_tags=['tag1'], -# and_tags=None, -# exclude_tags=['xtag1'], -# output_tags=None, -# question=[self.question], -# source=None, -# source_code_only=False, -# num_answers=3, -# max_tokens=None, -# temperature=None, -# model=None, -# match_all_tags=False, -# with_tags=False, -# with_file=False, -# ) -# self.config = self.dummy_config(db='test_files') -# -# @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") -# @patch("chatmastermind.main.print_tag_args") -# @patch("chatmastermind.main.print_chat_hist") -# @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) -# @patch("chatmastermind.utils.pp") -# @patch("builtins.print") -# def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, -# mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, -# mock_create_chat_hist: MagicMock) -> None: -# open_mock = MagicMock() -# with patch("chatmastermind.storage.open", open_mock): -# ask_cmd(self.args, self.config) -# mock_print_tag_args.assert_called_once_with(self.args.or_tags, -# self.args.exclude_tags, -# []) -# mock_create_chat_hist.assert_called_once_with(self.question, -# self.args.or_tags, -# self.args.exclude_tags, -# self.config, -# match_all_tags=False, -# with_tags=False, -# with_file=False) -# mock_print_chat_hist.assert_called_once_with('test_chat', -# False, -# self.args.source_code_only) -# mock_ai.assert_called_with("test_chat", -# self.config, -# self.args.num_answers) -# expected_calls = [] -# for num, answer in enumerate(mock_ai.return_value[0], start=1): -# title = f'-- ANSWER {num} ' -# title_end = '-' * (terminal_width() - len(title)) -# expected_calls.append(((f'{title}{title_end}',),)) -# expected_calls.append(((answer,),)) -# expected_calls.append((("-" * terminal_width(),),)) -# expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) -# self.assertEqual(mock_print.call_args_list, expected_calls) -# open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) -# open_mock.assert_has_calls(open_expected_calls, any_order=True) -# -# -# class TestSaveAnswers(CmmTestCase): -# @mock.patch('builtins.open') -# @mock.patch('chatmastermind.storage.print') -# def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: -# question = "Test question?" -# answers = ["Answer 1", "Answer 2"] -# tags = ["tag1", "tag2"] -# otags = ["otag1", "otag2"] -# config = self.dummy_config(db='test_db') -# -# with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ -# mock.patch('chatmastermind.storage.yaml.dump'), \ -# mock.patch('io.StringIO') as stringio_mock: -# stringio_instance = stringio_mock.return_value -# stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] -# save_answers(question, answers, tags, otags, config) -# -# open_calls = [ -# mock.call(pathlib.Path('test_db/.next'), 'r'), -# mock.call(pathlib.Path('test_db/.next'), 'w'), -# ] -# open_mock.assert_has_calls(open_calls, any_order=True) -# -# -# class TestAI(CmmTestCase): -# -# @patch("openai.ChatCompletion.create") -# def test_ai(self, mock_create: MagicMock) -> None: -# mock_create.return_value = { -# 'choices': [ -# {'message': {'content': 'response_text_1'}}, -# {'message': {'content': 'response_text_2'}} -# ], -# 'usage': {'tokens': 10} -# } -# -# chat = [{"role": "system", "content": "hello ai"}] -# config = self.dummy_config(db='dummy') -# config.openai.model = "text-davinci-002" -# config.openai.max_tokens = 150 -# config.openai.temperature = 0.5 -# -# result = ai(chat, config, 2) -# expected_result = (['response_text_1', 'response_text_2'], -# {'tokens': 10}) -# self.assertEqual(result, expected_result) -# -# -# class TestCreateParser(CmmTestCase): -# def test_create_parser(self) -> None: -# with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: -# mock_cmdparser = Mock() -# mock_add_subparsers.return_value = mock_cmdparser -# parser = create_parser() -# self.assertIsInstance(parser, argparse.ArgumentParser) -# mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) -# mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) -# self.assertTrue('.config.yaml' in parser.get_default('config')) -- 2.36.6 From ed567afbeac07b7426e230991e74d1f50a32bb97 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Fri, 8 Sep 2023 15:54:29 +0200 Subject: [PATCH 080/170] Make it possible to print just question or answer on printing files. --- chatmastermind/main.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 32e4ccd..c30ea4e 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -125,6 +125,10 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: sys.exit(1) if args.only_source_code: display_source_code(data['answer']) + elif args.answer: + print(data['answer'].strip()) + elif args.question: + print(data['question'].strip()) else: print(dump_data(data).strip()) @@ -213,8 +217,10 @@ def create_parser() -> argparse.ArgumentParser: aliases=['p']) print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) - print_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', - action='store_true') + print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group() + print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true') + print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true') + print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') argcomplete.autocomplete(parser) return parser -- 2.36.6 From b1a23394fc741f5038a4ac0b9d6772448d077f9d Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 13:31:01 +0200 Subject: [PATCH 081/170] cmm: splitted commands into separate modules (and more cleanup) --- chatmastermind/commands/config.py | 11 +++ chatmastermind/commands/hist.py | 23 +++++ chatmastermind/commands/print.py | 19 ++++ chatmastermind/commands/question.py | 57 ++++++++++++ chatmastermind/commands/tags.py | 17 ++++ chatmastermind/main.py | 131 +++++----------------------- setup.py | 2 +- tests/test_ai_factory.py | 48 ++++++++++ 8 files changed, 196 insertions(+), 112 deletions(-) create mode 100644 chatmastermind/commands/config.py create mode 100644 chatmastermind/commands/hist.py create mode 100644 chatmastermind/commands/print.py create mode 100644 chatmastermind/commands/question.py create mode 100644 chatmastermind/commands/tags.py create mode 100644 tests/test_ai_factory.py diff --git a/chatmastermind/commands/config.py b/chatmastermind/commands/config.py new file mode 100644 index 0000000..262164c --- /dev/null +++ b/chatmastermind/commands/config.py @@ -0,0 +1,11 @@ +import argparse +from pathlib import Path +from ..configuration import Config + + +def config_cmd(args: argparse.Namespace) -> None: + """ + Handler for the 'config' command. + """ + if args.create: + Config.create_default(Path(args.create)) diff --git a/chatmastermind/commands/hist.py b/chatmastermind/commands/hist.py new file mode 100644 index 0000000..88ed3be --- /dev/null +++ b/chatmastermind/commands/hist.py @@ -0,0 +1,23 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB +from ..message import MessageFilter + + +def hist_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'hist' command. + """ + + mfilter = MessageFilter(tags_or=args.or_tags, + tags_and=args.and_tags, + tags_not=args.exclude_tags, + question_contains=args.question, + answer_contains=args.answer) + chat = ChatDB.from_dir(Path('.'), + Path(config.db), + mfilter=mfilter) + chat.print(args.source_code_only, + args.with_tags, + args.with_files) diff --git a/chatmastermind/commands/print.py b/chatmastermind/commands/print.py new file mode 100644 index 0000000..51e76f8 --- /dev/null +++ b/chatmastermind/commands/print.py @@ -0,0 +1,19 @@ +import sys +import argparse +from pathlib import Path +from ..configuration import Config +from ..message import Message, MessageError + + +def print_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'print' command. + """ + fname = Path(args.file) + try: + message = Message.from_file(fname) + if message: + print(message.to_str(source_code_only=args.source_code_only)) + except MessageError: + print(f"File is not a valid message: {args.file}") + sys.exit(1) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py new file mode 100644 index 0000000..9c56ced --- /dev/null +++ b/chatmastermind/commands/question.py @@ -0,0 +1,57 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB +from ..message import Message, Question +from ..ai_factory import create_ai +from ..ai import AI, AIResponse + + +def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: + """ + Creates (and writes) a new message from the given arguments. + """ + # FIXME: add sources to the question + message = Message(question=Question(args.question), + tags=args.output_tags, # FIXME + ai=args.ai, + model=args.model) + chat.add_to_cache([message]) + return message + + +def question_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'question' command. + """ + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) + # if it's a new question, create and store it immediately + if args.ask or args.create: + message = create_message(chat, args) + if args.create: + return + + # create the correct AI instance + ai: AI = create_ai(args, config) + if args.ask: + response: AIResponse = ai.request(message, + chat, + args.num_answers, # FIXME + args.otags) # FIXME + assert response + # TODO: + # * add answer to the message above (and create + # more messages for any additional answers) + pass + elif args.repeat: + lmessage = chat.latest_message() + assert lmessage + # TODO: repeat either the last question or the + # one(s) given in 'args.repeat' (overwrite + # existing ones if 'args.overwrite' is True) + pass + elif args.process: + # TODO: process either all questions without an + # answer or the one(s) given in 'args.process' + pass diff --git a/chatmastermind/commands/tags.py b/chatmastermind/commands/tags.py new file mode 100644 index 0000000..2906a5b --- /dev/null +++ b/chatmastermind/commands/tags.py @@ -0,0 +1,17 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB + + +def tags_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'tags' command. + """ + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) + if args.list: + tags_freq = chat.tags_frequency(args.prefix, args.contain) + for tag, freq in tags_freq.items(): + print(f"- {tag}: {freq}") + # TODO: add renaming diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 857bb5a..88121b4 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,12 +6,14 @@ import sys import argcomplete import argparse from pathlib import Path -from .configuration import Config, default_config_path -from .chat import ChatDB -from .message import Message, MessageFilter, MessageError, Question -from .ai_factory import create_ai -from .ai import AI, AIResponse from typing import Any +from .configuration import Config, default_config_path +from .message import Message +from .commands.question import question_cmd +from .commands.tags import tags_cmd +from .commands.config import config_cmd +from .commands.hist import hist_cmd +from .commands.print import print_cmd def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: @@ -19,101 +21,6 @@ def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: return list(Message.tags_from_dir(Path(config.db), prefix=prefix)) -def tags_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'tags' command. - """ - chat = ChatDB.from_dir(cache_path=Path('.'), - db_path=Path(config.db)) - if args.list: - tags_freq = chat.tags_frequency(args.prefix, args.contain) - for tag, freq in tags_freq.items(): - print(f"- {tag}: {freq}") - # TODO: add renaming - - -def config_cmd(args: argparse.Namespace) -> None: - """ - Handler for the 'config' command. - """ - if args.create: - Config.create_default(Path(args.create)) - - -def question_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'question' command. - """ - chat = ChatDB.from_dir(cache_path=Path('.'), - db_path=Path(config.db)) - # if it's a new question, create and store it immediately - if args.ask or args.create: - # FIXME: add sources to the question - message = Message(question=Question(args.question), - tags=args.ouput_tags, # FIXME - ai=args.ai, - model=args.model) - chat.add_to_cache([message]) - if args.create: - return - - # create the correct AI instance - ai: AI = create_ai(args, config) - if args.ask: - response: AIResponse = ai.request(message, - chat, - args.num_answers, # FIXME - args.otags) # FIXME - assert response - # TODO: - # * add answer to the message above (and create - # more messages for any additional answers) - pass - elif args.repeat: - lmessage = chat.latest_message() - assert lmessage - # TODO: repeat either the last question or the - # one(s) given in 'args.repeat' (overwrite - # existing ones if 'args.overwrite' is True) - pass - elif args.process: - # TODO: process either all questions without an - # answer or the one(s) given in 'args.process' - pass - - -def hist_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'hist' command. - """ - - mfilter = MessageFilter(tags_or=args.or_tags, - tags_and=args.and_tags, - tags_not=args.exclude_tags, - question_contains=args.question, - answer_contains=args.answer) - chat = ChatDB.from_dir(Path('.'), - Path(config.db), - mfilter=mfilter) - chat.print(args.source_code_only, - args.with_tags, - args.with_files) - - -def print_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'print' command. - """ - fname = Path(args.file) - try: - message = Message.from_file(fname) - if message: - print(message.to_str(source_code_only=args.source_code_only)) - except MessageError: - print(f"File is not a valid message: {args.file}") - sys.exit(1) - - def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") @@ -128,20 +35,28 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+', - help='List of tag names (one must match)', metavar='OTAGS') + help='List of tags (one must match)', metavar='OTAGS') tag_arg.completer = tags_completer # type: ignore atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+', - help='List of tag names (all must match)', metavar='ATAGS') + help='List of tags (all must match)', metavar='ATAGS') atag_arg.completer = tags_completer # type: ignore etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+', - help='List of tag names to exclude', metavar='XTAGS') + help='List of tags to exclude', metavar='XTAGS') etag_arg.completer = tags_completer # type: ignore otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', - help='List of output tag names, default is input', metavar='OUTTAGS') + help='List of output tags (default: use input tags)', metavar='OUTTAGS') otag_arg.completer = tags_completer # type: ignore + # a parent parser for all commands that support AI configuration + ai_parser = argparse.ArgumentParser(add_help=False) + ai_parser.add_argument('-A', '--AI', help='AI ID to use') + ai_parser.add_argument('-M', '--model', help='Model to use') + ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1) + ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int) + ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float) + # 'question' command parser - question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser], + question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser, ai_parser], help="ask, create and process questions.", aliases=['q']) question_cmd_parser.set_defaults(func=question_cmd) @@ -152,12 +67,6 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) - question_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) - question_cmd_parser.add_argument('-A', '--AI', help='AI to use') - question_cmd_parser.add_argument('-M', '--model', help='Model to use') - question_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, - default=1) question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', action='store_true') diff --git a/setup.py b/setup.py index 8484629..a311605 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/ok2/ChatMastermind", - packages=find_packages() + ["chatmastermind.ais"], + packages=find_packages() + ["chatmastermind.ais", "chatmastermind.commands"], classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Console", diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py new file mode 100644 index 0000000..d63970e --- /dev/null +++ b/tests/test_ai_factory.py @@ -0,0 +1,48 @@ +import argparse +import unittest +from unittest.mock import MagicMock +from chatmastermind.ai_factory import create_ai +from chatmastermind.configuration import Config +from chatmastermind.ai import AIError +from chatmastermind.ais.openai import OpenAI + + +class TestCreateAI(unittest.TestCase): + def setUp(self) -> None: + self.args = MagicMock(spec=argparse.Namespace) + self.args.ai = 'default' + self.args.model = None + self.args.max_tokens = None + self.args.temperature = None + + def test_create_ai_from_args(self) -> None: + # Create an AI with the default configuration + config = Config() + self.args.ai = 'default' + ai = create_ai(self.args, config) + self.assertIsInstance(ai, OpenAI) + + def test_create_ai_from_default(self) -> None: + self.args.ai = None + # Create an AI with the default configuration + config = Config() + ai = create_ai(self.args, config) + self.assertIsInstance(ai, OpenAI) + + def test_create_empty_ai_error(self) -> None: + self.args.ai = None + # Create Config with empty AIs + config = Config() + config.ais = {} + # Call create_ai function and assert that it raises AIError + with self.assertRaises(AIError): + create_ai(self.args, config) + + def test_create_unsupported_ai_error(self) -> None: + # Mock argparse.Namespace with ai='invalid_ai' + self.args.ai = 'invalid_ai' + # Create default Config + config = Config() + # Call create_ai function and assert that it raises AIError + with self.assertRaises(AIError): + create_ai(self.args, config) -- 2.36.6 From eaa399bcb90fa90ec5c2e6bc6e1fb100f8ade338 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 6 Sep 2023 22:52:03 +0200 Subject: [PATCH 082/170] configuration et al: implemented new Config format --- chatmastermind/ai.py | 13 ++-- chatmastermind/ai_factory.py | 29 ++++++-- chatmastermind/ais/openai.py | 9 +-- chatmastermind/configuration.py | 119 ++++++++++++++++++++++++++------ 4 files changed, 134 insertions(+), 36 deletions(-) diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py index 4a8b914..e94de8e 100644 --- a/chatmastermind/ai.py +++ b/chatmastermind/ai.py @@ -33,18 +33,23 @@ class AI(Protocol): The base class for AI clients. """ + ID: str name: str config: AIConfig def request(self, question: Message, - context: Chat, + chat: Chat, num_answers: int = 1, otags: Optional[set[Tag]] = None) -> AIResponse: """ - Make an AI request, asking the given question with the given - context (i. e. chat history). The nr. of requested answers - corresponds to the nr. of messages in the 'AIResponse'. + Make an AI request. Parameters: + * question: the question to ask + * chat: the chat history to be added as context + * num_answers: nr. of requested answers (corresponds + to the nr. of messages in the 'AIResponse') + * otags: the output tags, i. e. the tags that all + returned messages should contain """ raise NotImplementedError diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index c90366b..c4a063a 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -3,18 +3,35 @@ Creates different AI instances, based on the given configuration. """ import argparse -from .configuration import Config +from typing import cast +from .configuration import Config, OpenAIConfig, default_ai_ID from .ai import AI, AIError from .ais.openai import OpenAI def create_ai(args: argparse.Namespace, config: Config) -> AI: """ - Creates an AI subclass instance from the given args and configuration. + Creates an AI subclass instance from the given arguments + and configuration file. """ - if args.ai == 'openai': - # FIXME: create actual 'OpenAIConfig' and set values from 'args' - # FIXME: use actual name from config - return OpenAI("openai", config.openai) + if args.ai: + try: + ai_conf = config.ais[args.ai] + except KeyError: + raise AIError(f"AI ID '{args.ai}' does not exist in this configuration") + elif default_ai_ID in config.ais: + ai_conf = config.ais[default_ai_ID] + else: + raise AIError("No AI name given and no default exists") + + if ai_conf.name == 'openai': + ai = OpenAI(cast(OpenAIConfig, ai_conf)) + if args.model: + ai.config.model = args.model + if args.max_tokens: + ai.config.max_tokens = args.max_tokens + if args.temperature: + ai.config.temperature = args.temperature + return ai else: raise AIError(f"AI '{args.ai}' is not supported") diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 74438b8..14ce33f 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -17,9 +17,11 @@ class OpenAI(AI): The OpenAI AI client. """ - def __init__(self, name: str, config: OpenAIConfig) -> None: - self.name = name + def __init__(self, config: OpenAIConfig) -> None: + self.ID = config.ID + self.name = config.name self.config = config + openai.api_key = config.api_key def request(self, question: Message, @@ -31,8 +33,7 @@ class OpenAI(AI): chat history. The nr. of requested answers corresponds to the nr. of messages in the 'AIResponse'. """ - # FIXME: use real 'system' message (store in OpenAIConfig) - oai_chat = self.openai_chat(chat, "system", question) + oai_chat = self.openai_chat(chat, self.config.system, question) response = openai.ChatCompletion.create( model=self.config.model, messages=oai_chat, diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 0780604..d82f913 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -1,17 +1,40 @@ import yaml -from typing import Type, TypeVar, Any -from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Type, TypeVar, Any, Optional, ClassVar +from dataclasses import dataclass, asdict, field ConfigInst = TypeVar('ConfigInst', bound='Config') +AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') +supported_ais: list[str] = ['openai'] +default_ai_ID: str = 'default' +default_config_path = '.config.yaml' + + +class ConfigError(Exception): + pass + + @dataclass class AIConfig: """ The base class of all AI configurations. """ - name: str + # the name of the AI the config class represents + # -> it's a class variable and thus not part of the + # dataclass constructor + name: ClassVar[str] + # a user-defined ID for an AI configuration entry + ID: str + + # the name must not be changed + def __setattr__(self, name: str, value: Any) -> None: + if name == 'name': + raise AttributeError("'{name}' is not allowed to be changed") + else: + super().__setattr__(name, value) @dataclass @@ -19,21 +42,27 @@ class OpenAIConfig(AIConfig): """ The OpenAI section of the configuration file. """ - api_key: str - model: str - temperature: float - max_tokens: int - top_p: float - frequency_penalty: float - presence_penalty: float + name: ClassVar[str] = 'openai' + + # all members have default values, so we can easily create + # a default configuration + ID: str = 'default' + api_key: str = '0123456789' + system: str = 'You are an assistant' + model: str = 'gpt-3.5-turbo-16k' + temperature: float = 1.0 + max_tokens: int = 4000 + top_p: float = 1.0 + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 @classmethod def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: """ Create OpenAIConfig from a dict. """ - return cls( - name='OpenAI', + res = cls( + system=str(source['system']), api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), @@ -42,6 +71,30 @@ class OpenAIConfig(AIConfig): frequency_penalty=float(source['frequency_penalty']), presence_penalty=float(source['presence_penalty']) ) + # overwrite default ID if provided + if 'ID' in source: + res.ID = source['ID'] + return res + + +def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig: + """ + Creates an AIConfig instance of the given name. + """ + if name.lower() == 'openai': + if conf_dict is None: + return OpenAIConfig() + else: + return OpenAIConfig.from_dict(conf_dict) + else: + raise ConfigError(f"AI '{name}' is not supported") + + +def create_default_ai_configs() -> dict[str, AIConfig]: + """ + Create a dict containing default configurations for all supported AIs. + """ + return {ai_config_instance(name).ID: ai_config_instance(name) for name in supported_ais} @dataclass @@ -49,30 +102,52 @@ class Config: """ The configuration file structure. """ - system: str - db: str - openai: OpenAIConfig + # all members have default values, so we can easily create + # a default configuration + db: str = './db/' + ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs) @classmethod def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: """ - Create Config from a dict. + Create Config from a dict (with the same format as the config file). """ + # create the correct AI type instances + ais: dict[str, AIConfig] = {} + for ID, conf in source['ais'].items(): + # add the AI ID to the config (for easy internal access) + conf['ID'] = ID + ai_conf = ai_config_instance(conf['name'], conf) + ais[ID] = ai_conf return cls( - system=str(source['system']), db=str(source['db']), - openai=OpenAIConfig.from_dict(source['openai']) + ais=ais ) + @classmethod + def create_default(self, file_path: Path) -> None: + """ + Creates a default Config in the given file. + """ + conf = Config() + conf.to_file(file_path) + @classmethod def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: with open(path, 'r') as f: source = yaml.load(f, Loader=yaml.FullLoader) return cls.from_dict(source) - def to_file(self, path: str) -> None: - with open(path, 'w') as f: - yaml.dump(asdict(self), f, sort_keys=False) + def to_file(self, file_path: Path) -> None: + # remove the AI name from the config (for a cleaner format) + data = self.as_dict() + for conf in data['ais'].values(): + del (conf['ID']) + with open(file_path, 'w') as f: + yaml.dump(data, f, sort_keys=False) def as_dict(self) -> dict[str, Any]: - return asdict(self) + res = asdict(self) + for ID, conf in res['ais'].items(): + conf.update({'name': self.ais[ID].name}) + return res -- 2.36.6 From 76f23733972f6add31353ca987303e590c3e5b76 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 10:40:22 +0200 Subject: [PATCH 083/170] configuration: added tests --- chatmastermind/configuration.py | 2 +- tests/test_configuration.py | 160 ++++++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 tests/test_configuration.py diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index d82f913..398fa03 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -87,7 +87,7 @@ def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> else: return OpenAIConfig.from_dict(conf_dict) else: - raise ConfigError(f"AI '{name}' is not supported") + raise ConfigError(f"Unknown AI '{name}'") def create_default_ai_configs() -> dict[str, AIConfig]: diff --git a/tests/test_configuration.py b/tests/test_configuration.py new file mode 100644 index 0000000..f3f9a98 --- /dev/null +++ b/tests/test_configuration.py @@ -0,0 +1,160 @@ +import os +import unittest +import yaml +from tempfile import NamedTemporaryFile +from pathlib import Path +from typing import cast +from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config + + +class TestAIConfigInstance(unittest.TestCase): + def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None: + ai_config = cast(OpenAIConfig, ai_config_instance('openai')) + ai_reference = OpenAIConfig() + self.assertEqual(ai_config.ID, ai_reference.ID) + self.assertEqual(ai_config.name, ai_reference.name) + self.assertEqual(ai_config.api_key, ai_reference.api_key) + self.assertEqual(ai_config.system, ai_reference.system) + self.assertEqual(ai_config.model, ai_reference.model) + self.assertEqual(ai_config.temperature, ai_reference.temperature) + self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens) + self.assertEqual(ai_config.top_p, ai_reference.top_p) + self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty) + self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty) + + def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None: + conf_dict = { + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict)) + self.assertEqual(ai_config.system, 'Custom system') + self.assertEqual(ai_config.api_key, '9876543210') + self.assertEqual(ai_config.model, 'custom_model') + self.assertEqual(ai_config.max_tokens, 5000) + self.assertAlmostEqual(ai_config.temperature, 0.5) + self.assertAlmostEqual(ai_config.top_p, 0.8) + self.assertAlmostEqual(ai_config.frequency_penalty, 0.7) + self.assertAlmostEqual(ai_config.presence_penalty, 0.2) + + def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None: + with self.assertRaises(ConfigError): + ai_config_instance('invalid_name') + + +class TestConfig(unittest.TestCase): + def setUp(self) -> None: + self.test_file = NamedTemporaryFile(delete=False) + + def tearDown(self) -> None: + os.remove(self.test_file.name) + + def test_from_dict_should_create_config_from_dict(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'openai', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + config = Config.from_dict(source_dict) + self.assertEqual(config.db, './test_db/') + self.assertEqual(len(config.ais), 1) + self.assertEqual(config.ais['default'].name, 'openai') + self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + # check that 'ID' has been added + self.assertEqual(config.ais['default'].ID, 'default') + + def test_create_default_should_create_default_config(self) -> None: + Config.create_default(Path(self.test_file.name)) + with open(self.test_file.name, 'r') as f: + default_config = yaml.load(f, Loader=yaml.FullLoader) + config_reference = Config() + self.assertEqual(default_config['db'], config_reference.db) + + def test_from_file_should_load_config_from_file(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'openai', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + with open(self.test_file.name, 'w') as f: + yaml.dump(source_dict, f) + config = Config.from_file(self.test_file.name) + self.assertIsInstance(config, Config) + self.assertEqual(config.db, './test_db/') + self.assertEqual(len(config.ais), 1) + self.assertIsInstance(config.ais['default'], AIConfig) + self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + + def test_to_file_should_save_config_to_file(self) -> None: + config = Config( + db='./test_db/', + ais={ + 'default': OpenAIConfig( + ID='default', + system='Custom system', + api_key='9876543210', + model='custom_model', + max_tokens=5000, + temperature=0.5, + top_p=0.8, + frequency_penalty=0.7, + presence_penalty=0.2 + ) + } + ) + config.to_file(Path(self.test_file.name)) + with open(self.test_file.name, 'r') as f: + saved_config = yaml.load(f, Loader=yaml.FullLoader) + self.assertEqual(saved_config['db'], './test_db/') + self.assertEqual(len(saved_config['ais']), 1) + self.assertEqual(saved_config['ais']['default']['system'], 'Custom system') + + def test_from_file_error_unknown_ai(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'foobla', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + with open(self.test_file.name, 'w') as f: + yaml.dump(source_dict, f) + with self.assertRaises(ConfigError): + Config.from_file(self.test_file.name) -- 2.36.6 From c0b7d17587f45a4b28a8083cd067b7b427816627 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 9 Sep 2023 08:51:17 +0200 Subject: [PATCH 084/170] question_cmd: fixes --- chatmastermind/commands/question.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 9c56ced..1709a3c 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -1,5 +1,6 @@ import argparse from pathlib import Path +from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB from ..message import Message, Question @@ -11,8 +12,26 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: """ Creates (and writes) a new message from the given arguments. """ - # FIXME: add sources to the question - message = Message(question=Question(args.question), + question_parts = [] + question_list = args.question if args.question is not None else [] + source_list = args.source if args.source is not None else [] + + # FIXME: don't surround all sourced files with ``` + # -> do it only if '--source-code-only' is True and no source code + # could be extracted from that file + for question, source in zip_longest(question_list, source_list, fillvalue=None): + if question is not None and source is not None: + with open(source) as r: + question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") + elif question is not None: + question_parts.append(question) + elif source is not None: + with open(source) as r: + question_parts.append(f"```\n{r.read().strip()}\n```") + + full_question = '\n\n'.join(question_parts) + + message = Message(question=Question(full_question), tags=args.output_tags, # FIXME ai=args.ai, model=args.model) -- 2.36.6 From 5fb5dde550539d0612d9ce3b9ae223bcebdef6a2 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 9 Sep 2023 08:31:30 +0200 Subject: [PATCH 085/170] question cmd: added tests --- tests/test_question_cmd.py | 111 +++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tests/test_question_cmd.py diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py new file mode 100644 index 0000000..96b2fdf --- /dev/null +++ b/tests/test_question_cmd.py @@ -0,0 +1,111 @@ +import os +import unittest +import argparse +import tempfile +from pathlib import Path +from unittest.mock import MagicMock +from chatmastermind.commands.question import create_message +from chatmastermind.message import Message, Question +from chatmastermind.chat import ChatDB + + +class TestMessageCreate(unittest.TestCase): + """ + Test if messages created by the 'question' command have + the correct format. + """ + def setUp(self) -> None: + # create ChatDB structure + self.db_path = tempfile.TemporaryDirectory() + self.cache_path = tempfile.TemporaryDirectory() + self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), + db_path=Path(self.db_path.name)) + # create arguments mock + self.args = MagicMock(spec=argparse.Namespace) + self.args.source = None + self.args.source_code_only = False + self.args.ai = None + self.args.model = None + self.args.output_tags = None + # create some files for sourcing + self.source_file1 = tempfile.NamedTemporaryFile(delete=False) + self.source_file1_content = """This is just text. +No source code. +Nope. Go look elsewhere!""" + with open(self.source_file1.name, 'w') as f: + f.write(self.source_file1_content) + self.source_file2 = tempfile.NamedTemporaryFile(delete=False) + self.source_file2_content = """This is just text. +``` +This is embedded source code. +``` +And some text again.""" + with open(self.source_file2.name, 'w') as f: + f.write(self.source_file2_content) + self.source_file3 = tempfile.NamedTemporaryFile(delete=False) + self.source_file3_content = """This is all source code. +Yes, really. +Language is called 'brainfart'.""" + with open(self.source_file3.name, 'w') as f: + f.write(self.source_file3_content) + + def tearDown(self) -> None: + os.remove(self.source_file1.name) + os.remove(self.source_file2.name) + os.remove(self.source_file3.name) + + def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: + # exclude '.next' + return list(Path(tmp_dir.name).glob('*.[ty]*')) + + def test_message_file_created(self) -> None: + self.args.question = ["What is this?"] + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + create_message(self.chat, self.args) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + message = Message.from_file(cache_dir_files[0]) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] + + def test_single_question(self) -> None: + self.args.question = ["What is this?"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("What is this?")) + self.assertEqual(len(message.question.source_code()), 0) + + def test_multipart_question(self) -> None: + self.args.question = ["What is this", "'bard' thing?", "Is it good?"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("""What is this + +'bard' thing? + +Is it good?""")) + + def test_single_question_with_text_only_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source = [f"{self.source_file1.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains no source code + # -> don't expect any in the question + self.assertEqual(len(message.question.source_code()), 0) + self.assertEqual(message.question, Question("""What is this? + +{self.source_file1_content}""")) + + def test_single_question_with_embedded_source_code_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source = [f"{self.source_file2.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains 1 source code block + # -> expect it in the question + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question("""What is this? + +{self.source_file2_content}""")) -- 2.36.6 From 3ef1339cc0d9b58103599c068b464b2b93183641 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 11:53:32 +0200 Subject: [PATCH 086/170] Fix extracting source file with type specification. --- chatmastermind/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index bd80e4f..6543ce1 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -55,9 +55,10 @@ def message_to_chat(message: dict[str, str], def display_source_code(content: str) -> None: try: content_start = content.index('```') + content_start = content.index('\n', content_start) + 1 content_end = content.rindex('```') - if content_start + 3 < content_end: - print(content[content_start + 3:content_end].strip()) + if content_start < content_end: + print(content[content_start:content_end].strip()) except ValueError: pass -- 2.36.6 From 7cf62c54efd0a6ca6a23eacbd7a7bd716962f97f Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 15:16:17 +0200 Subject: [PATCH 087/170] Allow in question -s for just sourcing file and -S to source file with ``` encapsulation. --- chatmastermind/commands/question.py | 22 ++++++++++++---------- chatmastermind/main.py | 5 ++--- tests/test_question_cmd.py | 22 ++++++++++++++++++---- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 1709a3c..818b1de 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -15,19 +15,21 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: question_parts = [] question_list = args.question if args.question is not None else [] source_list = args.source if args.source is not None else [] + code_list = args.source_code if args.source_code is not None else [] - # FIXME: don't surround all sourced files with ``` - # -> do it only if '--source-code-only' is True and no source code - # could be extracted from that file - for question, source in zip_longest(question_list, source_list, fillvalue=None): - if question is not None and source is not None: - with open(source) as r: - question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") - elif question is not None: + for question, source, code in zip_longest(question_list, source_list, code_list, fillvalue=None): + if question is not None and len(question.strip()) > 0: question_parts.append(question) - elif source is not None: + if source is not None and len(source) > 0: with open(source) as r: - question_parts.append(f"```\n{r.read().strip()}\n```") + content = r.read().strip() + if len(content) > 0: + question_parts.append(content) + if code is not None and len(code) > 0: + with open(code) as r: + content = r.read().strip() + if len(content) > 0: + question_parts.append(f"```\n{content}\n```") full_question = '\n\n'.join(question_parts) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 88121b4..f7163ab 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -67,9 +67,8 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', - action='store_true') + question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Add content of a file to the query') + question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 96b2fdf..06cc527 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -23,7 +23,7 @@ class TestMessageCreate(unittest.TestCase): # create arguments mock self.args = MagicMock(spec=argparse.Namespace) self.args.source = None - self.args.source_code_only = False + self.args.source_code = None self.args.ai = None self.args.model = None self.args.output_tags = None @@ -94,11 +94,11 @@ Is it good?""")) # source file contains no source code # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) - self.assertEqual(message.question, Question("""What is this? + self.assertEqual(message.question, Question(f"""What is this? {self.source_file1_content}""")) - def test_single_question_with_embedded_source_code_source(self) -> None: + def test_single_question_with_embedded_source_source(self) -> None: self.args.question = ["What is this?"] self.args.source = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) @@ -106,6 +106,20 @@ Is it good?""")) # source file contains 1 source code block # -> expect it in the question self.assertEqual(len(message.question.source_code()), 1) - self.assertEqual(message.question, Question("""What is this? + self.assertEqual(message.question, Question(f"""What is this? {self.source_file2_content}""")) + + def test_single_question_with_embedded_source_code_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source_code = [f"{self.source_file2.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains 1 source code block + # -> expect it in the question + self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(message.question, Question(f"""What is this? + +``` +{self.source_file2_content} +```""")) -- 2.36.6 From d22877a0f1a206ed7697fe4d773ef576bbf30aa3 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 15:38:40 +0200 Subject: [PATCH 088/170] Port print arguments -q/-a/-S from main to restructuring. --- chatmastermind/commands/print.py | 10 +++++++++- chatmastermind/main.py | 6 ++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/chatmastermind/commands/print.py b/chatmastermind/commands/print.py index 51e76f8..3d2b990 100644 --- a/chatmastermind/commands/print.py +++ b/chatmastermind/commands/print.py @@ -13,7 +13,15 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: try: message = Message.from_file(fname) if message: - print(message.to_str(source_code_only=args.source_code_only)) + if args.question: + print(message.question) + elif args.answer: + print(message.answer) + elif message.answer and args.only_source_code: + for code in message.answer.source_code(): + print(code) + else: + print(message.to_str()) except MessageError: print(f"File is not a valid message: {args.file}") sys.exit(1) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index f7163ab..eadb095 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -113,8 +113,10 @@ def create_parser() -> argparse.ArgumentParser: aliases=['p']) print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) - print_cmd_parser.add_argument('-S', '--source-code-only', help='Print source code only (from the answer, if available)', - action='store_true') + print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group() + print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true') + print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true') + print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') argcomplete.autocomplete(parser) return parser -- 2.36.6 From 39b518a8a60f335fb952995ee151440f899c7f85 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 16:05:27 +0200 Subject: [PATCH 089/170] Small fixes. --- chatmastermind/ai_factory.py | 8 ++++---- chatmastermind/commands/question.py | 6 +++--- tests/test_ai_factory.py | 10 +++++----- tests/test_question_cmd.py | 14 +++++++------- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index c4a063a..bc4583c 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -14,11 +14,11 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: Creates an AI subclass instance from the given arguments and configuration file. """ - if args.ai: + if args.AI: try: - ai_conf = config.ais[args.ai] + ai_conf = config.ais[args.AI] except KeyError: - raise AIError(f"AI ID '{args.ai}' does not exist in this configuration") + raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") elif default_ai_ID in config.ais: ai_conf = config.ais[default_ai_ID] else: @@ -34,4 +34,4 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: ai.config.temperature = args.temperature return ai else: - raise AIError(f"AI '{args.ai}' is not supported") + raise AIError(f"AI '{args.AI}' is not supported") diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 818b1de..90b782b 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -13,7 +13,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: Creates (and writes) a new message from the given arguments. """ question_parts = [] - question_list = args.question if args.question is not None else [] + question_list = args.ask if args.ask is not None else [] source_list = args.source if args.source is not None else [] code_list = args.source_code if args.source_code is not None else [] @@ -35,7 +35,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: message = Message(question=Question(full_question), tags=args.output_tags, # FIXME - ai=args.ai, + ai=args.AI, model=args.model) chat.add_to_cache([message]) return message @@ -59,7 +59,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: response: AIResponse = ai.request(message, chat, args.num_answers, # FIXME - args.otags) # FIXME + args.output_tags) # FIXME assert response # TODO: # * add answer to the message above (and create diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py index d63970e..d00b319 100644 --- a/tests/test_ai_factory.py +++ b/tests/test_ai_factory.py @@ -10,7 +10,7 @@ from chatmastermind.ais.openai import OpenAI class TestCreateAI(unittest.TestCase): def setUp(self) -> None: self.args = MagicMock(spec=argparse.Namespace) - self.args.ai = 'default' + self.args.AI = 'default' self.args.model = None self.args.max_tokens = None self.args.temperature = None @@ -18,19 +18,19 @@ class TestCreateAI(unittest.TestCase): def test_create_ai_from_args(self) -> None: # Create an AI with the default configuration config = Config() - self.args.ai = 'default' + self.args.AI = 'default' ai = create_ai(self.args, config) self.assertIsInstance(ai, OpenAI) def test_create_ai_from_default(self) -> None: - self.args.ai = None + self.args.AI = None # Create an AI with the default configuration config = Config() ai = create_ai(self.args, config) self.assertIsInstance(ai, OpenAI) def test_create_empty_ai_error(self) -> None: - self.args.ai = None + self.args.AI = None # Create Config with empty AIs config = Config() config.ais = {} @@ -40,7 +40,7 @@ class TestCreateAI(unittest.TestCase): def test_create_unsupported_ai_error(self) -> None: # Mock argparse.Namespace with ai='invalid_ai' - self.args.ai = 'invalid_ai' + self.args.AI = 'invalid_ai' # Create default Config config = Config() # Call create_ai function and assert that it raises AIError diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 06cc527..aa0dc25 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -24,7 +24,7 @@ class TestMessageCreate(unittest.TestCase): self.args = MagicMock(spec=argparse.Namespace) self.args.source = None self.args.source_code = None - self.args.ai = None + self.args.AI = None self.args.model = None self.args.output_tags = None # create some files for sourcing @@ -59,7 +59,7 @@ Language is called 'brainfart'.""" return list(Path(tmp_dir.name).glob('*.[ty]*')) def test_message_file_created(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) create_message(self.chat, self.args) @@ -70,14 +70,14 @@ Language is called 'brainfart'.""" self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] def test_single_question(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("What is this?")) self.assertEqual(len(message.question.source_code()), 0) def test_multipart_question(self) -> None: - self.args.question = ["What is this", "'bard' thing?", "Is it good?"] + self.args.ask = ["What is this", "'bard' thing?", "Is it good?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("""What is this @@ -87,7 +87,7 @@ Language is called 'brainfart'.""" Is it good?""")) def test_single_question_with_text_only_source(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] self.args.source = [f"{self.source_file1.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) @@ -99,7 +99,7 @@ Is it good?""")) {self.source_file1_content}""")) def test_single_question_with_embedded_source_source(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] self.args.source = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) @@ -111,7 +111,7 @@ Is it good?""")) {self.source_file2_content}""")) def test_single_question_with_embedded_source_code_source(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) -- 2.36.6 From 53582a71239e01b0c2a6cef6a0529e2a082d3118 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 9 Sep 2023 18:28:10 +0200 Subject: [PATCH 090/170] question_cmd: fixed source code extraction and added a testcase --- chatmastermind/commands/question.py | 17 +++++-- chatmastermind/main.py | 2 +- chatmastermind/message.py | 2 +- tests/test_question_cmd.py | 79 +++++++++++++++++++++-------- 4 files changed, 72 insertions(+), 28 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 90b782b..756a051 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -3,7 +3,7 @@ from pathlib import Path from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB -from ..message import Message, Question +from ..message import Message, Question, source_code from ..ai_factory import create_ai from ..ai import AI, AIResponse @@ -14,10 +14,10 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: """ question_parts = [] question_list = args.ask if args.ask is not None else [] - source_list = args.source if args.source is not None else [] - code_list = args.source_code if args.source_code is not None else [] + text_files = args.source_text if args.source_text is not None else [] + code_files = args.source_code if args.source_code is not None else [] - for question, source, code in zip_longest(question_list, source_list, code_list, fillvalue=None): + for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None): if question is not None and len(question.strip()) > 0: question_parts.append(question) if source is not None and len(source) > 0: @@ -28,7 +28,14 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: if code is not None and len(code) > 0: with open(code) as r: content = r.read().strip() - if len(content) > 0: + if len(content) == 0: + continue + # try to extract and add source code + code_parts = source_code(content, include_delims=True) + if len(code_parts) > 0: + question_parts += code_parts + # if there's none, add the whole file + else: question_parts.append(f"```\n{content}\n```") full_question = '\n\n'.join(question_parts) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index eadb095..99aca09 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -67,7 +67,7 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Add content of a file to the query') + question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query') question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') # 'hist' command parser diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 35de3b9..7107c13 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -414,7 +414,7 @@ class Message(): return '\n'.join(output) def __str__(self) -> str: - return self.to_str(False, False, False) + return self.to_str(True, True, False) def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 """ diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index aa0dc25..40ea4d8 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -22,18 +22,19 @@ class TestMessageCreate(unittest.TestCase): db_path=Path(self.db_path.name)) # create arguments mock self.args = MagicMock(spec=argparse.Namespace) - self.args.source = None + self.args.source_text = None self.args.source_code = None self.args.AI = None self.args.model = None self.args.output_tags = None - # create some files for sourcing + # File 1 : no source code block, only text self.source_file1 = tempfile.NamedTemporaryFile(delete=False) self.source_file1_content = """This is just text. No source code. Nope. Go look elsewhere!""" with open(self.source_file1.name, 'w') as f: f.write(self.source_file1_content) + # File 2 : one embedded source code block self.source_file2 = tempfile.NamedTemporaryFile(delete=False) self.source_file2_content = """This is just text. ``` @@ -42,12 +43,26 @@ This is embedded source code. And some text again.""" with open(self.source_file2.name, 'w') as f: f.write(self.source_file2_content) + # File 3 : all source code self.source_file3 = tempfile.NamedTemporaryFile(delete=False) self.source_file3_content = """This is all source code. Yes, really. Language is called 'brainfart'.""" with open(self.source_file3.name, 'w') as f: f.write(self.source_file3_content) + # File 4 : two source code blocks + self.source_file4 = tempfile.NamedTemporaryFile(delete=False) + self.source_file4_content = """This is just text. +``` +This is embedded source code. +``` +And some text again. +``` +This is embedded source code. +``` +Aaaand again some text.""" + with open(self.source_file4.name, 'w') as f: + f.write(self.source_file4_content) def tearDown(self) -> None: os.remove(self.source_file1.name) @@ -86,40 +101,62 @@ Language is called 'brainfart'.""" Is it good?""")) - def test_single_question_with_text_only_source(self) -> None: + def test_single_question_with_text_only_file(self) -> None: self.args.ask = ["What is this?"] - self.args.source = [f"{self.source_file1.name}"] + self.args.source_text = [f"{self.source_file1.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) - # source file contains no source code + # file contains no source code (only text) # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) self.assertEqual(message.question, Question(f"""What is this? {self.source_file1_content}""")) - def test_single_question_with_embedded_source_source(self) -> None: - self.args.ask = ["What is this?"] - self.args.source = [f"{self.source_file2.name}"] - message = create_message(self.chat, self.args) - self.assertIsInstance(message, Message) - # source file contains 1 source code block - # -> expect it in the question - self.assertEqual(len(message.question.source_code()), 1) - self.assertEqual(message.question, Question(f"""What is this? - -{self.source_file2_content}""")) - - def test_single_question_with_embedded_source_code_source(self) -> None: + def test_single_question_with_text_file_and_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) - # source file contains 1 source code block + # file contains 1 source code block # -> expect it in the question - self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question("""What is this? + +``` +This is embedded source code. +``` +""")) + + def test_single_question_with_code_only_file(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.source_file3.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file is complete source code + self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question(f"""What is this? ``` -{self.source_file2_content} +{self.source_file3_content} ```""")) + + def test_single_question_with_text_file_and_multi_embedded_code(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.source_file4.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file contains 2 source code blocks + # -> expect them in the question + self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(message.question, Question("""What is this? + +``` +This is embedded source code. +``` + + +``` +This is embedded source code. +``` +""")) -- 2.36.6 From 1e3bfdd67fc13437f6f7da72468572e24f9e9818 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:39:00 +0200 Subject: [PATCH 091/170] chat: added 'update_messages()' function and test --- chatmastermind/chat.py | 16 ++++++++++++++++ tests/test_chat.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 4e8fb20..ddabb56 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -386,3 +386,19 @@ class ChatDB(Chat): msgs = iter(messages if messages else self.messages) while (m := next(msgs, None)): m.to_file() + + def update_messages(self, messages: list[Message], write: bool = True) -> None: + """ + Update existing messages. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. Only accepts + existing messages. + """ + if any(not message_in(m, self.messages) for m in messages): + raise ChatError("Can't update messages that are not in the internal list") + # remove old versions and add new ones + self.messages = [m for m in self.messages if not message_in(m, messages)] + self.messages += messages + self.sort() + # write the UPDATED messages if requested + if write: + self.write_messages(messages) diff --git a/tests/test_chat.py b/tests/test_chat.py index 8e4aa8c..ed630a4 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -440,3 +440,31 @@ class TestChatDB(unittest.TestCase): cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 1) self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) + + def test_chat_db_update_messages(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + + message = chat_db.messages[0] + message.answer = Answer("New answer") + # update message without writing + chat_db.update_messages([message], write=False) + self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) + # re-read the message and check for old content + chat_db.read_db() + self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1")) + # now check with writing (message should be overwritten) + chat_db.update_messages([message], write=True) + chat_db.read_db() + self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) + # test without file_path -> expect error + message1 = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + with self.assertRaises(ChatError): + chat_db.update_messages([message1]) -- 2.36.6 From dd3d3ffc82abd110b66a4e0af6ce2d990d702b7c Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 19:18:14 +0200 Subject: [PATCH 092/170] chat: added check for existing files when creating new filenames --- chatmastermind/chat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index ddabb56..7c4dd35 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -62,7 +62,10 @@ def make_file_path(dir_path: Path, Create a file_path for the given directory using the given file_suffix and ID generator function. """ - return dir_path / f"{next_fid():04d}{file_suffix}" + file_path = dir_path / f"{next_fid():04d}{file_suffix}" + while file_path.exists(): + file_path = dir_path / f"{next_fid():04d}{file_suffix}" + return file_path def write_dir(dir_path: Path, -- 2.36.6 From cf50818f28f389e2a66ad919f78c83aa41933dfe Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:52:07 +0200 Subject: [PATCH 093/170] question_cmd: fixed '--ask' command --- chatmastermind/ai.py | 6 ++++++ chatmastermind/ais/openai.py | 19 ++++++++++++++----- chatmastermind/commands/question.py | 15 ++++++++++----- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py index e94de8e..b97b5f1 100644 --- a/chatmastermind/ai.py +++ b/chatmastermind/ai.py @@ -66,3 +66,9 @@ class AI(Protocol): and is not implemented for all AIs. """ raise NotImplementedError + + def print(self) -> None: + """ + Print some info about the current AI, like system message. + """ + pass diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 14ce33f..1db4d20 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -43,16 +43,20 @@ class OpenAI(AI): n=num_answers, frequency_penalty=self.config.frequency_penalty, presence_penalty=self.config.presence_penalty) - answers: list[Message] = [] - for choice in response['choices']: # type: ignore + question.answer = Answer(response['choices'][0]['message']['content']) + question.tags = otags + question.ai = self.name + question.model = self.config.model + answers: list[Message] = [question] + for choice in response['choices'][1:]: # type: ignore answers.append(Message(question=question.question, answer=Answer(choice['message']['content']), tags=otags, ai=self.name, model=self.config.model)) - return AIResponse(answers, Tokens(response['usage']['prompt'], - response['usage']['completion'], - response['usage']['total'])) + return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], + response['usage']['completion_tokens'], + response['usage']['total_tokens'])) def models(self) -> list[str]: """ @@ -95,3 +99,8 @@ class OpenAI(AI): def tokens(self, data: Union[Message, Chat]) -> int: raise NotImplementedError + + def print(self) -> None: + print(f"MODEL: {self.config.model}") + print("=== SYSTEM ===") + print(self.config.system) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 756a051..fdabd62 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -63,15 +63,20 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: # create the correct AI instance ai: AI = create_ai(args, config) if args.ask: + ai.print() + chat.print(paged=False) response: AIResponse = ai.request(message, chat, args.num_answers, # FIXME args.output_tags) # FIXME - assert response - # TODO: - # * add answer to the message above (and create - # more messages for any additional answers) - pass + chat.update_messages([response.messages[0]]) + chat.add_to_cache(response.messages[1:]) + for idx, msg in enumerate(response.messages): + print(f"=== ANSWER {idx+1} ===") + print(msg.answer) + if response.tokens: + print("===============") + print(response.tokens) elif args.repeat: lmessage = chat.latest_message() assert lmessage -- 2.36.6 From 533ee1c1a94d40f0ed9fa4ed70f09947010ba65b Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:54:17 +0200 Subject: [PATCH 094/170] question_cmd: added message filtering by tags --- chatmastermind/commands/question.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index fdabd62..f439447 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -3,7 +3,7 @@ from pathlib import Path from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB -from ..message import Message, Question, source_code +from ..message import Message, MessageFilter, Question, source_code from ..ai_factory import create_ai from ..ai import AI, AIResponse @@ -52,8 +52,12 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'question' command. """ + mfilter = MessageFilter(tags_or=args.or_tags, + tags_and=args.and_tags, + tags_not=args.exclude_tags) chat = ChatDB.from_dir(cache_path=Path('.'), - db_path=Path(config.db)) + db_path=Path(config.db), + mfilter=mfilter) # if it's a new question, create and store it immediately if args.ask or args.create: message = create_message(chat, args) @@ -77,14 +81,14 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: if response.tokens: print("===============") print(response.tokens) - elif args.repeat: + elif args.repeat is not None: lmessage = chat.latest_message() assert lmessage # TODO: repeat either the last question or the # one(s) given in 'args.repeat' (overwrite # existing ones if 'args.overwrite' is True) pass - elif args.process: + elif args.process is not None: # TODO: process either all questions without an # answer or the one(s) given in 'args.process' pass -- 2.36.6 From b48667bfa0347e1237bb555fd5b2fb2e7514c621 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:55:47 +0200 Subject: [PATCH 095/170] openai: stores AI.ID instead of AI.name in message --- chatmastermind/ais/openai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 1db4d20..a388a7a 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -45,14 +45,14 @@ class OpenAI(AI): presence_penalty=self.config.presence_penalty) question.answer = Answer(response['choices'][0]['message']['content']) question.tags = otags - question.ai = self.name + question.ai = self.ID question.model = self.config.model answers: list[Message] = [question] for choice in response['choices'][1:]: # type: ignore answers.append(Message(question=question.question, answer=Answer(choice['message']['content']), tags=otags, - ai=self.name, + ai=self.ID, model=self.config.model)) return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], response['usage']['completion_tokens'], -- 2.36.6 From eca44b14cb9b810c62dc896c30b6994a3eb0f757 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:24:20 +0200 Subject: [PATCH 096/170] message: fixed matching with empty tag sets --- chatmastermind/message.py | 4 ++-- tests/test_chat.py | 22 ++++++++++++++++++++-- tests/test_message.py | 6 ++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 7107c13..df59ed6 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -312,7 +312,7 @@ class Message(): mfilter.tags_not if mfilter else None) else: message = cls.__from_file_yaml(file_path) - if message and (not mfilter or (mfilter and message.match(mfilter))): + if message and (mfilter is None or message.match(mfilter)): return message else: return None @@ -508,7 +508,7 @@ class Message(): Return True if all attributes match, else False. """ mytags = self.tags or set() - if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not) + if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None) and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 diff --git a/tests/test_chat.py b/tests/test_chat.py index ed630a4..1916a2b 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -202,7 +202,25 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0003.txt')) - def test_chat_db_filter(self) -> None: + def test_chat_db_from_dir_filter_tags(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(tags_or={Tag('tag1')})) + self.assertEqual(len(chat_db.messages), 1) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + + def test_chat_db_from_dir_filter_tags_empty(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(tags_or=set(), + tags_and=set(), + tags_not=set())) + self.assertEqual(len(chat_db.messages), 0) + + def test_chat_db_from_dir_filter_answer(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(answer_contains='Answer 2')) @@ -213,7 +231,7 @@ class TestChatDB(unittest.TestCase): pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[0].answer, 'Answer 2') - def test_chat_db_from_messges(self) -> None: + def test_chat_db_from_messages(self) -> None: chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), messages=[self.message1, self.message2, diff --git a/tests/test_message.py b/tests/test_message.py index 57d5982..1f440df 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -300,6 +300,12 @@ This is a question. MessageFilter(tags_or={Tag('tag1')})) self.assertIsNone(message) + def test_from_file_txt_empty_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or=set(), + tags_and=set())) + self.assertIsNone(message) + def test_from_file_txt_no_tags_match_tags_not(self) -> None: message = Message.from_file(self.file_path_min, MessageFilter(tags_not={Tag('tag1')})) -- 2.36.6 From 6f71a2ff691105b25593ae00d5053443a1ab768b Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 19:56:50 +0200 Subject: [PATCH 097/170] message: to_file() now uses intermediate temporary file --- chatmastermind/message.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index df59ed6..64929a3 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -3,6 +3,8 @@ Module implementing message related functions and classes. """ import pathlib import yaml +import tempfile +import shutil from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags, rename_tags @@ -445,16 +447,18 @@ class Message(): * Answer.txt_header * Answer """ - with open(file_path, "w") as fd: + with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: + temp_file_path = pathlib.Path(temp_fd.name) if self.tags: - fd.write(f'{TagLine.from_set(self.tags)}\n') + temp_fd.write(f'{TagLine.from_set(self.tags)}\n') if self.ai: - fd.write(f'{AILine.from_ai(self.ai)}\n') + temp_fd.write(f'{AILine.from_ai(self.ai)}\n') if self.model: - fd.write(f'{ModelLine.from_model(self.model)}\n') - fd.write(f'{Question.txt_header}\n{self.question}\n') + temp_fd.write(f'{ModelLine.from_model(self.model)}\n') + temp_fd.write(f'{Question.txt_header}\n{self.question}\n') if self.answer: - fd.write(f'{Answer.txt_header}\n{self.answer}\n') + temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') + shutil.move(temp_file_path, file_path) def __to_file_yaml(self, file_path: pathlib.Path) -> None: """ @@ -466,7 +470,8 @@ class Message(): * Message.ai_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional] """ - with open(file_path, "w") as fd: + with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: + temp_file_path = pathlib.Path(temp_fd.name) data: YamlDict = {Question.yaml_key: str(self.question)} if self.answer: data[Answer.yaml_key] = str(self.answer) @@ -476,7 +481,8 @@ class Message(): data[self.model_yaml_key] = self.model if self.tags: data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) - yaml.dump(data, fd, sort_keys=False) + yaml.dump(data, temp_fd, sort_keys=False) + shutil.move(temp_file_path, file_path) def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ -- 2.36.6 From 59b851650ad59ea61df4774c33ed7e624e98e13b Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:25:33 +0200 Subject: [PATCH 098/170] question_cmd: when no tags are specified, no tags are selected --- chatmastermind/commands/question.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index f439447..4936d8f 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -52,9 +52,9 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'question' command. """ - mfilter = MessageFilter(tags_or=args.or_tags, - tags_and=args.and_tags, - tags_not=args.exclude_tags) + mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(), + tags_and=args.and_tags if args.and_tags is not None else set(), + tags_not=args.exclude_tags if args.exclude_tags is not None else set()) chat = ChatDB.from_dir(cache_path=Path('.'), db_path=Path(config.db), mfilter=mfilter) -- 2.36.6 From c143c001f905dda3154a153ad8ccaed1bc24a5f4 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:37:06 +0200 Subject: [PATCH 099/170] configuration: improved config file format --- chatmastermind/configuration.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 398fa03..08f6cbe 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -17,6 +17,18 @@ class ConfigError(Exception): pass +def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: + """ + Changes the YAML dump style to multiline syntax for multiline strings. + """ + if len(data.splitlines()) > 1: + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + + +yaml.add_representer(str, str_presenter) + + @dataclass class AIConfig: """ @@ -48,13 +60,13 @@ class OpenAIConfig(AIConfig): # a default configuration ID: str = 'default' api_key: str = '0123456789' - system: str = 'You are an assistant' model: str = 'gpt-3.5-turbo-16k' temperature: float = 1.0 max_tokens: int = 4000 top_p: float = 1.0 frequency_penalty: float = 0.0 presence_penalty: float = 0.0 + system: str = 'You are an assistant' @classmethod def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: @@ -62,14 +74,14 @@ class OpenAIConfig(AIConfig): Create OpenAIConfig from a dict. """ res = cls( - system=str(source['system']), api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), temperature=float(source['temperature']), top_p=float(source['top_p']), frequency_penalty=float(source['frequency_penalty']), - presence_penalty=float(source['presence_penalty']) + presence_penalty=float(source['presence_penalty']), + system=str(source['system']) ) # overwrite default ID if provided if 'ID' in source: @@ -148,6 +160,8 @@ class Config: def as_dict(self) -> dict[str, Any]: res = asdict(self) + # add the AI name manually (as first element) + # (not done by 'asdict' because it's a class variable) for ID, conf in res['ais'].items(): - conf.update({'name': self.ais[ID].name}) + res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf} return res -- 2.36.6 From d4021eeb110c4d7e9ac0ee41f68e92ad1e12cf22 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 11 Sep 2023 07:38:49 +0200 Subject: [PATCH 100/170] configuration: made 'default' AI ID optional --- chatmastermind/ai_factory.py | 18 ++++++++++++------ chatmastermind/configuration.py | 3 +-- tests/test_ai_factory.py | 4 ++-- tests/test_configuration.py | 14 +++++++------- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index bc4583c..420b287 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -4,25 +4,31 @@ Creates different AI instances, based on the given configuration. import argparse from typing import cast -from .configuration import Config, OpenAIConfig, default_ai_ID +from .configuration import Config, AIConfig, OpenAIConfig from .ai import AI, AIError from .ais.openai import OpenAI -def create_ai(args: argparse.Namespace, config: Config) -> AI: +def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 """ Creates an AI subclass instance from the given arguments - and configuration file. + and configuration file. If AI has not been set in the + arguments, it searches for the ID 'default'. If that + is not found, it uses the first AI in the list. """ + ai_conf: AIConfig if args.AI: try: ai_conf = config.ais[args.AI] except KeyError: raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") - elif default_ai_ID in config.ais: - ai_conf = config.ais[default_ai_ID] + elif 'default' in config.ais: + ai_conf = config.ais['default'] else: - raise AIError("No AI name given and no default exists") + try: + ai_conf = next(iter(config.ais.values())) + except StopIteration: + raise AIError("No AI found in this configuration") if ai_conf.name == 'openai': ai = OpenAI(cast(OpenAIConfig, ai_conf)) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 08f6cbe..5397f4a 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -9,7 +9,6 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') supported_ais: list[str] = ['openai'] -default_ai_ID: str = 'default' default_config_path = '.config.yaml' @@ -58,7 +57,7 @@ class OpenAIConfig(AIConfig): # all members have default values, so we can easily create # a default configuration - ID: str = 'default' + ID: str = 'myopenai' api_key: str = '0123456789' model: str = 'gpt-3.5-turbo-16k' temperature: float = 1.0 diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py index d00b319..9cb94d3 100644 --- a/tests/test_ai_factory.py +++ b/tests/test_ai_factory.py @@ -10,7 +10,7 @@ from chatmastermind.ais.openai import OpenAI class TestCreateAI(unittest.TestCase): def setUp(self) -> None: self.args = MagicMock(spec=argparse.Namespace) - self.args.AI = 'default' + self.args.AI = 'myopenai' self.args.model = None self.args.max_tokens = None self.args.temperature = None @@ -18,7 +18,7 @@ class TestCreateAI(unittest.TestCase): def test_create_ai_from_args(self) -> None: # Create an AI with the default configuration config = Config() - self.args.AI = 'default' + self.args.AI = 'myopenai' ai = create_ai(self.args, config) self.assertIsInstance(ai, OpenAI) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index f3f9a98..ba8a5aa 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -59,7 +59,7 @@ class TestConfig(unittest.TestCase): source_dict = { 'db': './test_db/', 'ais': { - 'default': { + 'myopenai': { 'name': 'openai', 'system': 'Custom system', 'api_key': '9876543210', @@ -75,10 +75,10 @@ class TestConfig(unittest.TestCase): config = Config.from_dict(source_dict) self.assertEqual(config.db, './test_db/') self.assertEqual(len(config.ais), 1) - self.assertEqual(config.ais['default'].name, 'openai') - self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + self.assertEqual(config.ais['myopenai'].name, 'openai') + self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system') # check that 'ID' has been added - self.assertEqual(config.ais['default'].ID, 'default') + self.assertEqual(config.ais['myopenai'].ID, 'myopenai') def test_create_default_should_create_default_config(self) -> None: Config.create_default(Path(self.test_file.name)) @@ -117,8 +117,8 @@ class TestConfig(unittest.TestCase): config = Config( db='./test_db/', ais={ - 'default': OpenAIConfig( - ID='default', + 'myopenai': OpenAIConfig( + ID='myopenai', system='Custom system', api_key='9876543210', model='custom_model', @@ -135,7 +135,7 @@ class TestConfig(unittest.TestCase): saved_config = yaml.load(f, Loader=yaml.FullLoader) self.assertEqual(saved_config['db'], './test_db/') self.assertEqual(len(saved_config['ais']), 1) - self.assertEqual(saved_config['ais']['default']['system'], 'Custom system') + self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system') def test_from_file_error_unknown_ai(self) -> None: source_dict = { -- 2.36.6 From 8bd659e888b37a45faf133b2ac2f4eaaca825a39 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 16 Aug 2023 17:07:01 +0200 Subject: [PATCH 101/170] added new module 'tags.py' with classes 'Tag' and 'TagLine' --- chatmastermind/tags.py | 130 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 chatmastermind/tags.py diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py new file mode 100644 index 0000000..28583a2 --- /dev/null +++ b/chatmastermind/tags.py @@ -0,0 +1,130 @@ +""" +Module implementing tag related functions and classes. +""" +from typing import Type, TypeVar, Optional + +TagInst = TypeVar('TagInst', bound='Tag') +TagLineInst = TypeVar('TagLineInst', bound='TagLine') + + +class TagError(Exception): + pass + + +class Tag(str): + """ + A single tag. A string that can contain anything but the default separator (' '). + """ + # default separator + default_separator = ' ' + # alternative separators (e. g. for backwards compatibility) + alternative_separators = [','] + + def __new__(cls: Type[TagInst], string: str) -> TagInst: + """ + Make sure the tag string does not contain the default separator. + """ + if cls.default_separator in string: + raise TagError(f"Tag '{string}' contains the separator char '{cls.default_separator}'") + instance = super().__new__(cls, string) + return instance + + +class TagLine(str): + """ + A line of tags. It starts with a prefix ('TAGS:'), followed by a list of tags, + separated by the defaut separator (' '). Any operations on a TagLine will sort + the tags. + """ + # the prefix + prefix = 'TAGS:' + + def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst: + """ + Make sure the tagline string starts with the prefix. + """ + if not string.startswith(cls.prefix): + raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_set(cls: Type[TagLineInst], tags: set[Tag]) -> TagLineInst: + """ + Create a new TagLine from a set of tags. + """ + return cls(' '.join([TagLine.prefix] + sorted([t for t in tags]))) + + def tags(self) -> set[Tag]: + """ + Returns all tags contained in this line as a set. + """ + tagstr = self[len(self.prefix):].strip() + separator = Tag.default_separator + # look for alternative separators and use the first one found + # -> we don't support different separators in the same TagLine + for s in Tag.alternative_separators: + if s in tagstr: + separator = s + break + return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + + def merge(self, taglines: set['TagLine']) -> 'TagLine': + """ + Merges the tags of all given taglines into the current one + and returns a new TagLine. + """ + merged_tags = self.tags() + for tl in taglines: + merged_tags |= tl.tags() + return self.from_set(set(sorted(merged_tags))) + + def delete_tags(self, tags: set[Tag]) -> 'TagLine': + """ + Deletes the given tags and returns a new TagLine. + """ + return self.from_set(self.tags().difference(tags)) + + def add_tags(self, tags: set[Tag]) -> 'TagLine': + """ + Adds the given tags and returns a new TagLine. + """ + return self.from_set(set(sorted(self.tags() | tags))) + + def rename_tags(self, tags: set[tuple[Tag, Tag]]) -> 'TagLine': + """ + Renames the given tags and returns a new TagLine. The first + tuple element is the old name, the second one is the new name. + """ + new_tags = self.tags() + for t in tags: + if t[0] in new_tags: + new_tags.remove(t[0]) + new_tags.add(t[1]) + return self.from_set(set(sorted(new_tags))) + + def match_tags(self, tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], + tags_not: Optional[set[Tag]]) -> bool: + """ + Checks if the current TagLine matches the given tag requirements: + - 'tags_or' : matches if this TagLine contains ANY of those tags + - 'tags_and': matches if this TagLine contains ALL of those tags + - 'tags_not': matches if this TagLine contains NONE of those tags + + Note that it's sufficient if the TagLine matches one of 'tags_or' or 'tags_and', + i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' + or all of the tags in 'tags_and' but it must never contain any of the tags in + 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag + exclusion is still done if 'tags_not' is not 'None'). + """ + tag_set = self.tags() + required_tags_present = False + excluded_tags_missing = False + if ((tags_or is None and tags_and is None) + or (tags_or and any(tag in tag_set for tag in tags_or)) # noqa: W503 + or (tags_and and all(tag in tag_set for tag in tags_and))): # noqa: W503 + required_tags_present = True + if ((tags_not is None) + or (not any(tag in tag_set for tag in tags_not))): # noqa: W503 + excluded_tags_missing = True + return required_tags_present and excluded_tags_missing -- 2.36.6 From 2d456e68f187cbd89cf048865bb5256abf2630c0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 17 Aug 2023 08:28:15 +0200 Subject: [PATCH 102/170] added testcases for Tag and TagLine classes --- tests/test_main.py | 114 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index db5fcdb..eb13dc5 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -7,6 +7,7 @@ from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data +from chatmastermind.tags import Tag, TagLine, TagError from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -231,3 +232,116 @@ class TestCreateParser(CmmTestCase): mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) + + +class TestTag(CmmTestCase): + def test_valid_tag(self) -> None: + tag = Tag('mytag') + self.assertEqual(tag, 'mytag') + + def test_invalid_tag(self) -> None: + with self.assertRaises(TagError): + Tag('tag with space') + + def test_default_separator(self) -> None: + self.assertEqual(Tag.default_separator, ' ') + + def test_alternative_separators(self) -> None: + self.assertEqual(Tag.alternative_separators, [',']) + + +class TestTagLine(CmmTestCase): + def test_valid_tagline(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_invalid_tagline(self) -> None: + with self.assertRaises(TagError): + TagLine('tag1 tag2') + + def test_prefix(self) -> None: + self.assertEqual(TagLine.prefix, 'TAGS:') + + def test_from_set(self) -> None: + tags = {Tag('tag1'), Tag('tag2')} + tagline = TagLine.from_set(tags) + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_merge(self) -> None: + tagline1 = TagLine('TAGS: tag1 tag2') + tagline2 = TagLine('TAGS: tag2 tag3') + merged_tagline = tagline1.merge({tagline2}) + self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3') + + def test_delete_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag2') + + def test_add_tags(self) -> None: + tagline = TagLine('TAGS: tag1') + new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3') + + def test_rename_tags(self) -> None: + tagline = TagLine('TAGS: old1 old2') + new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))}) + self.assertEqual(new_tagline, 'TAGS: new1 new2') + + def test_match_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + + # Test case 1: Match any tag in 'tags_or' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and: set[Tag] = set() + tags_not: set[Tag] = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 2: Match all tags in 'tags_and' + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = {Tag('tag5')} + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 5: No matching tags in 'tags_or' + tags_or = {Tag('tag4'), Tag('tag5')} + tags_and = set() + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 6: Not all tags in 'tags_and' are present + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')} + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 7: Some tags in 'tags_not' are present + tags_or = {Tag('tag1')} + tags_and = set() + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 8: 'tags_or' and 'tags_and' are None, match all tags + tags_not = set() + self.assertTrue(tagline.match_tags(None, None, tags_not)) + + # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(None, None, tags_not)) -- 2.36.6 From 061e5f8682be086d641db1dc6c0e02a10910ee83 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 18 Aug 2023 12:11:56 +0200 Subject: [PATCH 103/170] tags.py: converted most TagLine functions to module functions --- chatmastermind/tags.py | 99 ++++++++++++++++++++++++++++++------------ 1 file changed, 71 insertions(+), 28 deletions(-) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index 28583a2..bfe5fd5 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -30,6 +30,67 @@ class Tag(str): return instance +def delete_tags(tags: set[Tag], tags_delete: set[Tag]) -> set[Tag]: + """ + Deletes the given tags and returns a new set. + """ + return tags.difference(tags_delete) + + +def add_tags(tags: set[Tag], tags_add: set[Tag]) -> set[Tag]: + """ + Adds the given tags and returns a new set. + """ + return set(sorted(tags | tags_add)) + + +def merge_tags(tags: set[Tag], tags_merge: list[set[Tag]]) -> set[Tag]: + """ + Merges the tags in 'tags_merge' into the current one and returns a new set. + """ + for ts in tags_merge: + tags |= ts + return tags + + +def rename_tags(tags: set[Tag], tags_rename: set[tuple[Tag, Tag]]) -> set[Tag]: + """ + Renames the given tags and returns a new set. The first tuple element + is the old name, the second one is the new name. + """ + for t in tags_rename: + if t[0] in tags: + tags.remove(t[0]) + tags.add(t[1]) + return set(sorted(tags)) + + +def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], + tags_not: Optional[set[Tag]]) -> bool: + """ + Checks if the given set 'tags' matches the given tag requirements: + - 'tags_or' : matches if this TagLine contains ANY of those tags + - 'tags_and': matches if this TagLine contains ALL of those tags + - 'tags_not': matches if this TagLine contains NONE of those tags + + Note that it's sufficient if 'tags' matches one of 'tags_or' or 'tags_and', + i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' + or all of the tags in 'tags_and' but it must never contain any of the tags in + 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag + exclusion is still done if 'tags_not' is not 'None'). + """ + required_tags_present = False + excluded_tags_missing = False + if ((tags_or is None and tags_and is None) + or (tags_or and any(tag in tags for tag in tags_or)) # noqa: W503 + or (tags_and and all(tag in tags for tag in tags_and))): # noqa: W503 + required_tags_present = True + if ((tags_not is None) + or (not any(tag in tags for tag in tags_not))): # noqa: W503 + excluded_tags_missing = True + return required_tags_present and excluded_tags_missing + + class TagLine(str): """ A line of tags. It starts with a prefix ('TAGS:'), followed by a list of tags, @@ -71,37 +132,29 @@ class TagLine(str): def merge(self, taglines: set['TagLine']) -> 'TagLine': """ - Merges the tags of all given taglines into the current one - and returns a new TagLine. + Merges the tags of all given taglines into the current one and returns a new TagLine. """ - merged_tags = self.tags() - for tl in taglines: - merged_tags |= tl.tags() - return self.from_set(set(sorted(merged_tags))) + tags_merge = [tl.tags() for tl in taglines] + return self.from_set(merge_tags(self.tags(), tags_merge)) - def delete_tags(self, tags: set[Tag]) -> 'TagLine': + def delete_tags(self, tags_delete: set[Tag]) -> 'TagLine': """ Deletes the given tags and returns a new TagLine. """ - return self.from_set(self.tags().difference(tags)) + return self.from_set(delete_tags(self.tags(), tags_delete)) - def add_tags(self, tags: set[Tag]) -> 'TagLine': + def add_tags(self, tags_add: set[Tag]) -> 'TagLine': """ Adds the given tags and returns a new TagLine. """ - return self.from_set(set(sorted(self.tags() | tags))) + return self.from_set(add_tags(self.tags(), tags_add)) - def rename_tags(self, tags: set[tuple[Tag, Tag]]) -> 'TagLine': + def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> 'TagLine': """ Renames the given tags and returns a new TagLine. The first tuple element is the old name, the second one is the new name. """ - new_tags = self.tags() - for t in tags: - if t[0] in new_tags: - new_tags.remove(t[0]) - new_tags.add(t[1]) - return self.from_set(set(sorted(new_tags))) + return self.from_set(rename_tags(self.tags(), tags_rename)) def match_tags(self, tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]], tags_not: Optional[set[Tag]]) -> bool: @@ -117,14 +170,4 @@ class TagLine(str): 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag exclusion is still done if 'tags_not' is not 'None'). """ - tag_set = self.tags() - required_tags_present = False - excluded_tags_missing = False - if ((tags_or is None and tags_and is None) - or (tags_or and any(tag in tag_set for tag in tags_or)) # noqa: W503 - or (tags_and and all(tag in tag_set for tag in tags_and))): # noqa: W503 - required_tags_present = True - if ((tags_not is None) - or (not any(tag in tag_set for tag in tags_not))): # noqa: W503 - excluded_tags_missing = True - return required_tags_present and excluded_tags_missing + return match_tags(self.tags(), tags_or, tags_and, tags_not) -- 2.36.6 From 264979a60dff629abf36ea464c921c422c64c0f0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 18 Aug 2023 16:07:50 +0200 Subject: [PATCH 104/170] added new module 'message.py' --- chatmastermind/message.py | 430 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 430 insertions(+) create mode 100644 chatmastermind/message.py diff --git a/chatmastermind/message.py b/chatmastermind/message.py new file mode 100644 index 0000000..157cd46 --- /dev/null +++ b/chatmastermind/message.py @@ -0,0 +1,430 @@ +""" +Module implementing message related functions and classes. +""" +import pathlib +import yaml +from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal +from dataclasses import dataclass, asdict, field +from .tags import Tag, TagLine, TagError, match_tags + +QuestionInst = TypeVar('QuestionInst', bound='Question') +AnswerInst = TypeVar('AnswerInst', bound='Answer') +MessageInst = TypeVar('MessageInst', bound='Message') +AILineInst = TypeVar('AILineInst', bound='AILine') +ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine') +YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]] + + +class MessageError(Exception): + pass + + +def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: + """ + Changes the YAML dump style to multiline syntax for multiline strings. + """ + if len(data.splitlines()) > 1: + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + + +yaml.add_representer(str, str_presenter) + + +def source_code(text: str, include_delims: bool = False) -> list[str]: + """ + Extract all source code sections from the given text, i. e. all lines + surrounded by lines tarting with '```'. If 'include_delims' is True, + the surrounding lines are included, otherwise they are omitted. The + result list contains every source code section as a single string. + The order in the list represents the order of the sections in the text. + """ + code_sections: list[str] = [] + code_lines: list[str] = [] + in_code_block = False + + for line in text.split('\n'): + if line.strip().startswith('```'): + if include_delims: + code_lines.append(line) + if in_code_block: + code_sections.append('\n'.join(code_lines) + '\n') + code_lines.clear() + in_code_block = not in_code_block + elif in_code_block: + code_lines.append(line) + + return code_sections + + +@dataclass(kw_only=True) +class MessageFilter: + """ + Various filters for a Message. + """ + tags_or: Optional[set[Tag]] = None + tags_and: Optional[set[Tag]] = None + tags_not: Optional[set[Tag]] = None + ai: Optional[str] = None + model: Optional[str] = None + question_contains: Optional[str] = None + answer_contains: Optional[str] = None + answer_state: Optional[Literal['available', 'missing']] = None + ai_state: Optional[Literal['available', 'missing']] = None + model_state: Optional[Literal['available', 'missing']] = None + + +class AILine(str): + """ + A line that represents the AI name in a '.txt' file.. + """ + prefix: Final[str] = 'AI:' + + def __new__(cls: Type[AILineInst], string: str) -> AILineInst: + if not string.startswith(cls.prefix): + raise TagError(f"AILine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + def ai(self) -> str: + return self[len(self.prefix):].strip() + + @classmethod + def from_ai(cls: Type[AILineInst], ai: str) -> AILineInst: + return cls(' '.join([cls.prefix, ai])) + + +class ModelLine(str): + """ + A line that represents the model name in a '.txt' file.. + """ + prefix: Final[str] = 'MODEL:' + + def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst: + if not string.startswith(cls.prefix): + raise TagError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") + instance = super().__new__(cls, string) + return instance + + def model(self) -> str: + return self[len(self.prefix):].strip() + + @classmethod + def from_model(cls: Type[ModelLineInst], model: str) -> ModelLineInst: + return cls(' '.join([cls.prefix, model])) + + +class Question(str): + """ + A single question with a defined header. + """ + txt_header: ClassVar[str] = '=== QUESTION ===' + yaml_key: ClassVar[str] = 'question' + + def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: + """ + Make sure the question string does not contain the header. + """ + if cls.txt_header in string: + raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: + """ + Build Question from a list of strings. Make sure strings do not contain the header. + """ + if any(cls.txt_header in string for string in strings): + raise MessageError(f"Question contains the header '{cls.txt_header}'") + instance = super().__new__(cls, '\n'.join(strings).strip()) + return instance + + def source_code(self, include_delims: bool = False) -> list[str]: + """ + Extract and return all source code sections. + """ + return source_code(self, include_delims) + + +class Answer(str): + """ + A single answer with a defined header. + """ + txt_header: ClassVar[str] = '=== ANSWER ===' + yaml_key: ClassVar[str] = 'answer' + + def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: + """ + Make sure the answer string does not contain the header. + """ + if cls.txt_header in string: + raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") + instance = super().__new__(cls, string) + return instance + + @classmethod + def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: + """ + Build Question from a list of strings. Make sure strings do not contain the header. + """ + if any(cls.txt_header in string for string in strings): + raise MessageError(f"Question contains the header '{cls.txt_header}'") + instance = super().__new__(cls, '\n'.join(strings).strip()) + return instance + + def source_code(self, include_delims: bool = False) -> list[str]: + """ + Extract and return all source code sections. + """ + return source_code(self, include_delims) + + +@dataclass +class Message(): + """ + Single message. Consists of a question and optionally an answer, a set of tags + and a file path. + """ + question: Question + answer: Optional[Answer] = None + # metadata, ignored when comparing messages + tags: Optional[set[Tag]] = field(default=None, compare=False) + ai: Optional[str] = field(default=None, compare=False) + model: Optional[str] = field(default=None, compare=False) + file_path: Optional[pathlib.Path] = field(default=None, compare=False) + # class variables + file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml'] + tags_yaml_key: ClassVar[str] = 'tags' + file_yaml_key: ClassVar[str] = 'file_path' + ai_yaml_key: ClassVar[str] = 'ai' + model_yaml_key: ClassVar[str] = 'model' + + def __hash__(self) -> int: + """ + The hash value is computed based on immutable members. + """ + return hash((self.question, self.answer)) + + @classmethod + def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst: + """ + Create a Message from the given dict. + """ + return cls(question=data[Question.yaml_key], + answer=data.get(Answer.yaml_key, None), + tags=set(data.get(cls.tags_yaml_key, [])), + ai=data.get(cls.ai_yaml_key, None), + model=data.get(cls.model_yaml_key, None), + file_path=data.get(cls.file_yaml_key, None)) + + @classmethod + def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]: + """ + Return only the tags from the given Message file. + """ + if not file_path.exists(): + raise MessageError(f"Message file '{file_path}' does not exist") + if file_path.suffix not in cls.file_suffixes: + raise MessageError(f"File type '{file_path.suffix}' is not supported") + if file_path.suffix == '.txt': + with open(file_path, "r") as fd: + tags = TagLine(fd.readline()).tags() + else: # '.yaml' + with open(file_path, "r") as fd: + data = yaml.load(fd, Loader=yaml.FullLoader) + tags = set(sorted(data[cls.tags_yaml_key])) + return tags + + @classmethod + def from_file(cls: Type[MessageInst], file_path: pathlib.Path, + mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]: + """ + Create a Message from the given file. Returns 'None' if the message does + not fulfill the filter requirements. For TXT files, the tags are matched + before building the whole message. The other filters are applied afterwards. + """ + if not file_path.exists(): + raise MessageError(f"Message file '{file_path}' does not exist") + if file_path.suffix not in cls.file_suffixes: + raise MessageError(f"File type '{file_path.suffix}' is not supported") + + if file_path.suffix == '.txt': + message = cls.__from_file_txt(file_path, + mfilter.tags_or if mfilter else None, + mfilter.tags_and if mfilter else None, + mfilter.tags_not if mfilter else None) + else: + message = cls.__from_file_yaml(file_path) + if message and (not mfilter or (mfilter and message.match(mfilter))): + return message + else: + return None + + @classmethod + def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11 + tags_or: Optional[set[Tag]] = None, + tags_and: Optional[set[Tag]] = None, + tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]: + """ + Create a Message from the given TXT file. Expects the following file structures: + For '.txt': + * TagLine [Optional] + * AI [Optional] + * Model [Optional] + * Question.txt_header + * Question + * Answer.txt_header [Optional] + * Answer [Optional] + + Returns 'None' if the message does not fulfill the tag requirements. + """ + tags: set[Tag] = set() + question: Question + answer: Optional[Answer] = None + ai: Optional[str] = None + model: Optional[str] = None + with open(file_path, "r") as fd: + # TagLine (Optional) + try: + pos = fd.tell() + tags = TagLine(fd.readline()).tags() + except TagError: + fd.seek(pos) + if tags_or or tags_and or tags_not: + # match with an empty set if the file has no tags + if not match_tags(tags, tags_or, tags_and, tags_not): + return None + # AILine (Optional) + try: + pos = fd.tell() + ai = AILine(fd.readline()).ai() + except TagError: + fd.seek(pos) + # ModelLine (Optional) + try: + pos = fd.tell() + model = ModelLine(fd.readline()).model() + except TagError: + fd.seek(pos) + # Question and Answer + text = fd.read().strip().split('\n') + question_idx = text.index(Question.txt_header) + 1 + try: + answer_idx = text.index(Answer.txt_header) + question = Question.from_list(text[question_idx:answer_idx]) + answer = Answer.from_list(text[answer_idx + 1:]) + except ValueError: + question = Question.from_list(text[question_idx:]) + return cls(question, answer, tags, ai, model, file_path) + + @classmethod + def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst: + """ + Create a Message from the given YAML file. Expects the following file structures: + * Question.yaml_key: single or multiline string + * Answer.yaml_key: single or multiline string [Optional] + * Message.tags_yaml_key: list of strings [Optional] + * Message.ai_yaml_key: str [Optional] + * Message.model_yaml_key: str [Optional] + """ + with open(file_path, "r") as fd: + data = yaml.load(fd, Loader=yaml.FullLoader) + data[cls.file_yaml_key] = file_path + return cls.from_dict(data) + + def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 + """ + Write a Message to the given file. Type is determined based on the suffix. + Currently supported suffixes: ['.txt', '.yaml'] + """ + if file_path: + self.file_path = file_path + if not self.file_path: + raise MessageError("Got no valid path to write message") + if self.file_path.suffix not in self.file_suffixes: + raise MessageError(f"File type '{self.file_path.suffix}' is not supported") + # TXT + if self.file_path.suffix == '.txt': + return self.__to_file_txt(self.file_path) + elif self.file_path.suffix == '.yaml': + return self.__to_file_yaml(self.file_path) + + def __to_file_txt(self, file_path: pathlib.Path) -> None: + """ + Write a Message to the given file in TXT format. + Creates the following file structures: + * TagLine + * AI [Optional] + * Model [Optional] + * Question.txt_header + * Question + * Answer.txt_header + * Answer + """ + with open(file_path, "w") as fd: + if self.tags: + fd.write(f'{TagLine.from_set(self.tags)}\n') + if self.ai: + fd.write(f'{AILine.from_ai(self.ai)}\n') + if self.model: + fd.write(f'{ModelLine.from_model(self.model)}\n') + fd.write(f'{Question.txt_header}\n{self.question}\n') + if self.answer: + fd.write(f'{Answer.txt_header}\n{self.answer}\n') + + def __to_file_yaml(self, file_path: pathlib.Path) -> None: + """ + Write a Message to the given file in YAML format. + Creates the following file structures: + * Question.yaml_key: single or multiline string + * Answer.yaml_key: single or multiline string + * Message.tags_yaml_key: list of strings + * Message.ai_yaml_key: str [Optional] + * Message.model_yaml_key: str [Optional] + """ + with open(file_path, "w") as fd: + data: YamlDict = {Question.yaml_key: str(self.question)} + if self.answer: + data[Answer.yaml_key] = str(self.answer) + if self.ai: + data[self.ai_yaml_key] = self.ai + if self.model: + data[self.model_yaml_key] = self.model + if self.tags: + data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) + yaml.dump(data, fd, sort_keys=False) + + def match(self, mfilter: MessageFilter) -> bool: # noqa: 13 + """ + Matches the current Message to the given filter atttributes. + Return True if all attributes match, else False. + """ + mytags = self.tags or set() + if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not) + and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 + or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 + or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 + or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503 + or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503 + or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503 + or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503 + or (mfilter.model_state == 'available' and not self.model) # noqa: W503 + or (mfilter.answer_state == 'missing' and self.answer) # noqa: W503 + or (mfilter.ai_state == 'missing' and self.ai) # noqa: W503 + or (mfilter.model_state == 'missing' and self.model)): # noqa: W503 + return False + return True + + def msg_id(self) -> str: + """ + Returns an ID that is unique throughout all messages in the same (DB) directory. + Currently this is the file name. The ID is also used for sorting messages. + """ + if self.file_path: + return self.file_path.name + else: + raise MessageError("Can't create file ID without a file path") + + def as_dict(self) -> dict[str, Any]: + return asdict(self) -- 2.36.6 From 33567df15fcfbcf84118c59c15ecb055bf9b05da Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 18 Aug 2023 16:08:22 +0200 Subject: [PATCH 105/170] added testcases for messages.py --- tests/test_main.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/test_main.py b/tests/test_main.py index eb13dc5..8ce06cb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -8,6 +8,7 @@ from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data from chatmastermind.tags import Tag, TagLine, TagError +from chatmastermind.message import source_code, MessageError, Question, Answer from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -345,3 +346,79 @@ class TestTagLine(CmmTestCase): # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags tags_not = {Tag('tag2')} self.assertFalse(tagline.match_tags(None, None, tags_not)) + + +class SourceCodeTestCase(CmmTestCase): + def test_source_code_with_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " ```python\n print(\"Hello, World!\")\n ```\n", + " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" + ] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_without_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " print(\"Hello, World!\")\n", + " x = 10\n y = 20\n print(x + y)\n" + ] + result = source_code(text, include_delims=False) + self.assertEqual(result, expected_result) + + def test_source_code_with_single_code_block(self) -> None: + text = "```python\nprint(\"Hello, World!\")\n```" + expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_with_no_code_blocks(self) -> None: + text = "Some text without any code blocks" + expected_result: list[str] = [] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + +class QuestionTestCase(CmmTestCase): + def test_question_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Question("=== QUESTION === What is your name?") + + def test_question_without_prefix(self) -> None: + question = Question("What is your favorite color?") + self.assertIsInstance(question, Question) + self.assertEqual(question, "What is your favorite color?") + + +class AnswerTestCase(CmmTestCase): + def test_answer_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Answer("=== ANSWER === Yes") + + def test_answer_without_prefix(self) -> None: + answer = Answer("No") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, "No") -- 2.36.6 From 09da312657537d4eb802e7a79e4e2a9ef1f72e90 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 19 Aug 2023 08:04:41 +0200 Subject: [PATCH 106/170] configuration: added 'as_dict()' as an instance function --- chatmastermind/configuration.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 0037916..5ae32d6 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -63,4 +63,7 @@ class Config(): def to_file(self, path: str) -> None: with open(path, 'w') as f: - yaml.dump(asdict(self), f) + yaml.dump(asdict(self), f, sort_keys=False) + + def as_dict(self) -> dict[str, Any]: + return asdict(self) -- 2.36.6 From 30ccec2462a7610cf707cc13584df9bc3497b342 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 19 Aug 2023 08:30:24 +0200 Subject: [PATCH 107/170] tags: TagLine constructor now supports multiline taglines and multiple spaces --- chatmastermind/tags.py | 20 +++++++++++--------- tests/test_main.py | 9 +++++++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index bfe5fd5..544270c 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -1,7 +1,7 @@ """ Module implementing tag related functions and classes. """ -from typing import Type, TypeVar, Optional +from typing import Type, TypeVar, Optional, Final TagInst = TypeVar('TagInst', bound='Tag') TagLineInst = TypeVar('TagLineInst', bound='TagLine') @@ -16,9 +16,9 @@ class Tag(str): A single tag. A string that can contain anything but the default separator (' '). """ # default separator - default_separator = ' ' + default_separator: Final[str] = ' ' # alternative separators (e. g. for backwards compatibility) - alternative_separators = [','] + alternative_separators: Final[list[str]] = [','] def __new__(cls: Type[TagInst], string: str) -> TagInst: """ @@ -93,19 +93,21 @@ def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[s class TagLine(str): """ - A line of tags. It starts with a prefix ('TAGS:'), followed by a list of tags, - separated by the defaut separator (' '). Any operations on a TagLine will sort - the tags. + A line of tags in a '.txt' file. It starts with a prefix ('TAGS:'), followed by + a list of tags, separated by the defaut separator (' '). Any operations on a + TagLine will sort the tags. """ # the prefix - prefix = 'TAGS:' + prefix: Final[str] = 'TAGS:' def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst: """ - Make sure the tagline string starts with the prefix. + Make sure the tagline string starts with the prefix. Also replace newlines + and multiple spaces with ' ', in order to support multiline TagLines. """ if not string.startswith(cls.prefix): raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'") + string = ' '.join(string.split()) instance = super().__new__(cls, string) return instance @@ -114,7 +116,7 @@ class TagLine(str): """ Create a new TagLine from a set of tags. """ - return cls(' '.join([TagLine.prefix] + sorted([t for t in tags]))) + return cls(' '.join([cls.prefix] + sorted([t for t in tags]))) def tags(self) -> set[Tag]: """ diff --git a/tests/test_main.py b/tests/test_main.py index 8ce06cb..25cdc37 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -256,6 +256,10 @@ class TestTagLine(CmmTestCase): tagline = TagLine('TAGS: tag1 tag2') self.assertEqual(tagline, 'TAGS: tag1 tag2') + def test_valid_tagline_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + def test_invalid_tagline(self) -> None: with self.assertRaises(TagError): TagLine('tag1 tag2') @@ -273,6 +277,11 @@ class TestTagLine(CmmTestCase): tags = tagline.tags() self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_tags_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_merge(self) -> None: tagline1 = TagLine('TAGS: tag1 tag2') tagline2 = TagLine('TAGS: tag2 tag3') -- 2.36.6 From c0f50bace5d8f0e55ca1ed46cd7ff9c589c15a83 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 21 Aug 2023 08:29:48 +0200 Subject: [PATCH 108/170] gitignore: added vim session file --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 4ade1df..89bf5fb 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,5 @@ dmypy.json .config.yaml db -noweb \ No newline at end of file +noweb +Session.vim -- 2.36.6 From acec5f1d552d5537120cc0b496f97cf3c9fadefe Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 20 Aug 2023 08:46:03 +0200 Subject: [PATCH 109/170] tests: splitted 'test_main.py' into 3 modules --- tests/test_main.py | 200 ------------------------------------------ tests/test_message.py | 78 ++++++++++++++++ tests/test_tags.py | 124 ++++++++++++++++++++++++++ 3 files changed, 202 insertions(+), 200 deletions(-) create mode 100644 tests/test_message.py create mode 100644 tests/test_tags.py diff --git a/tests/test_main.py b/tests/test_main.py index 25cdc37..db5fcdb 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -7,8 +7,6 @@ from chatmastermind.main import create_parser, ask_cmd from chatmastermind.api_client import ai from chatmastermind.configuration import Config from chatmastermind.storage import create_chat_hist, save_answers, dump_data -from chatmastermind.tags import Tag, TagLine, TagError -from chatmastermind.message import source_code, MessageError, Question, Answer from unittest import mock from unittest.mock import patch, MagicMock, Mock, ANY @@ -233,201 +231,3 @@ class TestCreateParser(CmmTestCase): mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) - - -class TestTag(CmmTestCase): - def test_valid_tag(self) -> None: - tag = Tag('mytag') - self.assertEqual(tag, 'mytag') - - def test_invalid_tag(self) -> None: - with self.assertRaises(TagError): - Tag('tag with space') - - def test_default_separator(self) -> None: - self.assertEqual(Tag.default_separator, ' ') - - def test_alternative_separators(self) -> None: - self.assertEqual(Tag.alternative_separators, [',']) - - -class TestTagLine(CmmTestCase): - def test_valid_tagline(self) -> None: - tagline = TagLine('TAGS: tag1 tag2') - self.assertEqual(tagline, 'TAGS: tag1 tag2') - - def test_valid_tagline_with_newline(self) -> None: - tagline = TagLine('TAGS: tag1\n tag2') - self.assertEqual(tagline, 'TAGS: tag1 tag2') - - def test_invalid_tagline(self) -> None: - with self.assertRaises(TagError): - TagLine('tag1 tag2') - - def test_prefix(self) -> None: - self.assertEqual(TagLine.prefix, 'TAGS:') - - def test_from_set(self) -> None: - tags = {Tag('tag1'), Tag('tag2')} - tagline = TagLine.from_set(tags) - self.assertEqual(tagline, 'TAGS: tag1 tag2') - - def test_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2') - tags = tagline.tags() - self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) - - def test_tags_with_newline(self) -> None: - tagline = TagLine('TAGS: tag1\n tag2') - tags = tagline.tags() - self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) - - def test_merge(self) -> None: - tagline1 = TagLine('TAGS: tag1 tag2') - tagline2 = TagLine('TAGS: tag2 tag3') - merged_tagline = tagline1.merge({tagline2}) - self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3') - - def test_delete_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2 tag3') - new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')}) - self.assertEqual(new_tagline, 'TAGS: tag2') - - def test_add_tags(self) -> None: - tagline = TagLine('TAGS: tag1') - new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')}) - self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3') - - def test_rename_tags(self) -> None: - tagline = TagLine('TAGS: old1 old2') - new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))}) - self.assertEqual(new_tagline, 'TAGS: new1 new2') - - def test_match_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2 tag3') - - # Test case 1: Match any tag in 'tags_or' - tags_or = {Tag('tag1'), Tag('tag4')} - tags_and: set[Tag] = set() - tags_not: set[Tag] = set() - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 2: Match all tags in 'tags_and' - tags_or = set() - tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')} - tags_not = set() - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and' - tags_or = {Tag('tag1'), Tag('tag4')} - tags_and = {Tag('tag1'), Tag('tag2')} - tags_not = set() - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not' - tags_or = {Tag('tag1'), Tag('tag4')} - tags_and = {Tag('tag1'), Tag('tag2')} - tags_not = {Tag('tag5')} - self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 5: No matching tags in 'tags_or' - tags_or = {Tag('tag4'), Tag('tag5')} - tags_and = set() - tags_not = set() - self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 6: Not all tags in 'tags_and' are present - tags_or = set() - tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')} - tags_not = set() - self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 7: Some tags in 'tags_not' are present - tags_or = {Tag('tag1')} - tags_and = set() - tags_not = {Tag('tag2')} - self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) - - # Test case 8: 'tags_or' and 'tags_and' are None, match all tags - tags_not = set() - self.assertTrue(tagline.match_tags(None, None, tags_not)) - - # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags - tags_not = {Tag('tag2')} - self.assertFalse(tagline.match_tags(None, None, tags_not)) - - -class SourceCodeTestCase(CmmTestCase): - def test_source_code_with_include_delims(self) -> None: - text = """ - Some text before the code block - ```python - print("Hello, World!") - ``` - Some text after the code block - ```python - x = 10 - y = 20 - print(x + y) - ``` - """ - expected_result = [ - " ```python\n print(\"Hello, World!\")\n ```\n", - " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" - ] - result = source_code(text, include_delims=True) - self.assertEqual(result, expected_result) - - def test_source_code_without_include_delims(self) -> None: - text = """ - Some text before the code block - ```python - print("Hello, World!") - ``` - Some text after the code block - ```python - x = 10 - y = 20 - print(x + y) - ``` - """ - expected_result = [ - " print(\"Hello, World!\")\n", - " x = 10\n y = 20\n print(x + y)\n" - ] - result = source_code(text, include_delims=False) - self.assertEqual(result, expected_result) - - def test_source_code_with_single_code_block(self) -> None: - text = "```python\nprint(\"Hello, World!\")\n```" - expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] - result = source_code(text, include_delims=True) - self.assertEqual(result, expected_result) - - def test_source_code_with_no_code_blocks(self) -> None: - text = "Some text without any code blocks" - expected_result: list[str] = [] - result = source_code(text, include_delims=True) - self.assertEqual(result, expected_result) - - -class QuestionTestCase(CmmTestCase): - def test_question_with_prefix(self) -> None: - with self.assertRaises(MessageError): - Question("=== QUESTION === What is your name?") - - def test_question_without_prefix(self) -> None: - question = Question("What is your favorite color?") - self.assertIsInstance(question, Question) - self.assertEqual(question, "What is your favorite color?") - - -class AnswerTestCase(CmmTestCase): - def test_answer_with_prefix(self) -> None: - with self.assertRaises(MessageError): - Answer("=== ANSWER === Yes") - - def test_answer_without_prefix(self) -> None: - answer = Answer("No") - self.assertIsInstance(answer, Answer) - self.assertEqual(answer, "No") diff --git a/tests/test_message.py b/tests/test_message.py new file mode 100644 index 0000000..220fef2 --- /dev/null +++ b/tests/test_message.py @@ -0,0 +1,78 @@ +from .test_main import CmmTestCase +from chatmastermind.message import source_code, MessageError, Question, Answer + + +class SourceCodeTestCase(CmmTestCase): + def test_source_code_with_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " ```python\n print(\"Hello, World!\")\n ```\n", + " ```python\n x = 10\n y = 20\n print(x + y)\n ```\n" + ] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_without_include_delims(self) -> None: + text = """ + Some text before the code block + ```python + print("Hello, World!") + ``` + Some text after the code block + ```python + x = 10 + y = 20 + print(x + y) + ``` + """ + expected_result = [ + " print(\"Hello, World!\")\n", + " x = 10\n y = 20\n print(x + y)\n" + ] + result = source_code(text, include_delims=False) + self.assertEqual(result, expected_result) + + def test_source_code_with_single_code_block(self) -> None: + text = "```python\nprint(\"Hello, World!\")\n```" + expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + def test_source_code_with_no_code_blocks(self) -> None: + text = "Some text without any code blocks" + expected_result: list[str] = [] + result = source_code(text, include_delims=True) + self.assertEqual(result, expected_result) + + +class QuestionTestCase(CmmTestCase): + def test_question_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Question("=== QUESTION === What is your name?") + + def test_question_without_prefix(self) -> None: + question = Question("What is your favorite color?") + self.assertIsInstance(question, Question) + self.assertEqual(question, "What is your favorite color?") + + +class AnswerTestCase(CmmTestCase): + def test_answer_with_prefix(self) -> None: + with self.assertRaises(MessageError): + Answer("=== ANSWER === Yes") + + def test_answer_without_prefix(self) -> None: + answer = Answer("No") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, "No") diff --git a/tests/test_tags.py b/tests/test_tags.py new file mode 100644 index 0000000..9ac9746 --- /dev/null +++ b/tests/test_tags.py @@ -0,0 +1,124 @@ +from .test_main import CmmTestCase +from chatmastermind.tags import Tag, TagLine, TagError + + +class TestTag(CmmTestCase): + def test_valid_tag(self) -> None: + tag = Tag('mytag') + self.assertEqual(tag, 'mytag') + + def test_invalid_tag(self) -> None: + with self.assertRaises(TagError): + Tag('tag with space') + + def test_default_separator(self) -> None: + self.assertEqual(Tag.default_separator, ' ') + + def test_alternative_separators(self) -> None: + self.assertEqual(Tag.alternative_separators, [',']) + + +class TestTagLine(CmmTestCase): + def test_valid_tagline(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_valid_tagline_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_invalid_tagline(self) -> None: + with self.assertRaises(TagError): + TagLine('tag1 tag2') + + def test_prefix(self) -> None: + self.assertEqual(TagLine.prefix, 'TAGS:') + + def test_from_set(self) -> None: + tags = {Tag('tag1'), Tag('tag2')} + tagline = TagLine.from_set(tags) + self.assertEqual(tagline, 'TAGS: tag1 tag2') + + def test_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_tags_with_newline(self) -> None: + tagline = TagLine('TAGS: tag1\n tag2') + tags = tagline.tags() + self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_merge(self) -> None: + tagline1 = TagLine('TAGS: tag1 tag2') + tagline2 = TagLine('TAGS: tag2 tag3') + merged_tagline = tagline1.merge({tagline2}) + self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3') + + def test_delete_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag2') + + def test_add_tags(self) -> None: + tagline = TagLine('TAGS: tag1') + new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')}) + self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3') + + def test_rename_tags(self) -> None: + tagline = TagLine('TAGS: old1 old2') + new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))}) + self.assertEqual(new_tagline, 'TAGS: new1 new2') + + def test_match_tags(self) -> None: + tagline = TagLine('TAGS: tag1 tag2 tag3') + + # Test case 1: Match any tag in 'tags_or' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and: set[Tag] = set() + tags_not: set[Tag] = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 2: Match all tags in 'tags_and' + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = set() + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not' + tags_or = {Tag('tag1'), Tag('tag4')} + tags_and = {Tag('tag1'), Tag('tag2')} + tags_not = {Tag('tag5')} + self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 5: No matching tags in 'tags_or' + tags_or = {Tag('tag4'), Tag('tag5')} + tags_and = set() + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 6: Not all tags in 'tags_and' are present + tags_or = set() + tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')} + tags_not = set() + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 7: Some tags in 'tags_not' are present + tags_or = {Tag('tag1')} + tags_and = set() + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not)) + + # Test case 8: 'tags_or' and 'tags_and' are None, match all tags + tags_not = set() + self.assertTrue(tagline.match_tags(None, None, tags_not)) + + # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags + tags_not = {Tag('tag2')} + self.assertFalse(tagline.match_tags(None, None, tags_not)) -- 2.36.6 From 9c2598a4b82db3b304caee342859ae3cc15bf0ed Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 20 Aug 2023 19:59:38 +0200 Subject: [PATCH 110/170] tests: added testcases for Message.from/to_file() and others --- tests/test_message.py | 545 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 544 insertions(+), 1 deletion(-) diff --git a/tests/test_message.py b/tests/test_message.py index 220fef2..0e326b4 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,5 +1,9 @@ +import pathlib +import tempfile +from typing import cast from .test_main import CmmTestCase -from chatmastermind.message import source_code, MessageError, Question, Answer +from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter +from chatmastermind.tags import Tag, TagLine class SourceCodeTestCase(CmmTestCase): @@ -76,3 +80,542 @@ class AnswerTestCase(CmmTestCase): answer = Answer("No") self.assertIsInstance(answer, Answer) self.assertEqual(answer, "No") + + +class MessageToFileTxtTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + self.message_complete = Message(Question('This is a question.'), + Answer('This is an answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_min = Message(Question('This is a question.'), + file_path=self.file_path) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_to_file_txt_complete(self) -> None: + self.message_complete.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{TagLine.prefix} tag1 tag2 +{AILine.prefix} ChatGPT +{ModelLine.prefix} gpt-3.5-turbo +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""" + self.assertEqual(content, expected_content) + + def test_to_file_txt_min(self) -> None: + self.message_min.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.txt_header} +This is a question. +""" + self.assertEqual(content, expected_content) + + def test_to_file_unsupported_file_type(self) -> None: + unsupported_file_path = pathlib.Path("example.doc") + with self.assertRaises(MessageError) as cm: + self.message_complete.to_file(unsupported_file_path) + self.assertEqual(str(cm.exception), "File type '.doc' is not supported") + + def test_to_file_no_file_path(self) -> None: + """ + Provoke an exception using an empty path. + """ + with self.assertRaises(MessageError) as cm: + # clear the internal file_path + self.message_complete.file_path = None + self.message_complete.to_file(None) + self.assertEqual(str(cm.exception), "Got no valid path to write message") + # reset the internal file_path + self.message_complete.file_path = self.file_path + + +class MessageToFileYamlTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path = pathlib.Path(self.file.name) + self.message_complete = Message(Question('This is a question.'), + Answer('This is an answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_multiline = Message(Question('This is a\nmultiline question.'), + Answer('This is a\nmultiline answer.'), + {Tag('tag1'), Tag('tag2')}, + ai='ChatGPT', + model='gpt-3.5-turbo', + file_path=self.file_path) + self.message_min = Message(Question('This is a question.'), + file_path=self.file_path) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_to_file_yaml_complete(self) -> None: + self.message_complete.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.yaml_key}: This is a question. +{Answer.yaml_key}: This is an answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: +- tag1 +- tag2 +""" + self.assertEqual(content, expected_content) + + def test_to_file_yaml_multiline(self) -> None: + self.message_multiline.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"""{Question.yaml_key}: |- + This is a + multiline question. +{Answer.yaml_key}: |- + This is a + multiline answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: +- tag1 +- tag2 +""" + self.assertEqual(content, expected_content) + + def test_to_file_yaml_min(self) -> None: + self.message_min.to_file(self.file_path) + + with open(self.file_path, "r") as fd: + content = fd.read() + expected_content = f"{Question.yaml_key}: This is a question.\n" + self.assertEqual(content, expected_content) + + +class MessageFromFileTxtTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + with open(self.file_path, "w") as fd: + fd.write(f"""{TagLine.prefix} tag1 tag2 +{AILine.prefix} ChatGPT +{ModelLine.prefix} gpt-3.5-turbo +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""") + self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_min = pathlib.Path(self.file_min.name) + with open(self.file_path_min, "w") as fd: + fd.write(f"""{Question.txt_header} +This is a question. +""") + + def tearDown(self) -> None: + self.file.close() + self.file_min.close() + self.file_path.unlink() + self.file_path_min.unlink() + + def test_from_file_txt_complete(self) -> None: + """ + Read a complete message (with all optional values). + """ + message = Message.from_file(self.file_path) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.ai, 'ChatGPT') + self.assertEqual(message.model, 'gpt-3.5-turbo') + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_txt_min(self) -> None: + """ + Read a message with only required values. + """ + message = Message.from_file(self.file_path_min) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.file_path, self.file_path_min) + self.assertIsNone(message.answer) + + def test_from_file_txt_tags_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_txt_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag3')})) + self.assertIsNone(message) + + def test_from_file_txt_no_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNone(message) + + def test_from_file_txt_no_tags_match_tags_not(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_not={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + + def test_from_file_not_exists(self) -> None: + file_not_exists = pathlib.Path("example.txt") + with self.assertRaises(MessageError) as cm: + Message.from_file(file_not_exists) + self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") + + def test_from_file_txt_question_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='question')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='answer')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_available(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='available')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_answer_missing(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='missing')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_question_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='answer')) + self.assertIsNone(message) + + def test_from_file_txt_answer_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='question')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_exists(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_contains='answer')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_available(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='available')) + self.assertIsNone(message) + + def test_from_file_txt_answer_not_missing(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='missing')) + self.assertIsNone(message) + + def test_from_file_txt_ai_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='ChatGPT')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_ai_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='Foo')) + self.assertIsNone(message) + + def test_from_file_txt_model_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='gpt-3.5-turbo')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_txt_model_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='Bar')) + self.assertIsNone(message) + + +class MessageFromFileYamlTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path = pathlib.Path(self.file.name) + with open(self.file_path, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. +{Message.ai_yaml_key}: ChatGPT +{Message.model_yaml_key}: gpt-3.5-turbo +{Message.tags_yaml_key}: + - tag1 + - tag2 +""") + self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_min = pathlib.Path(self.file_min.name) + with open(self.file_path_min, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +""") + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + self.file_min.close() + self.file_path_min.unlink() + + def test_from_file_yaml_complete(self) -> None: + """ + Read a complete message (with all optional values). + """ + message = Message.from_file(self.file_path) + self.assertIsInstance(message, Message) + self.assertIsNotNone(message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.ai, 'ChatGPT') + self.assertEqual(message.model, 'gpt-3.5-turbo') + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_yaml_min(self) -> None: + """ + Read a message with only the required values. + """ + message = Message.from_file(self.file_path_min) + self.assertIsInstance(message, Message) + self.assertIsNotNone(message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + self.assertIsNone(message.answer) + + def test_from_file_not_exists(self) -> None: + file_not_exists = pathlib.Path("example.yaml") + with self.assertRaises(MessageError) as cm: + Message.from_file(file_not_exists) + self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist") + + def test_from_file_yaml_tags_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertEqual(message.answer, 'This is an answer.') + self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')}) + self.assertEqual(message.file_path, self.file_path) + + def test_from_file_yaml_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(tags_or={Tag('tag3')})) + self.assertIsNone(message) + + def test_from_file_yaml_no_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or={Tag('tag1')})) + self.assertIsNone(message) + + def test_from_file_yaml_no_tags_match_tags_not(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_not={Tag('tag1')})) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + if message: # mypy bug + self.assertEqual(message.question, 'This is a question.') + self.assertSetEqual(cast(set[Tag], message.tags), set()) + self.assertEqual(message.file_path, self.file_path_min) + + def test_from_file_yaml_question_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='question')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='answer')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_available(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='available')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_answer_missing(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='missing')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_question_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(question_contains='answer')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_contains='question')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_exists(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_contains='answer')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_available(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(answer_state='available')) + self.assertIsNone(message) + + def test_from_file_yaml_answer_not_missing(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(answer_state='missing')) + self.assertIsNone(message) + + def test_from_file_yaml_ai_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='ChatGPT')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_ai_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(ai='Foo')) + self.assertIsNone(message) + + def test_from_file_yaml_model_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='gpt-3.5-turbo')) + self.assertIsNotNone(message) + self.assertIsInstance(message, Message) + + def test_from_file_yaml_model_doesnt_match(self) -> None: + message = Message.from_file(self.file_path, + MessageFilter(model='Bar')) + self.assertIsNone(message) + + +class TagsFromFileTestCase(CmmTestCase): + def setUp(self) -> None: + self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt = pathlib.Path(self.file_txt.name) + with open(self.file_path_txt, "w") as fd: + fd.write(f"""{TagLine.prefix} tag1 tag2 +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. +""") + self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_yaml = pathlib.Path(self.file_yaml.name) + with open(self.file_path_yaml, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. +{Message.tags_yaml_key}: + - tag1 + - tag2 +""") + + def tearDown(self) -> None: + self.file_txt.close() + self.file_path_txt.unlink() + self.file_yaml.close() + self.file_path_yaml.unlink() + + def test_tags_from_file_txt(self) -> None: + tags = Message.tags_from_file(self.file_path_txt) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + + def test_tags_from_file_yaml(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + + +class MessageIDTestCase(CmmTestCase): + def setUp(self) -> None: + self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path = pathlib.Path(self.file.name) + self.message = Message(Question('This is a question.'), + file_path=self.file_path) + self.message_no_file_path = Message(Question('This is a question.')) + + def tearDown(self) -> None: + self.file.close() + self.file_path.unlink() + + def test_msg_id_txt(self) -> None: + self.assertEqual(self.message.msg_id(), self.file_path.name) + + def test_msg_id_txt_exception(self) -> None: + with self.assertRaises(MessageError): + self.message_no_file_path.msg_id() + + +class MessageHashTestCase(CmmTestCase): + def setUp(self) -> None: + self.message1 = Message(Question('This is a question.'), + tags={Tag('tag1')}, + file_path=pathlib.Path('/tmp/foo/bla')) + self.message2 = Message(Question('This is a new question.'), + file_path=pathlib.Path('/tmp/foo/bla')) + self.message3 = Message(Question('This is a question.'), + Answer('This is an answer.'), + file_path=pathlib.Path('/tmp/foo/bla')) + # message4 is a copy of message1, because only question and + # answer are used for hashing and comparison + self.message4 = Message(Question('This is a question.'), + tags={Tag('tag1'), Tag('tag2')}, + ai='Blabla', + file_path=pathlib.Path('foobla')) + + def test_set_hashing(self) -> None: + msgs: set[Message] = {self.message1, self.message2, self.message3, self.message4} + self.assertEqual(len(msgs), 3) + for msg in [self.message1, self.message2, self.message3]: + self.assertIn(msg, msgs) -- 2.36.6 From 17f7b2fb452ccb946d6c9344d46c70d50ce86a06 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 26 Aug 2023 12:50:47 +0200 Subject: [PATCH 111/170] Added tags filtering (prefix and contained string) to TagLine and Message --- chatmastermind/message.py | 71 ++++++++++++++++++++++-- chatmastermind/tags.py | 12 +++- tests/test_message.py | 113 +++++++++++++++++++++++++++++++++++++- tests/test_tags.py | 22 +++++++- 4 files changed, 204 insertions(+), 14 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 157cd46..902aaa2 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -219,21 +219,57 @@ class Message(): file_path=data.get(cls.file_yaml_key, None)) @classmethod - def tags_from_file(cls: Type[MessageInst], file_path: pathlib.Path) -> set[Tag]: + def tags_from_file(cls: Type[MessageInst], + file_path: pathlib.Path, + prefix: Optional[str] = None, + contain: Optional[str] = None) -> set[Tag]: """ - Return only the tags from the given Message file. + Return only the tags from the given Message file, + optionally filtered based on prefix or contained string. """ + tags: set[Tag] = set() if not file_path.exists(): raise MessageError(f"Message file '{file_path}' does not exist") if file_path.suffix not in cls.file_suffixes: raise MessageError(f"File type '{file_path.suffix}' is not supported") + # for TXT, it's enough to read the TagLine if file_path.suffix == '.txt': with open(file_path, "r") as fd: - tags = TagLine(fd.readline()).tags() + try: + tags = TagLine(fd.readline()).tags(prefix, contain) + except TagError: + pass # message without tags else: # '.yaml' - with open(file_path, "r") as fd: - data = yaml.load(fd, Loader=yaml.FullLoader) - tags = set(sorted(data[cls.tags_yaml_key])) + try: + message = cls.from_file(file_path) + if message: + msg_tags = message.filter_tags(prefix=prefix, contain=contain) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") + if msg_tags: + tags = msg_tags + return tags + + @classmethod + def tags_from_dir(cls: Type[MessageInst], + path: pathlib.Path, + glob: Optional[str] = None, + prefix: Optional[str] = None, + contain: Optional[str] = None) -> set[Tag]: + + """ + Return only the tags from message files in the given directory. + The files can be filtered using 'glob', the tags by using 'prefix' + and 'contain'. + """ + tags: set[Tag] = set() + file_iter = path.glob(glob) if glob else path.iterdir() + for file_path in sorted(file_iter): + if file_path.is_file(): + try: + tags |= cls.tags_from_file(file_path, prefix, contain) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") return tags @classmethod @@ -395,6 +431,29 @@ class Message(): data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) yaml.dump(data, fd, sort_keys=False) + def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: + """ + Filter tags based on their prefix (i. e. the tag starts with a given string) + or some contained string. + """ + res_tags = self.tags + if res_tags: + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags or set() + + def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str: + """ + Returns all tags as a string with the TagLine prefix. Optionally filtered + using 'Message.filter_tags()'. + """ + if self.tags: + return str(TagLine.from_set(self.filter_tags(prefix, contain))) + else: + return str(TagLine.from_set(set())) + def match(self, mfilter: MessageFilter) -> bool: # noqa: 13 """ Matches the current Message to the given filter atttributes. diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index 544270c..c438db9 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -118,9 +118,10 @@ class TagLine(str): """ return cls(' '.join([cls.prefix] + sorted([t for t in tags]))) - def tags(self) -> set[Tag]: + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ - Returns all tags contained in this line as a set. + Returns all tags contained in this line as a set, optionally + filtered based on prefix or contained string. """ tagstr = self[len(self.prefix):].strip() separator = Tag.default_separator @@ -130,7 +131,12 @@ class TagLine(str): if s in tagstr: separator = s break - return set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + res_tags = set(sorted([Tag(t.strip()) for t in tagstr.split(separator)])) + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags or set() def merge(self, taglines: set['TagLine']) -> 'TagLine': """ diff --git a/tests/test_message.py b/tests/test_message.py index 0e326b4..7b8aee9 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -543,11 +543,19 @@ class TagsFromFileTestCase(CmmTestCase): self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path_txt = pathlib.Path(self.file_txt.name) with open(self.file_path_txt, "w") as fd: - fd.write(f"""{TagLine.prefix} tag1 tag2 + fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3 {Question.txt_header} This is a question. {Answer.txt_header} This is an answer. +""") + self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name) + with open(self.file_path_txt_no_tags, "w") as fd: + fd.write(f"""{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. """) self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path_yaml = pathlib.Path(self.file_yaml.name) @@ -560,6 +568,16 @@ This is an answer. {Message.tags_yaml_key}: - tag1 - tag2 + - ptag3 +""") + self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') + self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name) + with open(self.file_path_yaml_no_tags, "w") as fd: + fd.write(f""" +{Question.yaml_key}: |- + This is a question. +{Answer.yaml_key}: |- + This is an answer. """) def tearDown(self) -> None: @@ -570,11 +588,90 @@ This is an answer. def test_tags_from_file_txt(self) -> None: tags = Message.tags_from_file(self.file_path_txt) - self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) + + def test_tags_from_file_txt_no_tags(self) -> None: + tags = Message.tags_from_file(self.file_path_txt_no_tags) + self.assertSetEqual(tags, set()) def test_tags_from_file_yaml(self) -> None: tags = Message.tags_from_file(self.file_path_yaml) - self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) + + def test_tags_from_file_yaml_no_tags(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml_no_tags) + self.assertSetEqual(tags, set()) + + def test_tags_from_file_txt_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_txt, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_txt, prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_yaml_prefix(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml, prefix='p') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_yaml, prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_txt_contain(self) -> None: + tags = Message.tags_from_file(self.file_path_txt, contain='3') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_txt, contain='R') + self.assertSetEqual(tags, set()) + + def test_tags_from_file_yaml_contain(self) -> None: + tags = Message.tags_from_file(self.file_path_yaml, contain='3') + self.assertSetEqual(tags, {Tag('ptag3')}) + tags = Message.tags_from_file(self.file_path_yaml, contain='R') + self.assertSetEqual(tags, set()) + + +class TagsFromDirTestCase(CmmTestCase): + def setUp(self) -> None: + self.temp_dir = tempfile.TemporaryDirectory() + self.temp_dir_no_tags = tempfile.TemporaryDirectory() + self.tag_sets = [ + {Tag('atag1'), Tag('atag2')}, + {Tag('btag3'), Tag('btag4')}, + {Tag('ctag5'), Tag('ctag6')} + ] + self.files = [ + pathlib.Path(self.temp_dir.name, 'file1.txt'), + pathlib.Path(self.temp_dir.name, 'file2.yaml'), + pathlib.Path(self.temp_dir.name, 'file3.txt') + ] + self.files_no_tags = [ + pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'), + pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'), + pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt') + ] + for file, tags in zip(self.files, self.tag_sets): + message = Message(Question('This is a question.'), + Answer('This is an answer.'), + tags) + message.to_file(file) + for file in self.files_no_tags: + message = Message(Question('This is a question.'), + Answer('This is an answer.')) + message.to_file(file) + + def tearDown(self) -> None: + self.temp_dir.cleanup() + + def test_tags_from_dir(self) -> None: + all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name)) + expected_tags = self.tag_sets[0] | self.tag_sets[1] | self.tag_sets[2] + self.assertEqual(all_tags, expected_tags) + + def test_tags_from_dir_prefix(self) -> None: + atags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name), prefix='a') + expected_tags = self.tag_sets[0] + self.assertEqual(atags, expected_tags) + + def test_tags_from_dir_no_tags(self) -> None: + all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir_no_tags.name)) + self.assertSetEqual(all_tags, set()) class MessageIDTestCase(CmmTestCase): @@ -619,3 +716,13 @@ class MessageHashTestCase(CmmTestCase): self.assertEqual(len(msgs), 3) for msg in [self.message1, self.message2, self.message3]: self.assertIn(msg, msgs) + + +class MessageTagsStrTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('tag1')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_tags_str(self) -> None: + self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') diff --git a/tests/test_tags.py b/tests/test_tags.py index 9ac9746..bd2b685 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -40,15 +40,33 @@ class TestTagLine(CmmTestCase): self.assertEqual(tagline, 'TAGS: tag1 tag2') def test_tags(self) -> None: - tagline = TagLine('TAGS: tag1 tag2') + tagline = TagLine('TAGS: atag1 btag2') tags = tagline.tags() - self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + self.assertEqual(tags, {Tag('atag1'), Tag('btag2')}) def test_tags_with_newline(self) -> None: tagline = TagLine('TAGS: tag1\n tag2') tags = tagline.tags() self.assertEqual(tags, {Tag('tag1'), Tag('tag2')}) + def test_tags_prefix(self) -> None: + tagline = TagLine('TAGS: atag1 stag2 stag3') + tags = tagline.tags(prefix='a') + self.assertSetEqual(tags, {Tag('atag1')}) + tags = tagline.tags(prefix='s') + self.assertSetEqual(tags, {Tag('stag2'), Tag('stag3')}) + tags = tagline.tags(prefix='R') + self.assertSetEqual(tags, set()) + + def test_tags_contain(self) -> None: + tagline = TagLine('TAGS: atag1 stag2 stag3') + tags = tagline.tags(contain='t') + self.assertSetEqual(tags, {Tag('atag1'), Tag('stag2'), Tag('stag3')}) + tags = tagline.tags(contain='1') + self.assertSetEqual(tags, {Tag('atag1')}) + tags = tagline.tags(contain='R') + self.assertSetEqual(tags, set()) + def test_merge(self) -> None: tagline1 = TagLine('TAGS: tag1 tag2') tagline2 = TagLine('TAGS: tag2 tag3') -- 2.36.6 From 238dbbee6061c5604a0bcf58d751bc43d517054c Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 27 Aug 2023 18:07:38 +0200 Subject: [PATCH 112/170] fixed handling empty tags in TXT file --- chatmastermind/tags.py | 2 ++ tests/test_message.py | 13 +++++++++++++ tests/test_tags.py | 4 ++++ 3 files changed, 19 insertions(+) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index c438db9..bb45a08 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -124,6 +124,8 @@ class TagLine(str): filtered based on prefix or contained string. """ tagstr = self[len(self.prefix):].strip() + if tagstr == '': + return set() # no tags, only prefix separator = Tag.default_separator # look for alternative separators and use the first one found # -> we don't support different separators in the same TagLine diff --git a/tests/test_message.py b/tests/test_message.py index 7b8aee9..9cfb30a 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -556,6 +556,15 @@ This is an answer. This is a question. {Answer.txt_header} This is an answer. +""") + self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') + self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name) + with open(self.file_path_txt_tags_empty, "w") as fd: + fd.write(f"""TAGS: +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer. """) self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path_yaml = pathlib.Path(self.file_yaml.name) @@ -594,6 +603,10 @@ This is an answer. tags = Message.tags_from_file(self.file_path_txt_no_tags) self.assertSetEqual(tags, set()) + def test_tags_from_file_txt_tags_empty(self) -> None: + tags = Message.tags_from_file(self.file_path_txt_tags_empty) + self.assertSetEqual(tags, set()) + def test_tags_from_file_yaml(self) -> None: tags = Message.tags_from_file(self.file_path_yaml) self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')}) diff --git a/tests/test_tags.py b/tests/test_tags.py index bd2b685..eeab199 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -44,6 +44,10 @@ class TestTagLine(CmmTestCase): tags = tagline.tags() self.assertEqual(tags, {Tag('atag1'), Tag('btag2')}) + def test_tags_empty(self) -> None: + tagline = TagLine('TAGS:') + self.assertSetEqual(tagline.tags(), set()) + def test_tags_with_newline(self) -> None: tagline = TagLine('TAGS: tag1\n tag2') tags = tagline.tags() -- 2.36.6 From fde0ae4652c604d756e5df66e4aa363cc7c427fd Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 29 Aug 2023 11:35:18 +0200 Subject: [PATCH 113/170] fixed test case file cleanup --- tests/test_message.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_message.py b/tests/test_message.py index 9cfb30a..83a73ea 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -594,6 +594,12 @@ This is an answer. self.file_path_txt.unlink() self.file_yaml.close() self.file_path_yaml.unlink() + self.file_txt_no_tags.close + self.file_path_txt_no_tags.unlink() + self.file_txt_tags_empty.close + self.file_path_txt_tags_empty.unlink() + self.file_yaml_no_tags.close() + self.file_path_yaml_no_tags.unlink() def test_tags_from_file_txt(self) -> None: tags = Message.tags_from_file(self.file_path_txt) @@ -671,6 +677,7 @@ class TagsFromDirTestCase(CmmTestCase): def tearDown(self) -> None: self.temp_dir.cleanup() + self.temp_dir_no_tags.cleanup() def test_tags_from_dir(self) -> None: all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name)) -- 2.36.6 From 74c39070d620f79c458497ab9cab6fe356d9b79c Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 30 Aug 2023 08:20:25 +0200 Subject: [PATCH 114/170] fixed Message.filter_tags --- chatmastermind/message.py | 15 ++++++++------- tests/test_message.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 902aaa2..820d104 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -436,13 +436,14 @@ class Message(): Filter tags based on their prefix (i. e. the tag starts with a given string) or some contained string. """ - res_tags = self.tags - if res_tags: - if prefix and len(prefix) > 0: - res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} - if contain and len(contain) > 0: - res_tags -= {tag for tag in res_tags if contain not in tag} - return res_tags or set() + if not self.tags: + return set() + res_tags = self.tags.copy() + if prefix and len(prefix) > 0: + res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)} + if contain and len(contain) > 0: + res_tags -= {tag for tag in res_tags if contain not in tag} + return res_tags def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str: """ diff --git a/tests/test_message.py b/tests/test_message.py index 83a73ea..2a9d0ff 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -746,3 +746,18 @@ class MessageTagsStrTestCase(CmmTestCase): def test_tags_str(self) -> None: self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') + + +class MessageFilterTagsTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_filter_tags(self) -> None: + tags_all = self.message.filter_tags() + self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) + tags_pref = self.message.filter_tags(prefix='a') + self.assertSetEqual(tags_pref, {Tag('atag1')}) + tags_cont = self.message.filter_tags(contain='2') + self.assertSetEqual(tags_cont, {Tag('btag2')}) -- 2.36.6 From dc3f3dc168b8b5fb19bb5b1a88c42638414a19ec Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 31 Aug 2023 09:19:38 +0200 Subject: [PATCH 115/170] added 'message_in()' function and test --- chatmastermind/message.py | 16 +++++++++++++++- tests/test_message.py | 16 +++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 820d104..3eca26e 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -3,7 +3,7 @@ Module implementing message related functions and classes. """ import pathlib import yaml -from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal +from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags @@ -57,6 +57,20 @@ def source_code(text: str, include_delims: bool = False) -> list[str]: return code_sections +def message_in(message: MessageInst, messages: Iterable[MessageInst]) -> bool: + """ + Searches the given message list for a message with the same file + name as the given one (i. e. it compares Message.file_path.name). + If the given message has no file_path, False is returned. + """ + if not message.file_path: + return False + for m in messages: + if m.file_path and m.file_path.name == message.file_path.name: + return True + return False + + @dataclass(kw_only=True) class MessageFilter: """ diff --git a/tests/test_message.py b/tests/test_message.py index 2a9d0ff..0d7953e 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -2,7 +2,7 @@ import pathlib import tempfile from typing import cast from .test_main import CmmTestCase -from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter +from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in from chatmastermind.tags import Tag, TagLine @@ -761,3 +761,17 @@ class MessageFilterTagsTestCase(CmmTestCase): self.assertSetEqual(tags_pref, {Tag('atag1')}) tags_cont = self.message.filter_tags(contain='2') self.assertSetEqual(tags_cont, {Tag('btag2')}) + + +class MessageInTestCase(CmmTestCase): + def setUp(self) -> None: + self.message1 = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + self.message2 = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/bla/foo')) + + def test_message_in(self) -> None: + self.assertTrue(message_in(self.message1, [self.message1])) + self.assertFalse(message_in(self.message1, [self.message2])) -- 2.36.6 From a093f9b86777067439c04de7c3dfeaa5d3a2ec68 Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 31 Aug 2023 15:47:29 +0200 Subject: [PATCH 116/170] tags: some clarification and new tests --- chatmastermind/tags.py | 3 ++- tests/test_tags.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/chatmastermind/tags.py b/chatmastermind/tags.py index bb45a08..5ea1a3a 100644 --- a/chatmastermind/tags.py +++ b/chatmastermind/tags.py @@ -77,7 +77,8 @@ def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[s i. e. you can select a TagLine if it either contains one of the tags in 'tags_or' or all of the tags in 'tags_and' but it must never contain any of the tags in 'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag - exclusion is still done if 'tags_not' is not 'None'). + exclusion is still done if 'tags_not' is not 'None'). If they are empty (set()), + they match no tags. """ required_tags_present = False excluded_tags_missing = False diff --git a/tests/test_tags.py b/tests/test_tags.py index eeab199..aa89a06 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -144,3 +144,20 @@ class TestTagLine(CmmTestCase): # Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags tags_not = {Tag('tag2')} self.assertFalse(tagline.match_tags(None, None, tags_not)) + + # Test case 10: 'tags_or' and 'tags_and' are empty, match no tags + self.assertFalse(tagline.match_tags(set(), set(), None)) + + # Test case 11: 'tags_or' is empty, match no tags + self.assertFalse(tagline.match_tags(set(), None, None)) + + # Test case 12: 'tags_and' is empty, match no tags + self.assertFalse(tagline.match_tags(None, set(), None)) + + # Test case 13: 'tags_or' is empty, match 'tags_and' + tags_and = {Tag('tag1'), Tag('tag2')} + self.assertTrue(tagline.match_tags(None, tags_and, None)) + + # Test case 14: 'tags_and' is empty, match 'tags_or' + tags_or = {Tag('tag1'), Tag('tag2')} + self.assertTrue(tagline.match_tags(tags_or, None, None)) -- 2.36.6 From 64893949a4193cdfcd03c0d268325b1347d71c0a Mon Sep 17 00:00:00 2001 From: juk0de Date: Thu, 24 Aug 2023 16:49:54 +0200 Subject: [PATCH 117/170] added new module 'chat.py' --- chatmastermind/chat.py | 278 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 278 insertions(+) create mode 100644 chatmastermind/chat.py diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py new file mode 100644 index 0000000..c5d8bf3 --- /dev/null +++ b/chatmastermind/chat.py @@ -0,0 +1,278 @@ +""" +Module implementing various chat classes and functions for managing a chat history. +""" +import shutil +import pathlib +from pprint import PrettyPrinter +from pydoc import pager +from dataclasses import dataclass +from typing import TypeVar, Type, Optional, ClassVar, Any, Callable +from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in +from .tags import Tag + +ChatInst = TypeVar('ChatInst', bound='Chat') +ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB') + + +class ChatError(Exception): + pass + + +def terminal_width() -> int: + return shutil.get_terminal_size().columns + + +def pp(*args: Any, **kwargs: Any) -> None: + return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) + + +def print_paged(text: str) -> None: + pager(text) + + +def read_dir(dir_path: pathlib.Path, + glob: Optional[str] = None, + mfilter: Optional[MessageFilter] = None) -> list[Message]: + """ + Reads the messages from the given folder. + Parameters: + * 'dir_path': source directory + * 'glob': if specified, files will be filtered using 'path.glob()', + otherwise it uses 'path.iterdir()'. + * 'mfilter': use with 'Message.from_file()' to filter messages + when reading them. + """ + messages: list[Message] = [] + file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() + for file_path in sorted(file_iter): + if file_path.is_file(): + try: + message = Message.from_file(file_path, mfilter) + if message: + messages.append(message) + except MessageError as e: + print(f"Error processing message in '{file_path}': {str(e)}") + return messages + + +def write_dir(dir_path: pathlib.Path, + messages: list[Message], + file_suffix: str, + next_fid: Callable[[], int]) -> None: + """ + Write all messages to the given directory. If a message has no file_path, + a new one will be created. If message.file_path exists, it will be modified + to point to the given directory. + Parameters: + * 'dir_path': destination directory + * 'messages': list of messages to write + * 'file_suffix': suffix for the message files ['.txt'|'.yaml'] + * 'next_fid': callable that returns the next file ID + """ + for message in messages: + file_path = message.file_path + # message has no file_path: create one + if not file_path: + fid = next_fid() + fname = f"{fid:04d}{file_suffix}" + file_path = dir_path / fname + # file_path does not point to given directory: modify it + elif not file_path.parent.samefile(dir_path): + file_path = dir_path / file_path.name + message.to_file(file_path) + + +@dataclass +class Chat: + """ + A class containing a complete chat history. + """ + + messages: list[Message] + + def filter(self, mfilter: MessageFilter) -> None: + """ + Use 'Message.match(mfilter) to remove all messages that + don't fulfill the filter requirements. + """ + self.messages = [m for m in self.messages if m.match(mfilter)] + + def sort(self, reverse: bool = False) -> None: + """ + Sort the messages according to 'Message.msg_id()'. + """ + try: + # the message may not have an ID if it doesn't have a file_path + self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse) + except MessageError: + pass + + def clear(self) -> None: + """ + Delete all messages. + """ + self.messages = [] + + def add_msgs(self, msgs: list[Message]) -> None: + """ + Add new messages and sort them if possible. + """ + self.messages += msgs + self.sort() + + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: + """ + Get the tags of all messages, optionally filtered by prefix or substring. + """ + tags: set[Tag] = set() + for m in self.messages: + tags |= m.filter_tags(prefix, contain) + return tags + + def print(self, dump: bool = False, source_code_only: bool = False, + with_tags: bool = False, with_file: bool = False, + paged: bool = True) -> None: + if dump: + pp(self) + return + output: list[str] = [] + for message in self.messages: + if source_code_only: + output.extend(source_code(message.question, include_delims=True)) + continue + output.append('-' * terminal_width()) + output.append(Question.txt_header) + output.append(message.question) + if message.answer: + output.append(Answer.txt_header) + output.append(message.answer) + if with_tags: + output.append(message.tags_str()) + if with_file: + output.append('FILE: ' + str(message.file_path)) + if paged: + print_paged('\n'.join(output)) + else: + print(*output, sep='\n') + + +@dataclass +class ChatDB(Chat): + """ + A 'Chat' class that is bound to a given directory structure. Supports reading + and writing messages from / to that structure. Such a structure consists of + two directories: a 'cache directory', where all messages are temporarily + stored, and a 'DB' directory, where selected messages can be stored + persistently. + """ + + default_file_suffix: ClassVar[str] = '.txt' + + cache_path: pathlib.Path + db_path: pathlib.Path + # a MessageFilter that all messages must match (if given) + mfilter: Optional[MessageFilter] = None + file_suffix: str = default_file_suffix + # the glob pattern for all messages + glob: Optional[str] = None + + def __post_init__(self) -> None: + # contains the latest message ID + self.next_fname = self.db_path / '.next' + # make all paths absolute + self.cache_path = self.cache_path.absolute() + self.db_path = self.db_path.absolute() + + @classmethod + def from_dir(cls: Type[ChatDBInst], + cache_path: pathlib.Path, + db_path: pathlib.Path, + glob: Optional[str] = None, + mfilter: Optional[MessageFilter] = None) -> ChatDBInst: + """ + Create a 'ChatDB' instance from the given directory structure. + Reads all messages from 'db_path' into the local message list. + Parameters: + * 'cache_path': path to the directory for temporary messages + * 'db_path': path to the directory for persistent messages + * 'glob': if specified, files will be filtered using 'path.glob()', + otherwise it uses 'path.iterdir()'. + * 'mfilter': use with 'Message.from_file()' to filter messages + when reading them. + """ + messages = read_dir(db_path, glob, mfilter) + return cls(messages, cache_path, db_path, mfilter, + cls.default_file_suffix, glob) + + @classmethod + def from_messages(cls: Type[ChatDBInst], + cache_path: pathlib.Path, + db_path: pathlib.Path, + messages: list[Message], + mfilter: Optional[MessageFilter] = None) -> ChatDBInst: + """ + Create a ChatDB instance from the given message list. + """ + return cls(messages, cache_path, db_path, mfilter) + + def get_next_fid(self) -> int: + try: + with open(self.next_fname, 'r') as f: + next_fid = int(f.read()) + 1 + self.set_next_fid(next_fid) + return next_fid + except Exception: + self.set_next_fid(1) + return 1 + + def set_next_fid(self, fid: int) -> None: + with open(self.next_fname, 'w') as f: + f.write(f'{fid}') + + def read_db(self) -> None: + """ + Reads new messages from the DB directory. New ones are added to the internal list, + existing ones are replaced. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. + """ + new_messages = read_dir(self.db_path, self.glob, self.mfilter) + # remove all messages from self.messages that are in the new list + self.messages = [m for m in self.messages if not message_in(m, new_messages)] + # copy the messages from the temporary list to self.messages and sort them + self.messages += new_messages + self.sort() + + def read_cache(self) -> None: + """ + Reads new messages from the cache directory. New ones are added to the internal list, + existing ones are replaced. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. + """ + new_messages = read_dir(self.cache_path, self.glob, self.mfilter) + # remove all messages from self.messages that are in the new list + self.messages = [m for m in self.messages if not message_in(m, new_messages)] + # copy the messages from the temporary list to self.messages and sort them + self.messages += new_messages + self.sort() + + def write_db(self, msgs: Optional[list[Message]] = None) -> None: + """ + Write messages to the DB directory. If a message has no file_path, a new one + will be created. If message.file_path exists, it will be modified to point + to the DB directory. + """ + write_dir(self.db_path, + msgs if msgs else self.messages, + self.file_suffix, + self.get_next_fid) + + def write_cache(self, msgs: Optional[list[Message]] = None) -> None: + """ + Write messages to the cache directory. If a message has no file_path, a new one + will be created. If message.file_path exists, it will be modified to point to + the cache directory. + """ + write_dir(self.cache_path, + msgs if msgs else self.messages, + self.file_suffix, + self.get_next_fid) -- 2.36.6 From 815a21893c70e4bf1186dc063b74229891915746 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 28 Aug 2023 14:24:24 +0200 Subject: [PATCH 118/170] added tests for 'chat.py' --- tests/test_chat.py | 297 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 297 insertions(+) create mode 100644 tests/test_chat.py diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..2d0ffa0 --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,297 @@ +import pathlib +import tempfile +import time +from io import StringIO +from unittest.mock import patch +from chatmastermind.tags import TagLine +from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter +from chatmastermind.chat import Chat, ChatDB, terminal_width +from .test_main import CmmTestCase + + +class TestChat(CmmTestCase): + def setUp(self) -> None: + self.chat = Chat([]) + self.message1 = Message(Question('Question 1'), + Answer('Answer 1'), + {Tag('atag1')}, + file_path=pathlib.Path('0001.txt')) + self.message2 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('btag2')}, + file_path=pathlib.Path('0002.txt')) + + def test_filter(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.filter(MessageFilter(answer_contains='Answer 1')) + + self.assertEqual(len(self.chat.messages), 1) + self.assertEqual(self.chat.messages[0].question, 'Question 1') + + def test_sort(self) -> None: + self.chat.add_msgs([self.message2, self.message1]) + self.chat.sort() + self.assertEqual(self.chat.messages[0].question, 'Question 1') + self.assertEqual(self.chat.messages[1].question, 'Question 2') + self.chat.sort(reverse=True) + self.assertEqual(self.chat.messages[0].question, 'Question 2') + self.assertEqual(self.chat.messages[1].question, 'Question 1') + + def test_clear(self) -> None: + self.chat.add_msgs([self.message1]) + self.chat.clear() + self.assertEqual(len(self.chat.messages), 0) + + def test_add_msgs(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.assertEqual(len(self.chat.messages), 2) + self.assertEqual(self.chat.messages[0].question, 'Question 1') + self.assertEqual(self.chat.messages[1].question, 'Question 2') + + def test_tags(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + tags_all = self.chat.tags() + self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) + tags_pref = self.chat.tags(prefix='a') + self.assertSetEqual(tags_pref, {Tag('atag1')}) + tags_cont = self.chat.tags(contain='2') + self.assertSetEqual(tags_cont, {Tag('btag2')}) + + @patch('sys.stdout', new_callable=StringIO) + def test_print(self, mock_stdout: StringIO) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.print(paged=False) + expected_output = f"""{'-'*terminal_width()} +{Question.txt_header} +Question 1 +{Answer.txt_header} +Answer 1 +{'-'*terminal_width()} +{Question.txt_header} +Question 2 +{Answer.txt_header} +Answer 2 +""" + self.assertEqual(mock_stdout.getvalue(), expected_output) + + @patch('sys.stdout', new_callable=StringIO) + def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: + self.chat.add_msgs([self.message1, self.message2]) + self.chat.print(paged=False, with_tags=True, with_file=True) + expected_output = f"""{'-'*terminal_width()} +{Question.txt_header} +Question 1 +{Answer.txt_header} +Answer 1 +{TagLine.prefix} atag1 +FILE: 0001.txt +{'-'*terminal_width()} +{Question.txt_header} +Question 2 +{Answer.txt_header} +Answer 2 +{TagLine.prefix} btag2 +FILE: 0002.txt +""" + self.assertEqual(mock_stdout.getvalue(), expected_output) + + +class TestChatDB(CmmTestCase): + def setUp(self) -> None: + self.db_path = tempfile.TemporaryDirectory() + self.cache_path = tempfile.TemporaryDirectory() + + self.message1 = Message(Question('Question 1'), + Answer('Answer 1'), + {Tag('tag1')}, + file_path=pathlib.Path('0001.txt')) + self.message2 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('tag2')}, + file_path=pathlib.Path('0002.yaml')) + self.message3 = Message(Question('Question 3'), + Answer('Answer 3'), + {Tag('tag3')}, + file_path=pathlib.Path('0003.txt')) + self.message4 = Message(Question('Question 4'), + Answer('Answer 4'), + {Tag('tag4')}, + file_path=pathlib.Path('0004.yaml')) + + self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt')) + self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) + self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) + self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) + + def tearDown(self) -> None: + self.db_path.cleanup() + self.cache_path.cleanup() + pass + + def test_chat_db_from_dir(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(len(chat_db.messages), 4) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + # check that the files are sorted + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, + pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, + pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, + pathlib.Path(self.db_path.name, '0004.yaml')) + + def test_chat_db_from_dir_glob(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + glob='*.txt') + self.assertEqual(len(chat_db.messages), 2) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, + pathlib.Path(self.db_path.name, '0003.txt')) + + def test_chat_db_filter(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(answer_contains='Answer 2')) + self.assertEqual(len(chat_db.messages), 1) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[0].answer, 'Answer 2') + + def test_chat_db_from_messges(self) -> None: + chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + messages=[self.message1, self.message2, + self.message3, self.message4]) + self.assertEqual(len(chat_db.messages), 4) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + + def test_chat_db_fids(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.get_next_fid(), 1) + self.assertEqual(chat_db.get_next_fid(), 2) + self.assertEqual(chat_db.get_next_fid(), 3) + with open(chat_db.next_fname, 'r') as f: + self.assertEqual(f.read(), '3') + + def test_chat_db_write(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # check that Message.file_path is correct + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + # write the messages to the cache directory + chat_db.write_cache() + # check if the written files are in the cache directory + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 4) + self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files) + self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files) + # check that Message.file_path has been correctly updated + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) + + # check the timestamp of the files in the DB directory + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} + # overwrite the messages in the db directory + time.sleep(0.05) + chat_db.write_db() + # check if the written files are in the DB directory + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + # check if all files in the DB dir have actually been overwritten + for file in db_dir_files: + self.assertGreater(file.stat().st_mtime, old_timestamps[file]) + # check that Message.file_path has been correctly updated (again) + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + def test_chat_db_read(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + self.assertEqual(len(chat_db.messages), 4) + + # create 2 new files in the DB directory + new_message1 = Message(Question('Question 5'), + Answer('Answer 5'), + {Tag('tag5')}) + new_message2 = Message(Question('Question 6'), + Answer('Answer 6'), + {Tag('tag6')}) + new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) + new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) + # read and check them + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 6) + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # create 2 new files in the cache directory + new_message3 = Message(Question('Question 7'), + Answer('Answer 5'), + {Tag('tag7')}) + new_message4 = Message(Question('Question 8'), + Answer('Answer 6'), + {Tag('tag8')}) + new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt')) + new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml')) + # read and check them + chat_db.read_cache() + self.assertEqual(len(chat_db.messages), 8) + # check that the new message have the cache dir path + self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt')) + self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml')) + # an the old ones keep their path (since they have not been replaced) + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # now overwrite two messages in the DB directory + new_message1.question = Question('New Question 1') + new_message2.question = Question('New Question 2') + new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt')) + new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml')) + # read from the DB dir and check if the modified messages have been updated + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 8) + self.assertEqual(chat_db.messages[4].question, 'New Question 1') + self.assertEqual(chat_db.messages[5].question, 'New Question 2') + self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt')) + self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml')) + + # now write the messages from the cache to the DB directory + new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt')) + new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml')) + # read and check them + chat_db.read_db() + self.assertEqual(len(chat_db.messages), 8) + # check that they now have the DB path + self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) + self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) -- 2.36.6 From 6737fa98c73a1db51f9ee9bf25b0765e2c193c96 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 08:57:54 +0200 Subject: [PATCH 119/170] added tokens() function to Message and Chat --- chatmastermind/chat.py | 7 +++++++ chatmastermind/message.py | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index c5d8bf3..4a458df 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -129,6 +129,13 @@ class Chat: tags |= m.filter_tags(prefix, contain) return tags + def tokens(self) -> int: + """ + Returns the nr. of AI language tokens used by all messages in this chat. + If unknown, 0 is returned. + """ + return sum(m.tokens() for m in self.messages) + def print(self, dump: bool = False, source_code_only: bool = False, with_tags: bool = False, with_file: bool = False, paged: bool = True) -> None: diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 3eca26e..675ab3a 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -132,6 +132,7 @@ class Question(str): """ A single question with a defined header. """ + tokens: int = 0 # tokens used by this question txt_header: ClassVar[str] = '=== QUESTION ===' yaml_key: ClassVar[str] = 'question' @@ -165,6 +166,7 @@ class Answer(str): """ A single answer with a defined header. """ + tokens: int = 0 # tokens used by this answer txt_header: ClassVar[str] = '=== ANSWER ===' yaml_key: ClassVar[str] = 'answer' @@ -502,3 +504,13 @@ class Message(): def as_dict(self) -> dict[str, Any]: return asdict(self) + + def tokens(self) -> int: + """ + Returns the nr. of AI language tokens used by this message. + If unknown, 0 is returned. + """ + if self.answer: + return self.question.tokens + self.answer.tokens + else: + return self.question.tokens -- 2.36.6 From 33565d351dc575660955b32a63f5c427998ec80c Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 09:07:58 +0200 Subject: [PATCH 120/170] configuration: added AIConfig class --- chatmastermind/configuration.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 5ae32d6..0780604 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -7,7 +7,15 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') @dataclass -class OpenAIConfig(): +class AIConfig: + """ + The base class of all AI configurations. + """ + name: str + + +@dataclass +class OpenAIConfig(AIConfig): """ The OpenAI section of the configuration file. """ @@ -25,6 +33,7 @@ class OpenAIConfig(): Create OpenAIConfig from a dict. """ return cls( + name='OpenAI', api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), @@ -36,7 +45,7 @@ class OpenAIConfig(): @dataclass -class Config(): +class Config: """ The configuration file structure. """ @@ -47,7 +56,7 @@ class Config(): @classmethod def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: """ - Create OpenAIConfig from a dict. + Create Config from a dict. """ return cls( system=str(source['system']), -- 2.36.6 From b22a4b07ed99ef9e3c13159479bec9cd07b4b9f9 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 12:35:32 +0200 Subject: [PATCH 121/170] chat: added tags_frequency() function and test --- chatmastermind/chat.py | 11 ++++++++++- tests/test_chat.py | 9 +++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 4a458df..759467d 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -127,7 +127,16 @@ class Chat: tags: set[Tag] = set() for m in self.messages: tags |= m.filter_tags(prefix, contain) - return tags + return set(sorted(tags)) + + def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]: + """ + Get the frequency of all tags of all messages, optionally filtered by prefix or substring. + """ + tags: list[Tag] = [] + for m in self.messages: + tags += [tag for tag in m.filter_tags(prefix, contain)] + return {tag: tags.count(tag) for tag in sorted(tags)} def tokens(self) -> int: """ diff --git a/tests/test_chat.py b/tests/test_chat.py index 2d0ffa0..5f1fcb6 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -14,7 +14,7 @@ class TestChat(CmmTestCase): self.chat = Chat([]) self.message1 = Message(Question('Question 1'), Answer('Answer 1'), - {Tag('atag1')}, + {Tag('atag1'), Tag('btag2')}, file_path=pathlib.Path('0001.txt')) self.message2 = Message(Question('Question 2'), Answer('Answer 2'), @@ -57,6 +57,11 @@ class TestChat(CmmTestCase): tags_cont = self.chat.tags(contain='2') self.assertSetEqual(tags_cont, {Tag('btag2')}) + def test_tags_frequency(self) -> None: + self.chat.add_msgs([self.message1, self.message2]) + tags_freq = self.chat.tags_frequency() + self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) + @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_msgs([self.message1, self.message2]) @@ -83,7 +88,7 @@ Answer 2 Question 1 {Answer.txt_header} Answer 1 -{TagLine.prefix} atag1 +{TagLine.prefix} atag1 btag2 FILE: 0001.txt {'-'*terminal_width()} {Question.txt_header} -- 2.36.6 From 48c8e951e1d439426e8a22a89b7dc2a24fdd0898 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 12:44:27 +0200 Subject: [PATCH 122/170] chat: fixed handling of unsupported files in DB and chache dir --- chatmastermind/chat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 759467d..11f1d74 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -45,7 +45,7 @@ def read_dir(dir_path: pathlib.Path, messages: list[Message] = [] file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() for file_path in sorted(file_iter): - if file_path.is_file(): + if file_path.is_file() and file_path.suffix in Message.file_suffixes: try: message = Message.from_file(file_path, mfilter) if message: -- 2.36.6 From c318b99671be511d7c79226d65974befd4241932 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 08:18:41 +0200 Subject: [PATCH 123/170] chat: improved history printing --- chatmastermind/chat.py | 15 ++++++--------- tests/test_chat.py | 10 +++++----- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 11f1d74..e4e8ab6 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -145,27 +145,24 @@ class Chat: """ return sum(m.tokens() for m in self.messages) - def print(self, dump: bool = False, source_code_only: bool = False, - with_tags: bool = False, with_file: bool = False, + def print(self, source_code_only: bool = False, + with_tags: bool = False, with_files: bool = False, paged: bool = True) -> None: - if dump: - pp(self) - return output: list[str] = [] for message in self.messages: if source_code_only: output.extend(source_code(message.question, include_delims=True)) continue output.append('-' * terminal_width()) + if with_tags: + output.append(message.tags_str()) + if with_files: + output.append('FILE: ' + str(message.file_path)) output.append(Question.txt_header) output.append(message.question) if message.answer: output.append(Answer.txt_header) output.append(message.answer) - if with_tags: - output.append(message.tags_str()) - if with_file: - output.append('FILE: ' + str(message.file_path)) if paged: print_paged('\n'.join(output)) else: diff --git a/tests/test_chat.py b/tests/test_chat.py index 5f1fcb6..8e1ad0d 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -82,21 +82,21 @@ Answer 2 @patch('sys.stdout', new_callable=StringIO) def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: self.chat.add_msgs([self.message1, self.message2]) - self.chat.print(paged=False, with_tags=True, with_file=True) + self.chat.print(paged=False, with_tags=True, with_files=True) expected_output = f"""{'-'*terminal_width()} +{TagLine.prefix} atag1 btag2 +FILE: 0001.txt {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 -{TagLine.prefix} atag1 btag2 -FILE: 0001.txt {'-'*terminal_width()} +{TagLine.prefix} btag2 +FILE: 0002.txt {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 -{TagLine.prefix} btag2 -FILE: 0002.txt """ self.assertEqual(mock_stdout.getvalue(), expected_output) -- 2.36.6 From 8e63831701741705799fd7baee065c4fe6b420b0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 09:19:47 +0200 Subject: [PATCH 124/170] chat: added clear_cache() function and test --- chatmastermind/chat.py | 20 +++++++++++++++++++ tests/test_chat.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index e4e8ab6..9fc0a27 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -82,6 +82,17 @@ def write_dir(dir_path: pathlib.Path, message.to_file(file_path) +def clear_dir(dir_path: pathlib.Path, + glob: Optional[str] = None) -> None: + """ + Deletes all Message files in the given directory. + """ + file_iter = dir_path.glob(glob) if glob else dir_path.iterdir() + for file_path in file_iter: + if file_path.is_file() and file_path.suffix in Message.file_suffixes: + file_path.unlink(missing_ok=True) + + @dataclass class Chat: """ @@ -289,3 +300,12 @@ class ChatDB(Chat): msgs if msgs else self.messages, self.file_suffix, self.get_next_fid) + + def clear_cache(self) -> None: + """ + Deletes all Message files from the cache dir and removes those messages from + the internal list. + """ + clear_dir(self.cache_path, self.glob) + # only keep messages from DB dir (or those that have not yet been written) + self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] diff --git a/tests/test_chat.py b/tests/test_chat.py index 8e1ad0d..9e74061 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -300,3 +300,48 @@ class TestChatDB(CmmTestCase): # check that they now have the DB path self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt')) self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml')) + + def test_chat_db_clear(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + # check that Message.file_path is correct + self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt')) + self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml')) + self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt')) + self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml')) + + # write the messages to the cache directory + chat_db.write_cache() + # check if the written files are in the cache directory + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 4) + + # now rewrite them to the DB dir and check for modified paths + chat_db.write_db() + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files) + self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files) + + # add a new message with empty file_path + message_empty = Message(question=Question("What the hell am I doing here?"), + answer=Answer("You don't belong here!")) + # and one for the cache dir + message_cache = Message(question=Question("What the hell am I doing here?"), + answer=Answer("You're a creep!"), + file_path=pathlib.Path(self.cache_path.name, '0005.txt')) + chat_db.add_msgs([message_empty, message_cache]) + + # clear the cache and check the cache dir + chat_db.clear_cache() + cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + self.assertEqual(len(cache_dir_files), 0) + # make sure that the DB messages (and the new message) are still there + self.assertEqual(len(chat_db.messages), 5) + db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + self.assertEqual(len(db_dir_files), 4) + # but not the message with the cache dir path + self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) -- 2.36.6 From aba3eb783d3ac9b1a644225ee509393673cf21ab Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 16:00:24 +0200 Subject: [PATCH 125/170] message: improved robustness of Question and Answer content checks and tests --- chatmastermind/message.py | 48 +++++++++++++++++++++------------------ tests/test_message.py | 29 ++++++++++++++++++----- 2 files changed, 49 insertions(+), 28 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 675ab3a..384fb96 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -128,29 +128,29 @@ class ModelLine(str): return cls(' '.join([cls.prefix, model])) -class Question(str): +class Answer(str): """ - A single question with a defined header. + A single answer with a defined header. """ - tokens: int = 0 # tokens used by this question - txt_header: ClassVar[str] = '=== QUESTION ===' - yaml_key: ClassVar[str] = 'question' + tokens: int = 0 # tokens used by this answer + txt_header: ClassVar[str] = '=== ANSWER ===' + yaml_key: ClassVar[str] = 'answer' - def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: + def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: """ - Make sure the question string does not contain the header. + Make sure the answer string does not contain the header as a whole line. """ - if cls.txt_header in string: - raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + if cls.txt_header in string.split('\n'): + raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod - def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: + def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ - if any(cls.txt_header in string for string in strings): + if cls.txt_header in strings: raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance @@ -162,29 +162,33 @@ class Question(str): return source_code(self, include_delims) -class Answer(str): +class Question(str): """ - A single answer with a defined header. + A single question with a defined header. """ - tokens: int = 0 # tokens used by this answer - txt_header: ClassVar[str] = '=== ANSWER ===' - yaml_key: ClassVar[str] = 'answer' + tokens: int = 0 # tokens used by this question + txt_header: ClassVar[str] = '=== QUESTION ===' + yaml_key: ClassVar[str] = 'question' - def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: + def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst: """ - Make sure the answer string does not contain the header. + Make sure the question string does not contain the header as a whole line + (also not that from 'Answer', so it's always clear where the answer starts). """ - if cls.txt_header in string: - raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'") + string_lines = string.split('\n') + if cls.txt_header in string_lines: + raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'") + if Answer.txt_header in string_lines: + raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'") instance = super().__new__(cls, string) return instance @classmethod - def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst: + def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst: """ Build Question from a list of strings. Make sure strings do not contain the header. """ - if any(cls.txt_header in string for string in strings): + if cls.txt_header in strings: raise MessageError(f"Question contains the header '{cls.txt_header}'") instance = super().__new__(cls, '\n'.join(strings).strip()) return instance diff --git a/tests/test_message.py b/tests/test_message.py index 0d7953e..e01de66 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -61,22 +61,39 @@ class SourceCodeTestCase(CmmTestCase): class QuestionTestCase(CmmTestCase): - def test_question_with_prefix(self) -> None: + def test_question_with_header(self) -> None: with self.assertRaises(MessageError): - Question("=== QUESTION === What is your name?") + Question(f"{Question.txt_header}\nWhat is your name?") - def test_question_without_prefix(self) -> None: + def test_question_with_answer_header(self) -> None: + with self.assertRaises(MessageError): + Question(f"{Answer.txt_header}\nBob") + + def test_question_with_legal_header(self) -> None: + """ + If the header is just a part of a line, it's fine. + """ + question = Question(f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?") + self.assertIsInstance(question, Question) + self.assertEqual(question, f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?") + + def test_question_without_header(self) -> None: question = Question("What is your favorite color?") self.assertIsInstance(question, Question) self.assertEqual(question, "What is your favorite color?") class AnswerTestCase(CmmTestCase): - def test_answer_with_prefix(self) -> None: + def test_answer_with_header(self) -> None: with self.assertRaises(MessageError): - Answer("=== ANSWER === Yes") + Answer(f"{Answer.txt_header}\nno") - def test_answer_without_prefix(self) -> None: + def test_answer_with_legal_header(self) -> None: + answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") + self.assertIsInstance(answer, Answer) + self.assertEqual(answer, f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.") + + def test_answer_without_header(self) -> None: answer = Answer("No") self.assertIsInstance(answer, Answer) self.assertEqual(answer, "No") -- 2.36.6 From d35de86c67fc963a3d3b97c48757ce05bd31cc04 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 10:00:08 +0200 Subject: [PATCH 126/170] message: fixed Answer header for TXT format --- chatmastermind/message.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 384fb96..87de8e2 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -96,7 +96,7 @@ class AILine(str): def __new__(cls: Type[AILineInst], string: str) -> AILineInst: if not string.startswith(cls.prefix): - raise TagError(f"AILine '{string}' is missing prefix '{cls.prefix}'") + raise MessageError(f"AILine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance @@ -116,7 +116,7 @@ class ModelLine(str): def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst: if not string.startswith(cls.prefix): - raise TagError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") + raise MessageError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'") instance = super().__new__(cls, string) return instance @@ -133,7 +133,7 @@ class Answer(str): A single answer with a defined header. """ tokens: int = 0 # tokens used by this answer - txt_header: ClassVar[str] = '=== ANSWER ===' + txt_header: ClassVar[str] = '==== ANSWER ====' yaml_key: ClassVar[str] = 'answer' def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst: @@ -355,17 +355,20 @@ class Message(): try: pos = fd.tell() ai = AILine(fd.readline()).ai() - except TagError: + except MessageError: fd.seek(pos) # ModelLine (Optional) try: pos = fd.tell() model = ModelLine(fd.readline()).model() - except TagError: + except MessageError: fd.seek(pos) # Question and Answer text = fd.read().strip().split('\n') - question_idx = text.index(Question.txt_header) + 1 + try: + question_idx = text.index(Question.txt_header) + 1 + except ValueError: + raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'") try: answer_idx = text.index(Answer.txt_header) question = Question.from_list(text[question_idx:answer_idx]) -- 2.36.6 From 713b55482a61195ce7163bd839b6eb257619fb03 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 10:19:14 +0200 Subject: [PATCH 127/170] message: added rename_tags() function and test --- chatmastermind/message.py | 10 +++++++++- tests/test_message.py | 12 ++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 87de8e2..0fb949c 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -5,7 +5,7 @@ import pathlib import yaml from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field -from .tags import Tag, TagLine, TagError, match_tags +from .tags import Tag, TagLine, TagError, match_tags, rename_tags QuestionInst = TypeVar('QuestionInst', bound='Question') AnswerInst = TypeVar('AnswerInst', bound='Answer') @@ -499,6 +499,14 @@ class Message(): return False return True + def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> None: + """ + Renames the given tags. The first tuple element is the old name, + the second one is the new name. + """ + if self.tags: + self.tags = rename_tags(self.tags, tags_rename) + def msg_id(self) -> str: """ Returns an ID that is unique throughout all messages in the same (DB) directory. diff --git a/tests/test_message.py b/tests/test_message.py index e01de66..e860538 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -792,3 +792,15 @@ class MessageInTestCase(CmmTestCase): def test_message_in(self) -> None: self.assertTrue(message_in(self.message1, [self.message1])) self.assertFalse(message_in(self.message1, [self.message2])) + + +class MessageRenameTagsTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_rename_tags(self) -> None: + self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))}) + self.assertIsNotNone(self.message.tags) + self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] -- 2.36.6 From 2e2228bd60e4aa7761e64833fe517604a9784096 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 3 Sep 2023 10:18:16 +0200 Subject: [PATCH 128/170] chat: new possibilites for adding messages and better tests --- chatmastermind/chat.py | 75 ++++++++++++++++++++++++---- tests/test_chat.py | 109 ++++++++++++++++++++++++++++++++--------- 2 files changed, 153 insertions(+), 31 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 9fc0a27..7e6df8f 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -55,6 +55,16 @@ def read_dir(dir_path: pathlib.Path, return messages +def make_file_path(dir_path: pathlib.Path, + file_suffix: str, + next_fid: Callable[[], int]) -> pathlib.Path: + """ + Create a file_path for the given directory using the + given file_suffix and ID generator function. + """ + return dir_path / f"{next_fid():04d}{file_suffix}" + + def write_dir(dir_path: pathlib.Path, messages: list[Message], file_suffix: str, @@ -73,9 +83,7 @@ def write_dir(dir_path: pathlib.Path, file_path = message.file_path # message has no file_path: create one if not file_path: - fid = next_fid() - fname = f"{fid:04d}{file_suffix}" - file_path = dir_path / fname + file_path = make_file_path(dir_path, file_suffix, next_fid) # file_path does not point to given directory: modify it elif not file_path.parent.samefile(dir_path): file_path = dir_path / file_path.name @@ -124,11 +132,11 @@ class Chat: """ self.messages = [] - def add_msgs(self, msgs: list[Message]) -> None: + def add_messages(self, messages: list[Message]) -> None: """ Add new messages and sort them if possible. """ - self.messages += msgs + self.messages += messages self.sort() def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: @@ -279,25 +287,25 @@ class ChatDB(Chat): self.messages += new_messages self.sort() - def write_db(self, msgs: Optional[list[Message]] = None) -> None: + def write_db(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the DB directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the DB directory. """ write_dir(self.db_path, - msgs if msgs else self.messages, + messages if messages else self.messages, self.file_suffix, self.get_next_fid) - def write_cache(self, msgs: Optional[list[Message]] = None) -> None: + def write_cache(self, messages: Optional[list[Message]] = None) -> None: """ Write messages to the cache directory. If a message has no file_path, a new one will be created. If message.file_path exists, it will be modified to point to the cache directory. """ write_dir(self.cache_path, - msgs if msgs else self.messages, + messages if messages else self.messages, self.file_suffix, self.get_next_fid) @@ -309,3 +317,52 @@ class ChatDB(Chat): clear_dir(self.cache_path, self.glob) # only keep messages from DB dir (or those that have not yet been written) self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)] + + def add_to_db(self, messages: list[Message], write: bool = True) -> None: + """ + Add the given new messages and set the file_path to the DB directory. + Only accepts messages without a file_path. + """ + if any(m.file_path is not None for m in messages): + raise ChatError("Can't add new messages with existing file_path") + if write: + write_dir(self.db_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() + + def add_to_cache(self, messages: list[Message], write: bool = True) -> None: + """ + Add the given new messages and set the file_path to the cache directory. + Only accepts messages without a file_path. + """ + if any(m.file_path is not None for m in messages): + raise ChatError("Can't add new messages with existing file_path") + if write: + write_dir(self.cache_path, + messages, + self.file_suffix, + self.get_next_fid) + else: + for m in messages: + m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid) + self.messages += messages + self.sort() + + def write_messages(self, messages: Optional[list[Message]] = None) -> None: + """ + Write either the given messages or the internal ones to their current file_path. + If messages are given, they all must have a valid file_path. When writing the + internal messages, the ones with a valid file_path are written, the others + are ignored. + """ + if messages and any(m.file_path is None for m in messages): + raise ChatError("Can't write files without a valid file_path") + msgs = iter(messages if messages else self.messages) + while (m := next(msgs, None)): + m.to_file() diff --git a/tests/test_chat.py b/tests/test_chat.py index 9e74061..a1c020e 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -5,7 +5,7 @@ from io import StringIO from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter -from chatmastermind.chat import Chat, ChatDB, terminal_width +from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError from .test_main import CmmTestCase @@ -22,14 +22,14 @@ class TestChat(CmmTestCase): file_path=pathlib.Path('0002.txt')) def test_filter(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.filter(MessageFilter(answer_contains='Answer 1')) self.assertEqual(len(self.chat.messages), 1) self.assertEqual(self.chat.messages[0].question, 'Question 1') def test_sort(self) -> None: - self.chat.add_msgs([self.message2, self.message1]) + self.chat.add_messages([self.message2, self.message1]) self.chat.sort() self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') @@ -38,18 +38,18 @@ class TestChat(CmmTestCase): self.assertEqual(self.chat.messages[1].question, 'Question 1') def test_clear(self) -> None: - self.chat.add_msgs([self.message1]) + self.chat.add_messages([self.message1]) self.chat.clear() self.assertEqual(len(self.chat.messages), 0) - def test_add_msgs(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + def test_add_messages(self) -> None: + self.chat.add_messages([self.message1, self.message2]) self.assertEqual(len(self.chat.messages), 2) self.assertEqual(self.chat.messages[0].question, 'Question 1') self.assertEqual(self.chat.messages[1].question, 'Question 2') def test_tags(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) tags_all = self.chat.tags() self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')}) tags_pref = self.chat.tags(prefix='a') @@ -58,13 +58,13 @@ class TestChat(CmmTestCase): self.assertSetEqual(tags_cont, {Tag('btag2')}) def test_tags_frequency(self) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) tags_freq = self.chat.tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False) expected_output = f"""{'-'*terminal_width()} {Question.txt_header} @@ -81,7 +81,7 @@ Answer 2 @patch('sys.stdout', new_callable=StringIO) def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: - self.chat.add_msgs([self.message1, self.message2]) + self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False, with_tags=True, with_files=True) expected_output = f"""{'-'*terminal_width()} {TagLine.prefix} atag1 btag2 @@ -127,6 +127,17 @@ class TestChatDB(CmmTestCase): self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml')) self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt')) self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml')) + # make the next FID match the current state + next_fname = pathlib.Path(self.db_path.name) / '.next' + with open(next_fname, 'w') as f: + f.write('4') + + def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]: + """ + List all Message files in the given TemporaryDirectory. + """ + # exclude '.next' + return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*')) def tearDown(self) -> None: self.db_path.cleanup() @@ -184,11 +195,11 @@ class TestChatDB(CmmTestCase): def test_chat_db_fids(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name)) - self.assertEqual(chat_db.get_next_fid(), 1) - self.assertEqual(chat_db.get_next_fid(), 2) - self.assertEqual(chat_db.get_next_fid(), 3) + self.assertEqual(chat_db.get_next_fid(), 5) + self.assertEqual(chat_db.get_next_fid(), 6) + self.assertEqual(chat_db.get_next_fid(), 7) with open(chat_db.next_fname, 'r') as f: - self.assertEqual(f.read(), '3') + self.assertEqual(f.read(), '7') def test_chat_db_write(self) -> None: # create a new ChatDB instance @@ -203,7 +214,7 @@ class TestChatDB(CmmTestCase): # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files) self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files) @@ -216,14 +227,14 @@ class TestChatDB(CmmTestCase): self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml')) # check the timestamp of the files in the DB directory - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) old_timestamps = {file: file.stat().st_mtime for file in db_dir_files} # overwrite the messages in the db directory time.sleep(0.05) chat_db.write_db() # check if the written files are in the DB directory - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) @@ -314,12 +325,12 @@ class TestChatDB(CmmTestCase): # write the messages to the cache directory chat_db.write_cache() # check if the written files are in the cache directory - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 4) # now rewrite them to the DB dir and check for modified paths chat_db.write_db() - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files) self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files) @@ -333,15 +344,69 @@ class TestChatDB(CmmTestCase): message_cache = Message(question=Question("What the hell am I doing here?"), answer=Answer("You're a creep!"), file_path=pathlib.Path(self.cache_path.name, '0005.txt')) - chat_db.add_msgs([message_empty, message_cache]) + chat_db.add_messages([message_empty, message_cache]) # clear the cache and check the cache dir chat_db.clear_cache() - cache_dir_files = list(pathlib.Path(self.cache_path.name).glob('*')) + cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) # make sure that the DB messages (and the new message) are still there self.assertEqual(len(chat_db.messages), 5) - db_dir_files = list(pathlib.Path(self.db_path.name).glob('*')) + db_dir_files = self.message_list(self.db_path) self.assertEqual(len(db_dir_files), 4) # but not the message with the cache dir path self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages)) + + def test_chat_db_add(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + + # add new messages to the cache dir + message1 = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + chat_db.add_to_cache([message1]) + # check if the file_path has been correctly set + self.assertIsNotNone(message1.file_path) + self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr] + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + + # add new messages to the DB dir + message2 = Message(question=Question("Question 2"), + answer=Answer("Answer 2")) + chat_db.add_to_db([message2]) + # check if the file_path has been correctly set + self.assertIsNotNone(message2.file_path) + self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr] + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 5) + + with self.assertRaises(ChatError): + chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))]) + + def test_chat_db_write_messages(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + + # try to write a message without a valid file_path + message = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + with self.assertRaises(ChatError): + chat_db.write_messages([message]) + + # write a message with a valid file_path + message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt' + chat_db.write_messages([message]) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) -- 2.36.6 From abb7fdacb65a7e266f63f4c2397e76e5e5961338 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 4 Sep 2023 08:49:43 +0200 Subject: [PATCH 129/170] message / chat: output improvements --- chatmastermind/chat.py | 16 ++++------------ chatmastermind/message.py | 24 ++++++++++++++++++++++++ tests/test_chat.py | 16 ++++++++++++---- tests/test_message.py | 24 ++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 16 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 7e6df8f..c631dab 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -7,7 +7,7 @@ from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass from typing import TypeVar, Type, Optional, ClassVar, Any, Callable -from .message import Question, Answer, Message, MessageFilter, MessageError, source_code, message_in +from .message import Message, MessageFilter, MessageError, message_in from .tags import Tag ChatInst = TypeVar('ChatInst', bound='Chat') @@ -170,18 +170,10 @@ class Chat: output: list[str] = [] for message in self.messages: if source_code_only: - output.extend(source_code(message.question, include_delims=True)) + output.append(message.to_str(source_code_only=True)) continue - output.append('-' * terminal_width()) - if with_tags: - output.append(message.tags_str()) - if with_files: - output.append('FILE: ' + str(message.file_path)) - output.append(Question.txt_header) - output.append(message.question) - if message.answer: - output.append(Answer.txt_header) - output.append(message.answer) + output.append(message.to_str(with_tags, with_files)) + output.append('\n' + ('-' * terminal_width()) + '\n') if paged: print_paged('\n'.join(output)) else: diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 0fb949c..35de3b9 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -392,6 +392,30 @@ class Message(): data[cls.file_yaml_key] = file_path return cls.from_dict(data) + def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str: + """ + Return the current Message as a string. + """ + output: list[str] = [] + if source_code_only: + # use the source code from answer only + if self.answer: + output.extend(self.answer.source_code(include_delims=True)) + return '\n'.join(output) if len(output) > 0 else '' + if with_tags: + output.append(self.tags_str()) + if with_file: + output.append('FILE: ' + str(self.file_path)) + output.append(Question.txt_header) + output.append(self.question) + if self.answer: + output.append(Answer.txt_header) + output.append(self.answer) + return '\n'.join(output) + + def __str__(self) -> str: + return self.to_str(False, False, False) + def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 """ Write a Message to the given file. Type is determined based on the suffix. diff --git a/tests/test_chat.py b/tests/test_chat.py index a1c020e..f8302eb 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -66,16 +66,20 @@ class TestChat(CmmTestCase): def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False) - expected_output = f"""{'-'*terminal_width()} -{Question.txt_header} + expected_output = f"""{Question.txt_header} Question 1 {Answer.txt_header} Answer 1 + {'-'*terminal_width()} + {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 + +{'-'*terminal_width()} + """ self.assertEqual(mock_stdout.getvalue(), expected_output) @@ -83,20 +87,24 @@ Answer 2 def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) self.chat.print(paged=False, with_tags=True, with_files=True) - expected_output = f"""{'-'*terminal_width()} -{TagLine.prefix} atag1 btag2 + expected_output = f"""{TagLine.prefix} atag1 btag2 FILE: 0001.txt {Question.txt_header} Question 1 {Answer.txt_header} Answer 1 + {'-'*terminal_width()} + {TagLine.prefix} btag2 FILE: 0002.txt {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 + +{'-'*terminal_width()} + """ self.assertEqual(mock_stdout.getvalue(), expected_output) diff --git a/tests/test_message.py b/tests/test_message.py index e860538..a49c893 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -804,3 +804,27 @@ class MessageRenameTagsTestCase(CmmTestCase): self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))}) self.assertIsNotNone(self.message.tags) self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] + + +class MessageToStrTestCase(CmmTestCase): + def setUp(self) -> None: + self.message = Message(Question('This is a question.'), + Answer('This is an answer.'), + tags={Tag('atag1'), Tag('btag2')}, + file_path=pathlib.Path('/tmp/foo/bla')) + + def test_to_str(self) -> None: + expected_output = f"""{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer.""" + self.assertEqual(self.message.to_str(), expected_output) + + def test_to_str_with_tags_and_file(self) -> None: + expected_output = f"""{TagLine.prefix} atag1 btag2 +FILE: /tmp/foo/bla +{Question.txt_header} +This is a question. +{Answer.txt_header} +This is an answer.""" + self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output) -- 2.36.6 From e1414835c8c2cdc96d9c425b7d585afb3ffbb261 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 6 Sep 2023 08:16:55 +0200 Subject: [PATCH 130/170] chat: added functions for finding and deleting messages --- chatmastermind/chat.py | 52 ++++++++++++++++++++++++++++++++---------- tests/test_chat.py | 22 ++++++++++++++++++ 2 files changed, 62 insertions(+), 12 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index c631dab..4e8fb20 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -2,7 +2,7 @@ Module implementing various chat classes and functions for managing a chat history. """ import shutil -import pathlib +from pathlib import Path from pprint import PrettyPrinter from pydoc import pager from dataclasses import dataclass @@ -30,7 +30,7 @@ def print_paged(text: str) -> None: pager(text) -def read_dir(dir_path: pathlib.Path, +def read_dir(dir_path: Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> list[Message]: """ @@ -55,9 +55,9 @@ def read_dir(dir_path: pathlib.Path, return messages -def make_file_path(dir_path: pathlib.Path, +def make_file_path(dir_path: Path, file_suffix: str, - next_fid: Callable[[], int]) -> pathlib.Path: + next_fid: Callable[[], int]) -> Path: """ Create a file_path for the given directory using the given file_suffix and ID generator function. @@ -65,7 +65,7 @@ def make_file_path(dir_path: pathlib.Path, return dir_path / f"{next_fid():04d}{file_suffix}" -def write_dir(dir_path: pathlib.Path, +def write_dir(dir_path: Path, messages: list[Message], file_suffix: str, next_fid: Callable[[], int]) -> None: @@ -90,7 +90,7 @@ def write_dir(dir_path: pathlib.Path, message.to_file(file_path) -def clear_dir(dir_path: pathlib.Path, +def clear_dir(dir_path: Path, glob: Optional[str] = None) -> None: """ Deletes all Message files in the given directory. @@ -139,6 +139,34 @@ class Chat: self.messages += messages self.sort() + def latest_message(self) -> Optional[Message]: + """ + Returns the last added message (according to the file ID). + """ + if len(self.messages) > 0: + self.sort() + return self.messages[-1] + else: + return None + + def find_messages(self, msg_names: list[str]) -> list[Message]: + """ + Search and return the messages with the given names. Names can either be filenames + (incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the + caller should check the result if he requires all messages). + """ + return [m for m in self.messages + if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + + def remove_messages(self, msg_names: list[str]) -> None: + """ + Remove the messages with the given names. Names can either be filenames + (incl. the suffix) or full paths. + """ + self.messages = [m for m in self.messages + if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)] + self.sort() + def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ Get the tags of all messages, optionally filtered by prefix or substring. @@ -192,8 +220,8 @@ class ChatDB(Chat): default_file_suffix: ClassVar[str] = '.txt' - cache_path: pathlib.Path - db_path: pathlib.Path + cache_path: Path + db_path: Path # a MessageFilter that all messages must match (if given) mfilter: Optional[MessageFilter] = None file_suffix: str = default_file_suffix @@ -209,8 +237,8 @@ class ChatDB(Chat): @classmethod def from_dir(cls: Type[ChatDBInst], - cache_path: pathlib.Path, - db_path: pathlib.Path, + cache_path: Path, + db_path: Path, glob: Optional[str] = None, mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ @@ -230,8 +258,8 @@ class ChatDB(Chat): @classmethod def from_messages(cls: Type[ChatDBInst], - cache_path: pathlib.Path, - db_path: pathlib.Path, + cache_path: Path, + db_path: Path, messages: list[Message], mfilter: Optional[MessageFilter] = None) -> ChatDBInst: """ diff --git a/tests/test_chat.py b/tests/test_chat.py index f8302eb..d81a97a 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -62,6 +62,28 @@ class TestChat(CmmTestCase): tags_freq = self.chat.tags_frequency() self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2}) + def test_find_remove_messages(self) -> None: + self.chat.add_messages([self.message1, self.message2]) + msgs = self.chat.find_messages(['0001.txt']) + self.assertListEqual(msgs, [self.message1]) + msgs = self.chat.find_messages(['0001.txt', '0002.txt']) + self.assertListEqual(msgs, [self.message1, self.message2]) + # add new Message with full path + message3 = Message(Question('Question 2'), + Answer('Answer 2'), + {Tag('btag2')}, + file_path=pathlib.Path('/foo/bla/0003.txt')) + self.chat.add_messages([message3]) + # find new Message by full path + msgs = self.chat.find_messages(['/foo/bla/0003.txt']) + self.assertListEqual(msgs, [message3]) + # find Message with full path only by filename + msgs = self.chat.find_messages(['0003.txt']) + self.assertListEqual(msgs, [message3]) + # remove last message + self.chat.remove_messages(['0003.txt']) + self.assertListEqual(self.chat.messages, [self.message1, self.message2]) + @patch('sys.stdout', new_callable=StringIO) def test_print(self, mock_stdout: StringIO) -> None: self.chat.add_messages([self.message1, self.message2]) -- 2.36.6 From 8923a13352980ec6af36e1f381f177ae5a5a1841 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 12:46:23 +0200 Subject: [PATCH 131/170] cmm: the 'tags' command now uses the new 'ChatDB' --- chatmastermind/main.py | 34 +++++++++++++++++++++------------- chatmastermind/utils.py | 5 ----- tests/test_main.py | 2 +- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index c30ea4e..f9eccba 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -7,10 +7,11 @@ import sys import argcomplete import argparse import pathlib -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, print_tags_frequency, ChatType -from .storage import save_answers, create_chat_hist, get_tags, get_tags_unique, read_file, dump_data +from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType +from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, dump_data from .api_client import ai, openai_api_key, print_models from .configuration import Config +from .chat import ChatDB from itertools import zip_longest from typing import Any @@ -56,12 +57,17 @@ def create_question_with_hist(args: argparse.Namespace, return chat, full_question, tags -def tag_cmd(args: argparse.Namespace, config: Config) -> None: +def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ - Handler for the 'tag' command. + Handler for the 'tags' command. """ + chat = ChatDB.from_dir(cache_path=pathlib.Path('.'), + db_path=pathlib.Path(config.db)) if args.list: - print_tags_frequency(get_tags(config, None)) + tags_freq = chat.tags_frequency(args.prefix, args.contain) + for tag, freq in tags_freq.items(): + print(f"- {tag}: {freq}") + # TODO: add renaming def config_cmd(args: argparse.Namespace, config: Config) -> None: @@ -190,14 +196,16 @@ def create_parser() -> argparse.ArgumentParser: hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') - # 'tag' command parser - tag_cmd_parser = cmdparser.add_parser('tag', - help="Manage tags.", - aliases=['t']) - tag_cmd_parser.set_defaults(func=tag_cmd) - tag_group = tag_cmd_parser.add_mutually_exclusive_group(required=True) - tag_group.add_argument('-l', '--list', help="List all tags and their frequency", - action='store_true') + # 'tags' command parser + tags_cmd_parser = cmdparser.add_parser('tags', + help="Manage tags.", + aliases=['t']) + tags_cmd_parser.set_defaults(func=tags_cmd) + tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True) + tags_group.add_argument('-l', '--list', help="List all tags and their frequency", + action='store_true') + tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix") + tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring") # 'config' command parser config_cmd_parser = cmdparser.add_parser('config', diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py index 6543ce1..4135ae3 100644 --- a/chatmastermind/utils.py +++ b/chatmastermind/utils.py @@ -79,8 +79,3 @@ def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = Fals print(message['content']) else: print(f"{message['role'].upper()}: {message['content']}") - - -def print_tags_frequency(tags: list[str]) -> None: - for tag in sorted(set(tags)): - print(f"- {tag}: {tags.count(tag)}") diff --git a/tests/test_main.py b/tests/test_main.py index db5fcdb..23c3d00 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -227,7 +227,7 @@ class TestCreateParser(CmmTestCase): mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('tag', help=ANY, aliases=ANY) + mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) self.assertTrue('.config.yaml' in parser.get_default('config')) -- 2.36.6 From 4c378dde854499771f377b29d8ae40b4984c7eec Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 08:21:49 +0200 Subject: [PATCH 132/170] cmm: the 'hist' command now uses the new 'ChatDB' --- chatmastermind/main.py | 58 +++++++++++++++++++++++------------------- tests/test_main.py | 15 ++++++----- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index f9eccba..8aef252 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -12,6 +12,7 @@ from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB +from .message import MessageFilter from itertools import zip_longest from typing import Any @@ -31,11 +32,11 @@ def create_question_with_hist(args: argparse.Namespace, by the specified tags. """ tags = args.tags or [] - extags = args.extags or [] + etags = args.etags or [] otags = args.output_tags or [] - if not args.only_source_code: - print_tag_args(tags, extags, otags) + if not args.source_code_only: + print_tag_args(tags, etags, otags) question_parts = [] question_list = args.question if args.question is not None else [] @@ -52,8 +53,10 @@ def create_question_with_hist(args: argparse.Namespace, question_parts.append(f"```\n{r.read().strip()}\n```") full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, extags, config, - args.match_all_tags, False, False) + chat = create_chat_hist(full_question, tags, etags, config, + match_all_tags=True if args.atags else False, # FIXME + with_tags=False, + with_file=False) return chat, full_question, tags @@ -94,7 +97,7 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None: if args.model: config.openai.model = args.model chat, question, tags = create_question_with_hist(args, config) - print_chat_hist(chat, False, args.only_source_code) + print_chat_hist(chat, False, args.source_code_only) otags = args.output_tags or [] answers, usage = ai(chat, config, args.number) save_answers(question, answers, tags, otags, config) @@ -106,14 +109,18 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'hist' command. """ - tags = args.tags or [] - extags = args.extags or [] - chat = create_chat_hist(None, tags, extags, config, - args.match_all_tags, - args.with_tags, - args.with_files) - print_chat_hist(chat, args.dump, args.only_source_code) + mfilter = MessageFilter(tags_or=args.tags, + tags_and=args.atags, + tags_not=args.etags, + question_contains=args.question, + answer_contains=args.answer) + chat = ChatDB.from_dir(Path('.'), + Path(config.db), + mfilter=mfilter) + chat.print(args.source_code_only, + args.with_tags, + args.with_files) def print_cmd(args: argparse.Namespace, config: Config) -> None: @@ -129,7 +136,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: else: print(f"Unknown file type: {args.file}") sys.exit(1) - if args.only_source_code: + if args.source_code_only: display_source_code(data['answer']) elif args.answer: print(data['answer'].strip()) @@ -153,18 +160,17 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+', - help='List of tag names', metavar='TAGS') + help='List of tag names (one must match)', metavar='TAGS') tag_arg.completer = tags_completer # type: ignore - extag_arg = tag_parser.add_argument('-e', '--extags', nargs='+', - help='List of tag names to exclude', metavar='EXTAGS') - extag_arg.completer = tags_completer # type: ignore + atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+', + help='List of tag names (all must match)', metavar='TAGS') + atag_arg.completer = tags_completer # type: ignore + etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+', + help='List of tag names to exclude', metavar='ETAGS') + etag_arg.completer = tags_completer # type: ignore otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', help='List of output tag names, default is input', metavar='OTAGS') otag_arg.completer = tags_completer # type: ignore - tag_parser.add_argument('-a', '--match-all-tags', - help="All given tags must match when selecting chat history entries", - action='store_true') - # enable autocompletion for tags # 'ask' command parser ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], @@ -179,7 +185,7 @@ def create_parser() -> argparse.ArgumentParser: ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1) ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - ask_cmd_parser.add_argument('-S', '--only-source-code', help='Add pure source code to the chat history', + ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', action='store_true') # 'hist' command parser @@ -187,14 +193,14 @@ def create_parser() -> argparse.ArgumentParser: help="Print chat history.", aliases=['h']) hist_cmd_parser.set_defaults(func=hist_cmd) - hist_cmd_parser.add_argument('-d', '--dump', help="Print chat history as Python structure", - action='store_true') hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", action='store_true') hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.", action='store_true') - hist_cmd_parser.add_argument('-S', '--only-source-code', help='Print only source code', + hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', action='store_true') + hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring') + hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring') # 'tags' command parser tags_cmd_parser = cmdparser.add_parser('tags', diff --git a/tests/test_main.py b/tests/test_main.py index 23c3d00..bb9aa2a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -115,11 +115,12 @@ class TestHandleQuestion(CmmTestCase): self.question = "test question" self.args = argparse.Namespace( tags=['tag1'], - extags=['extag1'], + atags=None, + etags=['etag1'], output_tags=None, question=[self.question], source=None, - only_source_code=False, + source_code_only=False, number=3, max_tokens=None, temperature=None, @@ -143,16 +144,18 @@ class TestHandleQuestion(CmmTestCase): with patch("chatmastermind.storage.open", open_mock): ask_cmd(self.args, self.config) mock_print_tag_args.assert_called_once_with(self.args.tags, - self.args.extags, + self.args.etags, []) mock_create_chat_hist.assert_called_once_with(self.question, self.args.tags, - self.args.extags, + self.args.etags, self.config, - False, False, False) + match_all_tags=False, + with_tags=False, + with_file=False) mock_print_chat_hist.assert_called_once_with('test_chat', False, - self.args.only_source_code) + self.args.source_code_only) mock_ai.assert_called_with("test_chat", self.config, self.args.number) -- 2.36.6 From 5e4ec70072fe77a888973746594e56df01839dfd Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 2 Sep 2023 08:42:59 +0200 Subject: [PATCH 133/170] cmm: tags completion now uses 'Message.tags_from_dir' (fixes tag completion for me) --- chatmastermind/main.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 8aef252..1796f69 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,13 +6,13 @@ import yaml import sys import argcomplete import argparse -import pathlib +from pathlib import Path from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType -from .storage import save_answers, create_chat_hist, get_tags_unique, read_file, dump_data +from .storage import save_answers, create_chat_hist, read_file, dump_data from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB -from .message import MessageFilter +from .message import Message, MessageFilter from itertools import zip_longest from typing import Any @@ -64,8 +64,8 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'tags' command. """ - chat = ChatDB.from_dir(cache_path=pathlib.Path('.'), - db_path=pathlib.Path(config.db)) + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) if args.list: tags_freq = chat.tags_frequency(args.prefix, args.contain) for tag, freq in tags_freq.items(): @@ -127,7 +127,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'print' command. """ - fname = pathlib.Path(args.file) + fname = Path(args.file) if fname.suffix == '.yaml': with open(args.file, 'r') as f: data = yaml.load(f, Loader=yaml.FullLoader) -- 2.36.6 From e186afbef046e04f5588805c6def89ee6a5c5eee Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 4 Sep 2023 22:07:02 +0200 Subject: [PATCH 134/170] cmm: the 'print' command now uses 'Message.from_file()' --- chatmastermind/main.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 1796f69..ed67f7b 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -2,17 +2,16 @@ # -*- coding: utf-8 -*- # vim: set fileencoding=utf-8 : -import yaml import sys import argcomplete import argparse from pathlib import Path -from .utils import terminal_width, print_tag_args, print_chat_hist, display_source_code, ChatType -from .storage import save_answers, create_chat_hist, read_file, dump_data +from .utils import terminal_width, print_tag_args, print_chat_hist, ChatType +from .storage import save_answers, create_chat_hist from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB -from .message import Message, MessageFilter +from .message import Message, MessageFilter, MessageError from itertools import zip_longest from typing import Any @@ -128,13 +127,12 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: Handler for the 'print' command. """ fname = Path(args.file) - if fname.suffix == '.yaml': - with open(args.file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - elif fname.suffix == '.txt': - data = read_file(fname) - else: - print(f"Unknown file type: {args.file}") + try: + message = Message.from_file(fname) + if message: + print(message.to_str(source_code_only=args.source_code_only)) + except MessageError: + print(f"File is not a valid message: {args.file}") sys.exit(1) if args.source_code_only: display_source_code(data['answer']) @@ -227,14 +225,22 @@ def create_parser() -> argparse.ArgumentParser: # 'print' command parser print_cmd_parser = cmdparser.add_parser('print', - help="Print files.", + help="Print message files.", aliases=['p']) print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) +<<<<<<< HEAD print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group() print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true') print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true') print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') +||||||| parent of bf1cbff (cmm: the 'print' command now uses 'Message.from_file()') + print_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', + action='store_true') +======= + print_cmd_parser.add_argument('-S', '--source-code-only', help='Print source code only (from the answer, if available)', + action='store_true') +>>>>>>> bf1cbff (cmm: the 'print' command now uses 'Message.from_file()') argcomplete.autocomplete(parser) return parser -- 2.36.6 From 4bd144c4d75a2892947e470a863dc87d6f4f0633 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 09:00:15 +0200 Subject: [PATCH 135/170] added new module 'ai.py' --- chatmastermind/ai.py | 63 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 chatmastermind/ai.py diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py new file mode 100644 index 0000000..4a8b914 --- /dev/null +++ b/chatmastermind/ai.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass +from typing import Protocol, Optional, Union +from .configuration import AIConfig +from .tags import Tag +from .message import Message +from .chat import Chat + + +class AIError(Exception): + pass + + +@dataclass +class Tokens: + prompt: int = 0 + completion: int = 0 + total: int = 0 + + +@dataclass +class AIResponse: + """ + The response to an AI request. Consists of one or more messages + (each containing the question and a single answer) and the nr. + of used tokens. + """ + messages: list[Message] + tokens: Optional[Tokens] = None + + +class AI(Protocol): + """ + The base class for AI clients. + """ + + name: str + config: AIConfig + + def request(self, + question: Message, + context: Chat, + num_answers: int = 1, + otags: Optional[set[Tag]] = None) -> AIResponse: + """ + Make an AI request, asking the given question with the given + context (i. e. chat history). The nr. of requested answers + corresponds to the nr. of messages in the 'AIResponse'. + """ + raise NotImplementedError + + def models(self) -> list[str]: + """ + Return all models supported by this AI. + """ + raise NotImplementedError + + def tokens(self, data: Union[Message, Chat]) -> int: + """ + Computes the nr. of AI language tokens for the given message + or chat. Note that the computation may not be 100% accurate + and is not implemented for all AIs. + """ + raise NotImplementedError -- 2.36.6 From 823d3bf7dc1ed2bc40dc3604007a21e0a69bb475 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 1 Sep 2023 10:18:09 +0200 Subject: [PATCH 136/170] added new module 'openai.py' --- chatmastermind/ais/openai.py | 96 ++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 chatmastermind/ais/openai.py diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py new file mode 100644 index 0000000..74438b8 --- /dev/null +++ b/chatmastermind/ais/openai.py @@ -0,0 +1,96 @@ +""" +Implements the OpenAI client classes and functions. +""" +import openai +from typing import Optional, Union +from ..tags import Tag +from ..message import Message, Answer +from ..chat import Chat +from ..ai import AI, AIResponse, Tokens +from ..configuration import OpenAIConfig + +ChatType = list[dict[str, str]] + + +class OpenAI(AI): + """ + The OpenAI AI client. + """ + + def __init__(self, name: str, config: OpenAIConfig) -> None: + self.name = name + self.config = config + + def request(self, + question: Message, + chat: Chat, + num_answers: int = 1, + otags: Optional[set[Tag]] = None) -> AIResponse: + """ + Make an AI request, asking the given question with the given + chat history. The nr. of requested answers corresponds to the + nr. of messages in the 'AIResponse'. + """ + # FIXME: use real 'system' message (store in OpenAIConfig) + oai_chat = self.openai_chat(chat, "system", question) + response = openai.ChatCompletion.create( + model=self.config.model, + messages=oai_chat, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + top_p=self.config.top_p, + n=num_answers, + frequency_penalty=self.config.frequency_penalty, + presence_penalty=self.config.presence_penalty) + answers: list[Message] = [] + for choice in response['choices']: # type: ignore + answers.append(Message(question=question.question, + answer=Answer(choice['message']['content']), + tags=otags, + ai=self.name, + model=self.config.model)) + return AIResponse(answers, Tokens(response['usage']['prompt'], + response['usage']['completion'], + response['usage']['total'])) + + def models(self) -> list[str]: + """ + Return all models supported by this AI. + """ + raise NotImplementedError + + def print_models(self) -> None: + """ + Print all models supported by the current AI. + """ + not_ready = [] + for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): + if engine['ready']: + print(engine['id']) + else: + not_ready.append(engine['id']) + if len(not_ready) > 0: + print('\nNot ready: ' + ', '.join(not_ready)) + + def openai_chat(self, chat: Chat, system: str, + question: Optional[Message] = None) -> ChatType: + """ + Create a chat history with system message in OpenAI format. + Optionally append a new question. + """ + oai_chat: ChatType = [] + + def append(role: str, content: str) -> None: + oai_chat.append({'role': role, 'content': content.replace("''", "'")}) + + append('system', system) + for message in chat.messages: + if message.answer: + append('user', message.question) + append('assistant', message.answer) + if question: + append('user', question.question) + return oai_chat + + def tokens(self, data: Union[Message, Chat]) -> int: + raise NotImplementedError -- 2.36.6 From 7d154522420d112b44a94f3717726331d1ae0af7 Mon Sep 17 00:00:00 2001 From: juk0de Date: Tue, 5 Sep 2023 23:24:20 +0200 Subject: [PATCH 137/170] added new module 'ai_factory' --- chatmastermind/ai_factory.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 chatmastermind/ai_factory.py diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py new file mode 100644 index 0000000..c90366b --- /dev/null +++ b/chatmastermind/ai_factory.py @@ -0,0 +1,20 @@ +""" +Creates different AI instances, based on the given configuration. +""" + +import argparse +from .configuration import Config +from .ai import AI, AIError +from .ais.openai import OpenAI + + +def create_ai(args: argparse.Namespace, config: Config) -> AI: + """ + Creates an AI subclass instance from the given args and configuration. + """ + if args.ai == 'openai': + # FIXME: create actual 'OpenAIConfig' and set values from 'args' + # FIXME: use actual name from config + return OpenAI("openai", config.openai) + else: + raise AIError(f"AI '{args.ai}' is not supported") -- 2.36.6 From 034e4093f1ff65d352ea41b89155eb22153477f8 Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 4 Sep 2023 22:35:53 +0200 Subject: [PATCH 138/170] cmm: added 'question' command --- chatmastermind/main.py | 103 +++++++++++++++++++++++++++++++++-------- tests/test_main.py | 18 +++---- 2 files changed, 93 insertions(+), 28 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index ed67f7b..67eafae 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -11,7 +11,9 @@ from .storage import save_answers, create_chat_hist from .api_client import ai, openai_api_key, print_models from .configuration import Config from .chat import ChatDB -from .message import Message, MessageFilter, MessageError +from .message import Message, MessageFilter, MessageError, Question +from .ai_factory import create_ai +from .ai import AI, AIResponse from itertools import zip_longest from typing import Any @@ -30,12 +32,12 @@ def create_question_with_hist(args: argparse.Namespace, Creates the "AI request", including the question and chat history as determined by the specified tags. """ - tags = args.tags or [] - etags = args.etags or [] + tags = args.or_tags or [] + xtags = args.exclude_tags or [] otags = args.output_tags or [] if not args.source_code_only: - print_tag_args(tags, etags, otags) + print_tag_args(tags, xtags, otags) question_parts = [] question_list = args.question if args.question is not None else [] @@ -52,8 +54,8 @@ def create_question_with_hist(args: argparse.Namespace, question_parts.append(f"```\n{r.read().strip()}\n```") full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, etags, config, - match_all_tags=True if args.atags else False, # FIXME + chat = create_chat_hist(full_question, tags, xtags, config, + match_all_tags=True if args.and_tags else False, # FIXME with_tags=False, with_file=False) return chat, full_question, tags @@ -85,6 +87,47 @@ def config_cmd(args: argparse.Namespace, config: Config) -> None: config.to_file(args.config) +def question_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'question' command. + """ + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) + # if it's a new question, create and store it immediately + if args.ask or args.create: + message = Message(question=Question(args.question), + tags=args.ouput_tags, # FIXME + ai=args.ai, + model=args.model) + chat.add_to_cache([message]) + if args.create: + return + + # create the correct AI instance + ai: AI = create_ai(args, config) + if args.ask: + response: AIResponse = ai.request(message, + chat, + args.num_answers, # FIXME + args.otags) # FIXME + assert response + # TODO: + # * add answer to the message above (and create + # more messages for any additional answers) + pass + elif args.repeat: + lmessage = chat.latest_message() + assert lmessage + # TODO: repeat either the last question or the + # one(s) given in 'args.repeat' (overwrite + # existing ones if 'args.overwrite' is True) + pass + elif args.process: + # TODO: process either all questions without an + # answer or the one(s) given in 'args.process' + pass + + def ask_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'ask' command. @@ -98,7 +141,7 @@ def ask_cmd(args: argparse.Namespace, config: Config) -> None: chat, question, tags = create_question_with_hist(args, config) print_chat_hist(chat, False, args.source_code_only) otags = args.output_tags or [] - answers, usage = ai(chat, config, args.number) + answers, usage = ai(chat, config, args.num_answers) save_answers(question, answers, tags, otags, config) print("-" * terminal_width()) print(f"Usage: {usage}") @@ -109,9 +152,9 @@ def hist_cmd(args: argparse.Namespace, config: Config) -> None: Handler for the 'hist' command. """ - mfilter = MessageFilter(tags_or=args.tags, - tags_and=args.atags, - tags_not=args.etags, + mfilter = MessageFilter(tags_or=args.or_tags, + tags_and=args.and_tags, + tags_not=args.exclude_tags, question_contains=args.question, answer_contains=args.answer) chat = ChatDB.from_dir(Path('.'), @@ -147,7 +190,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") - parser.add_argument('-c', '--config', help='Config file name.', default=default_config) + parser.add_argument('-C', '--config', help='Config file name.', default=default_config) # subcommand-parser cmdparser = parser.add_subparsers(dest='command', @@ -157,19 +200,41 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) - tag_arg = tag_parser.add_argument('-t', '--tags', nargs='+', - help='List of tag names (one must match)', metavar='TAGS') + tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+', + help='List of tag names (one must match)', metavar='OTAGS') tag_arg.completer = tags_completer # type: ignore - atag_arg = tag_parser.add_argument('-a', '--atags', nargs='+', - help='List of tag names (all must match)', metavar='TAGS') + atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+', + help='List of tag names (all must match)', metavar='ATAGS') atag_arg.completer = tags_completer # type: ignore - etag_arg = tag_parser.add_argument('-e', '--etags', nargs='+', - help='List of tag names to exclude', metavar='ETAGS') + etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+', + help='List of tag names to exclude', metavar='XTAGS') etag_arg.completer = tags_completer # type: ignore otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', - help='List of output tag names, default is input', metavar='OTAGS') + help='List of output tag names, default is input', metavar='OUTTAGS') otag_arg.completer = tags_completer # type: ignore + # 'question' command parser + question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser], + help="ask, create and process questions.", + aliases=['q']) + question_cmd_parser.set_defaults(func=question_cmd) + question_group = question_cmd_parser.add_mutually_exclusive_group(required=True) + question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question') + question_group.add_argument('-c', '--create', nargs='+', help='Create a question') + question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question') + question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') + question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', + action='store_true') + question_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) + question_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) + question_cmd_parser.add_argument('-A', '--AI', help='AI to use') + question_cmd_parser.add_argument('-M', '--model', help='Model to use') + question_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, + default=1) + question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') + question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', + action='store_true') + # 'ask' command parser ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], help="Ask a question.", @@ -180,7 +245,7 @@ def create_parser() -> argparse.ArgumentParser: ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) ask_cmd_parser.add_argument('-M', '--model', help='Model to use') - ask_cmd_parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, + ask_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, default=1) ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', diff --git a/tests/test_main.py b/tests/test_main.py index bb9aa2a..ce9121a 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -114,14 +114,14 @@ class TestHandleQuestion(CmmTestCase): def setUp(self) -> None: self.question = "test question" self.args = argparse.Namespace( - tags=['tag1'], - atags=None, - etags=['etag1'], + or_tags=['tag1'], + and_tags=None, + exclude_tags=['xtag1'], output_tags=None, question=[self.question], source=None, source_code_only=False, - number=3, + num_answers=3, max_tokens=None, temperature=None, model=None, @@ -143,12 +143,12 @@ class TestHandleQuestion(CmmTestCase): open_mock = MagicMock() with patch("chatmastermind.storage.open", open_mock): ask_cmd(self.args, self.config) - mock_print_tag_args.assert_called_once_with(self.args.tags, - self.args.etags, + mock_print_tag_args.assert_called_once_with(self.args.or_tags, + self.args.exclude_tags, []) mock_create_chat_hist.assert_called_once_with(self.question, - self.args.tags, - self.args.etags, + self.args.or_tags, + self.args.exclude_tags, self.config, match_all_tags=False, with_tags=False, @@ -158,7 +158,7 @@ class TestHandleQuestion(CmmTestCase): self.args.source_code_only) mock_ai.assert_called_with("test_chat", self.config, - self.args.number) + self.args.num_answers) expected_calls = [] for num, answer in enumerate(mock_ai.return_value[0], start=1): title = f'-- ANSWER {num} ' -- 2.36.6 From d6bb5800b16601a7fd9086f0b6c50991b78aed6e Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 6 Sep 2023 22:12:05 +0200 Subject: [PATCH 139/170] test_main: temporarily disabled all testcases --- tests/test_chat.py | 6 +- tests/test_main.py | 468 +++++++++++++++++++++--------------------- tests/test_message.py | 34 +-- tests/test_tags.py | 6 +- 4 files changed, 257 insertions(+), 257 deletions(-) diff --git a/tests/test_chat.py b/tests/test_chat.py index d81a97a..8e4aa8c 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -1,3 +1,4 @@ +import unittest import pathlib import tempfile import time @@ -6,10 +7,9 @@ from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError -from .test_main import CmmTestCase -class TestChat(CmmTestCase): +class TestChat(unittest.TestCase): def setUp(self) -> None: self.chat = Chat([]) self.message1 = Message(Question('Question 1'), @@ -131,7 +131,7 @@ Answer 2 self.assertEqual(mock_stdout.getvalue(), expected_output) -class TestChatDB(CmmTestCase): +class TestChatDB(unittest.TestCase): def setUp(self) -> None: self.db_path = tempfile.TemporaryDirectory() self.cache_path = tempfile.TemporaryDirectory() diff --git a/tests/test_main.py b/tests/test_main.py index ce9121a..91e6462 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,236 +1,236 @@ -import unittest -import io -import pathlib -import argparse -from chatmastermind.utils import terminal_width -from chatmastermind.main import create_parser, ask_cmd -from chatmastermind.api_client import ai -from chatmastermind.configuration import Config -from chatmastermind.storage import create_chat_hist, save_answers, dump_data -from unittest import mock -from unittest.mock import patch, MagicMock, Mock, ANY +# import unittest +# import io +# import pathlib +# import argparse +# from chatmastermind.utils import terminal_width +# from chatmastermind.main import create_parser, ask_cmd +# from chatmastermind.api_client import ai +# from chatmastermind.configuration import Config +# from chatmastermind.storage import create_chat_hist, save_answers, dump_data +# from unittest import mock +# from unittest.mock import patch, MagicMock, Mock, ANY -class CmmTestCase(unittest.TestCase): - """ - Base class for all cmm testcases. - """ - def dummy_config(self, db: str) -> Config: - """ - Creates a dummy configuration. - """ - return Config.from_dict( - {'system': 'dummy_system', - 'db': db, - 'openai': {'api_key': 'dummy_key', - 'model': 'dummy_model', - 'max_tokens': 4000, - 'temperature': 1.0, - 'top_p': 1, - 'frequency_penalty': 0, - 'presence_penalty': 0}} - ) - - -class TestCreateChat(CmmTestCase): - - def setUp(self) -> None: - self.config = self.dummy_config(db='test_files') - self.question = "test question" - self.tags = ['test_tag'] - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( - {'question': 'test_content', 'answer': 'some answer', - 'tags': ['test_tag']})) - - test_chat = create_chat_hist(self.question, self.tags, None, self.config) - - self.assertEqual(len(test_chat), 4) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': 'test_content'}) - self.assertEqual(test_chat[2], - {'role': 'assistant', 'content': 'some answer'}) - self.assertEqual(test_chat[3], - {'role': 'user', 'content': self.question}) - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( - {'question': 'test_content', 'answer': 'some answer', - 'tags': ['other_tag']})) - - test_chat = create_chat_hist(self.question, self.tags, None, self.config) - - self.assertEqual(len(test_chat), 2) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': self.question}) - - @patch('os.listdir') - @patch('pathlib.Path.iterdir') - @patch('builtins.open') - def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: - listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] - iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] - open_mock.side_effect = ( - io.StringIO(dump_data({'question': 'test_content', - 'answer': 'some answer', - 'tags': ['test_tag']})), - io.StringIO(dump_data({'question': 'test_content2', - 'answer': 'some answer2', - 'tags': ['test_tag2']})), - ) - - test_chat = create_chat_hist(self.question, [], None, self.config) - - self.assertEqual(len(test_chat), 6) - self.assertEqual(test_chat[0], - {'role': 'system', 'content': self.config.system}) - self.assertEqual(test_chat[1], - {'role': 'user', 'content': 'test_content'}) - self.assertEqual(test_chat[2], - {'role': 'assistant', 'content': 'some answer'}) - self.assertEqual(test_chat[3], - {'role': 'user', 'content': 'test_content2'}) - self.assertEqual(test_chat[4], - {'role': 'assistant', 'content': 'some answer2'}) - - -class TestHandleQuestion(CmmTestCase): - - def setUp(self) -> None: - self.question = "test question" - self.args = argparse.Namespace( - or_tags=['tag1'], - and_tags=None, - exclude_tags=['xtag1'], - output_tags=None, - question=[self.question], - source=None, - source_code_only=False, - num_answers=3, - max_tokens=None, - temperature=None, - model=None, - match_all_tags=False, - with_tags=False, - with_file=False, - ) - self.config = self.dummy_config(db='test_files') - - @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") - @patch("chatmastermind.main.print_tag_args") - @patch("chatmastermind.main.print_chat_hist") - @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) - @patch("chatmastermind.utils.pp") - @patch("builtins.print") - def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, - mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, - mock_create_chat_hist: MagicMock) -> None: - open_mock = MagicMock() - with patch("chatmastermind.storage.open", open_mock): - ask_cmd(self.args, self.config) - mock_print_tag_args.assert_called_once_with(self.args.or_tags, - self.args.exclude_tags, - []) - mock_create_chat_hist.assert_called_once_with(self.question, - self.args.or_tags, - self.args.exclude_tags, - self.config, - match_all_tags=False, - with_tags=False, - with_file=False) - mock_print_chat_hist.assert_called_once_with('test_chat', - False, - self.args.source_code_only) - mock_ai.assert_called_with("test_chat", - self.config, - self.args.num_answers) - expected_calls = [] - for num, answer in enumerate(mock_ai.return_value[0], start=1): - title = f'-- ANSWER {num} ' - title_end = '-' * (terminal_width() - len(title)) - expected_calls.append(((f'{title}{title_end}',),)) - expected_calls.append(((answer,),)) - expected_calls.append((("-" * terminal_width(),),)) - expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) - self.assertEqual(mock_print.call_args_list, expected_calls) - open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) - open_mock.assert_has_calls(open_expected_calls, any_order=True) - - -class TestSaveAnswers(CmmTestCase): - @mock.patch('builtins.open') - @mock.patch('chatmastermind.storage.print') - def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: - question = "Test question?" - answers = ["Answer 1", "Answer 2"] - tags = ["tag1", "tag2"] - otags = ["otag1", "otag2"] - config = self.dummy_config(db='test_db') - - with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ - mock.patch('chatmastermind.storage.yaml.dump'), \ - mock.patch('io.StringIO') as stringio_mock: - stringio_instance = stringio_mock.return_value - stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] - save_answers(question, answers, tags, otags, config) - - open_calls = [ - mock.call(pathlib.Path('test_db/.next'), 'r'), - mock.call(pathlib.Path('test_db/.next'), 'w'), - ] - open_mock.assert_has_calls(open_calls, any_order=True) - - -class TestAI(CmmTestCase): - - @patch("openai.ChatCompletion.create") - def test_ai(self, mock_create: MagicMock) -> None: - mock_create.return_value = { - 'choices': [ - {'message': {'content': 'response_text_1'}}, - {'message': {'content': 'response_text_2'}} - ], - 'usage': {'tokens': 10} - } - - chat = [{"role": "system", "content": "hello ai"}] - config = self.dummy_config(db='dummy') - config.openai.model = "text-davinci-002" - config.openai.max_tokens = 150 - config.openai.temperature = 0.5 - - result = ai(chat, config, 2) - expected_result = (['response_text_1', 'response_text_2'], - {'tokens': 10}) - self.assertEqual(result, expected_result) - - -class TestCreateParser(CmmTestCase): - def test_create_parser(self) -> None: - with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: - mock_cmdparser = Mock() - mock_add_subparsers.return_value = mock_cmdparser - parser = create_parser() - self.assertIsInstance(parser, argparse.ArgumentParser) - mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) - mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) - mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) - self.assertTrue('.config.yaml' in parser.get_default('config')) +# class CmmTestCase(unittest.TestCase): +# """ +# Base class for all cmm testcases. +# """ +# def dummy_config(self, db: str) -> Config: +# """ +# Creates a dummy configuration. +# """ +# return Config.from_dict( +# {'system': 'dummy_system', +# 'db': db, +# 'openai': {'api_key': 'dummy_key', +# 'model': 'dummy_model', +# 'max_tokens': 4000, +# 'temperature': 1.0, +# 'top_p': 1, +# 'frequency_penalty': 0, +# 'presence_penalty': 0}} +# ) +# +# +# class TestCreateChat(CmmTestCase): +# +# def setUp(self) -> None: +# self.config = self.dummy_config(db='test_files') +# self.question = "test question" +# self.tags = ['test_tag'] +# +# @patch('os.listdir') +# @patch('pathlib.Path.iterdir') +# @patch('builtins.open') +# def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: +# listdir_mock.return_value = ['testfile.txt'] +# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] +# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( +# {'question': 'test_content', 'answer': 'some answer', +# 'tags': ['test_tag']})) +# +# test_chat = create_chat_hist(self.question, self.tags, None, self.config) +# +# self.assertEqual(len(test_chat), 4) +# self.assertEqual(test_chat[0], +# {'role': 'system', 'content': self.config.system}) +# self.assertEqual(test_chat[1], +# {'role': 'user', 'content': 'test_content'}) +# self.assertEqual(test_chat[2], +# {'role': 'assistant', 'content': 'some answer'}) +# self.assertEqual(test_chat[3], +# {'role': 'user', 'content': self.question}) +# +# @patch('os.listdir') +# @patch('pathlib.Path.iterdir') +# @patch('builtins.open') +# def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: +# listdir_mock.return_value = ['testfile.txt'] +# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] +# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( +# {'question': 'test_content', 'answer': 'some answer', +# 'tags': ['other_tag']})) +# +# test_chat = create_chat_hist(self.question, self.tags, None, self.config) +# +# self.assertEqual(len(test_chat), 2) +# self.assertEqual(test_chat[0], +# {'role': 'system', 'content': self.config.system}) +# self.assertEqual(test_chat[1], +# {'role': 'user', 'content': self.question}) +# +# @patch('os.listdir') +# @patch('pathlib.Path.iterdir') +# @patch('builtins.open') +# def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: +# listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] +# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] +# open_mock.side_effect = ( +# io.StringIO(dump_data({'question': 'test_content', +# 'answer': 'some answer', +# 'tags': ['test_tag']})), +# io.StringIO(dump_data({'question': 'test_content2', +# 'answer': 'some answer2', +# 'tags': ['test_tag2']})), +# ) +# +# test_chat = create_chat_hist(self.question, [], None, self.config) +# +# self.assertEqual(len(test_chat), 6) +# self.assertEqual(test_chat[0], +# {'role': 'system', 'content': self.config.system}) +# self.assertEqual(test_chat[1], +# {'role': 'user', 'content': 'test_content'}) +# self.assertEqual(test_chat[2], +# {'role': 'assistant', 'content': 'some answer'}) +# self.assertEqual(test_chat[3], +# {'role': 'user', 'content': 'test_content2'}) +# self.assertEqual(test_chat[4], +# {'role': 'assistant', 'content': 'some answer2'}) +# +# +# class TestHandleQuestion(CmmTestCase): +# +# def setUp(self) -> None: +# self.question = "test question" +# self.args = argparse.Namespace( +# or_tags=['tag1'], +# and_tags=None, +# exclude_tags=['xtag1'], +# output_tags=None, +# question=[self.question], +# source=None, +# source_code_only=False, +# num_answers=3, +# max_tokens=None, +# temperature=None, +# model=None, +# match_all_tags=False, +# with_tags=False, +# with_file=False, +# ) +# self.config = self.dummy_config(db='test_files') +# +# @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") +# @patch("chatmastermind.main.print_tag_args") +# @patch("chatmastermind.main.print_chat_hist") +# @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) +# @patch("chatmastermind.utils.pp") +# @patch("builtins.print") +# def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, +# mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, +# mock_create_chat_hist: MagicMock) -> None: +# open_mock = MagicMock() +# with patch("chatmastermind.storage.open", open_mock): +# ask_cmd(self.args, self.config) +# mock_print_tag_args.assert_called_once_with(self.args.or_tags, +# self.args.exclude_tags, +# []) +# mock_create_chat_hist.assert_called_once_with(self.question, +# self.args.or_tags, +# self.args.exclude_tags, +# self.config, +# match_all_tags=False, +# with_tags=False, +# with_file=False) +# mock_print_chat_hist.assert_called_once_with('test_chat', +# False, +# self.args.source_code_only) +# mock_ai.assert_called_with("test_chat", +# self.config, +# self.args.num_answers) +# expected_calls = [] +# for num, answer in enumerate(mock_ai.return_value[0], start=1): +# title = f'-- ANSWER {num} ' +# title_end = '-' * (terminal_width() - len(title)) +# expected_calls.append(((f'{title}{title_end}',),)) +# expected_calls.append(((answer,),)) +# expected_calls.append((("-" * terminal_width(),),)) +# expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) +# self.assertEqual(mock_print.call_args_list, expected_calls) +# open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) +# open_mock.assert_has_calls(open_expected_calls, any_order=True) +# +# +# class TestSaveAnswers(CmmTestCase): +# @mock.patch('builtins.open') +# @mock.patch('chatmastermind.storage.print') +# def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: +# question = "Test question?" +# answers = ["Answer 1", "Answer 2"] +# tags = ["tag1", "tag2"] +# otags = ["otag1", "otag2"] +# config = self.dummy_config(db='test_db') +# +# with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ +# mock.patch('chatmastermind.storage.yaml.dump'), \ +# mock.patch('io.StringIO') as stringio_mock: +# stringio_instance = stringio_mock.return_value +# stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] +# save_answers(question, answers, tags, otags, config) +# +# open_calls = [ +# mock.call(pathlib.Path('test_db/.next'), 'r'), +# mock.call(pathlib.Path('test_db/.next'), 'w'), +# ] +# open_mock.assert_has_calls(open_calls, any_order=True) +# +# +# class TestAI(CmmTestCase): +# +# @patch("openai.ChatCompletion.create") +# def test_ai(self, mock_create: MagicMock) -> None: +# mock_create.return_value = { +# 'choices': [ +# {'message': {'content': 'response_text_1'}}, +# {'message': {'content': 'response_text_2'}} +# ], +# 'usage': {'tokens': 10} +# } +# +# chat = [{"role": "system", "content": "hello ai"}] +# config = self.dummy_config(db='dummy') +# config.openai.model = "text-davinci-002" +# config.openai.max_tokens = 150 +# config.openai.temperature = 0.5 +# +# result = ai(chat, config, 2) +# expected_result = (['response_text_1', 'response_text_2'], +# {'tokens': 10}) +# self.assertEqual(result, expected_result) +# +# +# class TestCreateParser(CmmTestCase): +# def test_create_parser(self) -> None: +# with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: +# mock_cmdparser = Mock() +# mock_add_subparsers.return_value = mock_cmdparser +# parser = create_parser() +# self.assertIsInstance(parser, argparse.ArgumentParser) +# mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) +# mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) +# mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) +# self.assertTrue('.config.yaml' in parser.get_default('config')) diff --git a/tests/test_message.py b/tests/test_message.py index a49c893..57d5982 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -1,12 +1,12 @@ +import unittest import pathlib import tempfile from typing import cast -from .test_main import CmmTestCase from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in from chatmastermind.tags import Tag, TagLine -class SourceCodeTestCase(CmmTestCase): +class SourceCodeTestCase(unittest.TestCase): def test_source_code_with_include_delims(self) -> None: text = """ Some text before the code block @@ -60,7 +60,7 @@ class SourceCodeTestCase(CmmTestCase): self.assertEqual(result, expected_result) -class QuestionTestCase(CmmTestCase): +class QuestionTestCase(unittest.TestCase): def test_question_with_header(self) -> None: with self.assertRaises(MessageError): Question(f"{Question.txt_header}\nWhat is your name?") @@ -83,7 +83,7 @@ class QuestionTestCase(CmmTestCase): self.assertEqual(question, "What is your favorite color?") -class AnswerTestCase(CmmTestCase): +class AnswerTestCase(unittest.TestCase): def test_answer_with_header(self) -> None: with self.assertRaises(MessageError): Answer(f"{Answer.txt_header}\nno") @@ -99,7 +99,7 @@ class AnswerTestCase(CmmTestCase): self.assertEqual(answer, "No") -class MessageToFileTxtTestCase(CmmTestCase): +class MessageToFileTxtTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path = pathlib.Path(self.file.name) @@ -160,7 +160,7 @@ This is a question. self.message_complete.file_path = self.file_path -class MessageToFileYamlTestCase(CmmTestCase): +class MessageToFileYamlTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path = pathlib.Path(self.file.name) @@ -226,7 +226,7 @@ class MessageToFileYamlTestCase(CmmTestCase): self.assertEqual(content, expected_content) -class MessageFromFileTxtTestCase(CmmTestCase): +class MessageFromFileTxtTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path = pathlib.Path(self.file.name) @@ -388,7 +388,7 @@ This is a question. self.assertIsNone(message) -class MessageFromFileYamlTestCase(CmmTestCase): +class MessageFromFileYamlTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml') self.file_path = pathlib.Path(self.file.name) @@ -555,7 +555,7 @@ class MessageFromFileYamlTestCase(CmmTestCase): self.assertIsNone(message) -class TagsFromFileTestCase(CmmTestCase): +class TagsFromFileTestCase(unittest.TestCase): def setUp(self) -> None: self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path_txt = pathlib.Path(self.file_txt.name) @@ -663,7 +663,7 @@ This is an answer. self.assertSetEqual(tags, set()) -class TagsFromDirTestCase(CmmTestCase): +class TagsFromDirTestCase(unittest.TestCase): def setUp(self) -> None: self.temp_dir = tempfile.TemporaryDirectory() self.temp_dir_no_tags = tempfile.TemporaryDirectory() @@ -711,7 +711,7 @@ class TagsFromDirTestCase(CmmTestCase): self.assertSetEqual(all_tags, set()) -class MessageIDTestCase(CmmTestCase): +class MessageIDTestCase(unittest.TestCase): def setUp(self) -> None: self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt') self.file_path = pathlib.Path(self.file.name) @@ -731,7 +731,7 @@ class MessageIDTestCase(CmmTestCase): self.message_no_file_path.msg_id() -class MessageHashTestCase(CmmTestCase): +class MessageHashTestCase(unittest.TestCase): def setUp(self) -> None: self.message1 = Message(Question('This is a question.'), tags={Tag('tag1')}, @@ -755,7 +755,7 @@ class MessageHashTestCase(CmmTestCase): self.assertIn(msg, msgs) -class MessageTagsStrTestCase(CmmTestCase): +class MessageTagsStrTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), tags={Tag('tag1')}, @@ -765,7 +765,7 @@ class MessageTagsStrTestCase(CmmTestCase): self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1') -class MessageFilterTagsTestCase(CmmTestCase): +class MessageFilterTagsTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), tags={Tag('atag1'), Tag('btag2')}, @@ -780,7 +780,7 @@ class MessageFilterTagsTestCase(CmmTestCase): self.assertSetEqual(tags_cont, {Tag('btag2')}) -class MessageInTestCase(CmmTestCase): +class MessageInTestCase(unittest.TestCase): def setUp(self) -> None: self.message1 = Message(Question('This is a question.'), tags={Tag('atag1'), Tag('btag2')}, @@ -794,7 +794,7 @@ class MessageInTestCase(CmmTestCase): self.assertFalse(message_in(self.message1, [self.message2])) -class MessageRenameTagsTestCase(CmmTestCase): +class MessageRenameTagsTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), tags={Tag('atag1'), Tag('btag2')}, @@ -806,7 +806,7 @@ class MessageRenameTagsTestCase(CmmTestCase): self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type] -class MessageToStrTestCase(CmmTestCase): +class MessageToStrTestCase(unittest.TestCase): def setUp(self) -> None: self.message = Message(Question('This is a question.'), Answer('This is an answer.'), diff --git a/tests/test_tags.py b/tests/test_tags.py index aa89a06..edd3c05 100644 --- a/tests/test_tags.py +++ b/tests/test_tags.py @@ -1,8 +1,8 @@ -from .test_main import CmmTestCase +import unittest from chatmastermind.tags import Tag, TagLine, TagError -class TestTag(CmmTestCase): +class TestTag(unittest.TestCase): def test_valid_tag(self) -> None: tag = Tag('mytag') self.assertEqual(tag, 'mytag') @@ -18,7 +18,7 @@ class TestTag(CmmTestCase): self.assertEqual(Tag.alternative_separators, [',']) -class TestTagLine(CmmTestCase): +class TestTagLine(unittest.TestCase): def test_valid_tagline(self) -> None: tagline = TagLine('TAGS: tag1 tag2') self.assertEqual(tagline, 'TAGS: tag1 tag2') -- 2.36.6 From 6a4cc7a65d9c3cc094a78568e4cded6a92e3f63e Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 09:23:29 +0200 Subject: [PATCH 140/170] setup: added 'ais' subfolder --- chatmastermind/ais/__init__.py | 0 setup.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 chatmastermind/ais/__init__.py diff --git a/chatmastermind/ais/__init__.py b/chatmastermind/ais/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py index 02d9ab1..8484629 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/ok2/ChatMastermind", - packages=find_packages(), + packages=find_packages() + ["chatmastermind.ais"], classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Console", @@ -32,7 +32,7 @@ setup( "openai", "PyYAML", "argcomplete", - "pytest" + "pytest", ], python_requires=">=3.9", test_suite="tests", -- 2.36.6 From 21d39c6c6646213caeaa595b9833dce0ffafbb33 Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 09:43:23 +0200 Subject: [PATCH 141/170] cmm: removed all the old code and modules --- chatmastermind/api_client.py | 45 ------- chatmastermind/main.py | 104 ++------------- chatmastermind/storage.py | 121 ------------------ chatmastermind/utils.py | 81 ------------ tests/test_main.py | 236 ----------------------------------- 5 files changed, 12 insertions(+), 575 deletions(-) delete mode 100644 chatmastermind/api_client.py delete mode 100644 chatmastermind/storage.py delete mode 100644 chatmastermind/utils.py delete mode 100644 tests/test_main.py diff --git a/chatmastermind/api_client.py b/chatmastermind/api_client.py deleted file mode 100644 index 2c4a094..0000000 --- a/chatmastermind/api_client.py +++ /dev/null @@ -1,45 +0,0 @@ -import openai - -from .utils import ChatType -from .configuration import Config - - -def openai_api_key(api_key: str) -> None: - openai.api_key = api_key - - -def print_models() -> None: - """ - Print all models supported by the current AI. - """ - not_ready = [] - for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): - if engine['ready']: - print(engine['id']) - else: - not_ready.append(engine['id']) - if len(not_ready) > 0: - print('\nNot ready: ' + ', '.join(not_ready)) - - -def ai(chat: ChatType, - config: Config, - number: int - ) -> tuple[list[str], dict[str, int]]: - """ - Make AI request with the given chat history and configuration. - Return AI response and tokens used. - """ - response = openai.ChatCompletion.create( - model=config.openai.model, - messages=chat, - temperature=config.openai.temperature, - max_tokens=config.openai.max_tokens, - top_p=config.openai.top_p, - n=number, - frequency_penalty=config.openai.frequency_penalty, - presence_penalty=config.openai.presence_penalty) - result = [] - for choice in response['choices']: # type: ignore - result.append(choice['message']['content'].strip()) - return result, dict(response['usage']) # type: ignore diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 67eafae..58ce9ed 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,61 +6,19 @@ import sys import argcomplete import argparse from pathlib import Path -from .utils import terminal_width, print_tag_args, print_chat_hist, ChatType -from .storage import save_answers, create_chat_hist -from .api_client import ai, openai_api_key, print_models -from .configuration import Config +from .configuration import Config, default_config_path from .chat import ChatDB from .message import Message, MessageFilter, MessageError, Question from .ai_factory import create_ai from .ai import AI, AIResponse -from itertools import zip_longest from typing import Any -default_config = '.config.yaml' - def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: config = Config.from_file(parsed_args.config) return get_tags_unique(config, prefix) -def create_question_with_hist(args: argparse.Namespace, - config: Config, - ) -> tuple[ChatType, str, list[str]]: - """ - Creates the "AI request", including the question and chat history as determined - by the specified tags. - """ - tags = args.or_tags or [] - xtags = args.exclude_tags or [] - otags = args.output_tags or [] - - if not args.source_code_only: - print_tag_args(tags, xtags, otags) - - question_parts = [] - question_list = args.question if args.question is not None else [] - source_list = args.source if args.source is not None else [] - - for question, source in zip_longest(question_list, source_list, fillvalue=None): - if question is not None and source is not None: - with open(source) as r: - question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") - elif question is not None: - question_parts.append(question) - elif source is not None: - with open(source) as r: - question_parts.append(f"```\n{r.read().strip()}\n```") - - full_question = '\n\n'.join(question_parts) - chat = create_chat_hist(full_question, tags, xtags, config, - match_all_tags=True if args.and_tags else False, # FIXME - with_tags=False, - with_file=False) - return chat, full_question, tags - - def tags_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'tags' command. @@ -74,17 +32,12 @@ def tags_cmd(args: argparse.Namespace, config: Config) -> None: # TODO: add renaming -def config_cmd(args: argparse.Namespace, config: Config) -> None: +def config_cmd(args: argparse.Namespace) -> None: """ Handler for the 'config' command. """ - if args.list_models: - print_models() - elif args.print_model: - print(config.openai.model) - elif args.model: - config.openai.model = args.model - config.to_file(args.config) + if args.create: + Config.create_default(Path(args.create)) def question_cmd(args: argparse.Namespace, config: Config) -> None: @@ -95,6 +48,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: db_path=Path(config.db)) # if it's a new question, create and store it immediately if args.ask or args.create: + # FIXME: add sources to the question message = Message(question=Question(args.question), tags=args.ouput_tags, # FIXME ai=args.ai, @@ -128,25 +82,6 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: pass -def ask_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'ask' command. - """ - if args.max_tokens: - config.openai.max_tokens = args.max_tokens - if args.temperature: - config.openai.temperature = args.temperature - if args.model: - config.openai.model = args.model - chat, question, tags = create_question_with_hist(args, config) - print_chat_hist(chat, False, args.source_code_only) - otags = args.output_tags or [] - answers, usage = ai(chat, config, args.num_answers) - save_answers(question, answers, tags, otags, config) - print("-" * terminal_width()) - print(f"Usage: {usage}") - - def hist_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'hist' command. @@ -190,7 +125,7 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: def create_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="ChatMastermind is a Python application that automates conversation with AI") - parser.add_argument('-C', '--config', help='Config file name.', default=default_config) + parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path) # subcommand-parser cmdparser = parser.add_subparsers(dest='command', @@ -235,22 +170,6 @@ def create_parser() -> argparse.ArgumentParser: question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', action='store_true') - # 'ask' command parser - ask_cmd_parser = cmdparser.add_parser('ask', parents=[tag_parser], - help="Ask a question.", - aliases=['a']) - ask_cmd_parser.set_defaults(func=ask_cmd) - ask_cmd_parser.add_argument('-q', '--question', nargs='+', help='Question to ask', - required=True) - ask_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) - ask_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) - ask_cmd_parser.add_argument('-M', '--model', help='Model to use') - ask_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, - default=1) - ask_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - ask_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', - action='store_true') - # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], help="Print chat history.", @@ -286,7 +205,7 @@ def create_parser() -> argparse.ArgumentParser: action='store_true') config_group.add_argument('-m', '--print-model', help="Print the currently configured model", action='store_true') - config_group.add_argument('-M', '--model', help="Set model in the config file") + config_group.add_argument('-c', '--create', help="Create config with default settings in the given file") # 'print' command parser print_cmd_parser = cmdparser.add_parser('print', @@ -315,11 +234,12 @@ def main() -> int: parser = create_parser() args = parser.parse_args() command = parser.parse_args() - config = Config.from_file(args.config) - openai_api_key(config.openai.api_key) - - command.func(command, config) + if command.func == config_cmd: + command.func(command) + else: + config = Config.from_file(args.config) + command.func(command, config) return 0 diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py deleted file mode 100644 index 8b9ed97..0000000 --- a/chatmastermind/storage.py +++ /dev/null @@ -1,121 +0,0 @@ -import yaml -import io -import pathlib -from .utils import terminal_width, append_message, message_to_chat, ChatType -from .configuration import Config -from typing import Any, Optional - - -def read_file(fname: pathlib.Path, tags_only: bool = False) -> dict[str, Any]: - with open(fname, "r") as fd: - tagline = fd.readline().strip().split(':', maxsplit=1)[1].strip() - # also support tags separated by ',' (old format) - separator = ',' if ',' in tagline else ' ' - tags = [t.strip() for t in tagline.split(separator)] - if tags_only: - return {"tags": tags} - text = fd.read().strip().split('\n') - question_idx = text.index("=== QUESTION ===") + 1 - answer_idx = text.index("==== ANSWER ====") - question = "\n".join(text[question_idx:answer_idx]).strip() - answer = "\n".join(text[answer_idx + 1:]).strip() - return {"question": question, "answer": answer, "tags": tags, - "file": fname.name} - - -def dump_data(data: dict[str, Any]) -> str: - with io.StringIO() as fd: - fd.write(f'TAGS: {" ".join(data["tags"])}\n') - fd.write(f'=== QUESTION ===\n{data["question"]}\n') - fd.write(f'==== ANSWER ====\n{data["answer"]}\n') - return fd.getvalue() - - -def write_file(fname: str, data: dict[str, Any]) -> None: - with open(fname, "w") as fd: - fd.write(f'TAGS: {" ".join(data["tags"])}\n') - fd.write(f'=== QUESTION ===\n{data["question"]}\n') - fd.write(f'==== ANSWER ====\n{data["answer"]}\n') - - -def save_answers(question: str, - answers: list[str], - tags: list[str], - otags: Optional[list[str]], - config: Config - ) -> None: - wtags = otags or tags - num, inum = 0, 0 - next_fname = pathlib.Path(str(config.db)) / '.next' - try: - with open(next_fname, 'r') as f: - num = int(f.read()) - except Exception: - pass - for answer in answers: - num += 1 - inum += 1 - title = f'-- ANSWER {inum} ' - title_end = '-' * (terminal_width() - len(title)) - print(f'{title}{title_end}') - print(answer) - write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags}) - with open(next_fname, 'w') as f: - f.write(f'{num}') - - -def create_chat_hist(question: Optional[str], - tags: Optional[list[str]], - extags: Optional[list[str]], - config: Config, - match_all_tags: bool = False, - with_tags: bool = False, - with_file: bool = False - ) -> ChatType: - chat: ChatType = [] - append_message(chat, 'system', str(config.system).strip()) - for file in sorted(pathlib.Path(str(config.db)).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - data['file'] = file.name - elif file.suffix == '.txt': - data = read_file(file) - else: - continue - data_tags = set(data.get('tags', [])) - tags_match: bool - if match_all_tags: - tags_match = not tags or set(tags).issubset(data_tags) - else: - tags_match = not tags or bool(data_tags.intersection(tags)) - extags_do_not_match = \ - not extags or not data_tags.intersection(extags) - if tags_match and extags_do_not_match: - message_to_chat(data, chat, with_tags, with_file) - if question: - append_message(chat, 'user', question) - return chat - - -def get_tags(config: Config, prefix: Optional[str]) -> list[str]: - result = [] - for file in sorted(pathlib.Path(str(config.db)).iterdir()): - if file.suffix == '.yaml': - with open(file, 'r') as f: - data = yaml.load(f, Loader=yaml.FullLoader) - elif file.suffix == '.txt': - data = read_file(file, tags_only=True) - else: - continue - for tag in data.get('tags', []): - if prefix and len(prefix) > 0: - if tag.startswith(prefix): - result.append(tag) - else: - result.append(tag) - return result - - -def get_tags_unique(config: Config, prefix: Optional[str]) -> list[str]: - return list(set(get_tags(config, prefix))) diff --git a/chatmastermind/utils.py b/chatmastermind/utils.py deleted file mode 100644 index 4135ae3..0000000 --- a/chatmastermind/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -import shutil -from pprint import PrettyPrinter -from typing import Any - -ChatType = list[dict[str, str]] - - -def terminal_width() -> int: - return shutil.get_terminal_size().columns - - -def pp(*args: Any, **kwargs: Any) -> None: - return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) - - -def print_tag_args(tags: list[str], extags: list[str], otags: list[str]) -> None: - """ - Prints the tags specified in the given args. - """ - printed_messages = [] - - if tags: - printed_messages.append(f"Tags: {' '.join(tags)}") - if extags: - printed_messages.append(f"Excluding tags: {' '.join(extags)}") - if otags: - printed_messages.append(f"Output tags: {' '.join(otags)}") - - if printed_messages: - print("\n".join(printed_messages)) - print() - - -def append_message(chat: ChatType, - role: str, - content: str - ) -> None: - chat.append({'role': role, 'content': content.replace("''", "'")}) - - -def message_to_chat(message: dict[str, str], - chat: ChatType, - with_tags: bool = False, - with_file: bool = False - ) -> None: - append_message(chat, 'user', message['question']) - append_message(chat, 'assistant', message['answer']) - if with_tags: - tags = " ".join(message['tags']) - append_message(chat, 'tags', tags) - if with_file: - append_message(chat, 'file', message['file']) - - -def display_source_code(content: str) -> None: - try: - content_start = content.index('```') - content_start = content.index('\n', content_start) + 1 - content_end = content.rindex('```') - if content_start < content_end: - print(content[content_start:content_end].strip()) - except ValueError: - pass - - -def print_chat_hist(chat: ChatType, dump: bool = False, source_code: bool = False) -> None: - if dump: - pp(chat) - return - for message in chat: - text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2 - if source_code: - display_source_code(message['content']) - continue - if message['role'] == 'user': - print('-' * terminal_width()) - if text_too_long: - print(f"{message['role'].upper()}:") - print(message['content']) - else: - print(f"{message['role'].upper()}: {message['content']}") diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 91e6462..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,236 +0,0 @@ -# import unittest -# import io -# import pathlib -# import argparse -# from chatmastermind.utils import terminal_width -# from chatmastermind.main import create_parser, ask_cmd -# from chatmastermind.api_client import ai -# from chatmastermind.configuration import Config -# from chatmastermind.storage import create_chat_hist, save_answers, dump_data -# from unittest import mock -# from unittest.mock import patch, MagicMock, Mock, ANY - - -# class CmmTestCase(unittest.TestCase): -# """ -# Base class for all cmm testcases. -# """ -# def dummy_config(self, db: str) -> Config: -# """ -# Creates a dummy configuration. -# """ -# return Config.from_dict( -# {'system': 'dummy_system', -# 'db': db, -# 'openai': {'api_key': 'dummy_key', -# 'model': 'dummy_model', -# 'max_tokens': 4000, -# 'temperature': 1.0, -# 'top_p': 1, -# 'frequency_penalty': 0, -# 'presence_penalty': 0}} -# ) -# -# -# class TestCreateChat(CmmTestCase): -# -# def setUp(self) -> None: -# self.config = self.dummy_config(db='test_files') -# self.question = "test question" -# self.tags = ['test_tag'] -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_with_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( -# {'question': 'test_content', 'answer': 'some answer', -# 'tags': ['test_tag']})) -# -# test_chat = create_chat_hist(self.question, self.tags, None, self.config) -# -# self.assertEqual(len(test_chat), 4) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': 'test_content'}) -# self.assertEqual(test_chat[2], -# {'role': 'assistant', 'content': 'some answer'}) -# self.assertEqual(test_chat[3], -# {'role': 'user', 'content': self.question}) -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_with_other_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.return_value.__enter__.return_value = io.StringIO(dump_data( -# {'question': 'test_content', 'answer': 'some answer', -# 'tags': ['other_tag']})) -# -# test_chat = create_chat_hist(self.question, self.tags, None, self.config) -# -# self.assertEqual(len(test_chat), 2) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': self.question}) -# -# @patch('os.listdir') -# @patch('pathlib.Path.iterdir') -# @patch('builtins.open') -# def test_create_chat_without_tags(self, open_mock: MagicMock, iterdir_mock: MagicMock, listdir_mock: MagicMock) -> None: -# listdir_mock.return_value = ['testfile.txt', 'testfile2.txt'] -# iterdir_mock.return_value = [pathlib.Path(x) for x in listdir_mock.return_value] -# open_mock.side_effect = ( -# io.StringIO(dump_data({'question': 'test_content', -# 'answer': 'some answer', -# 'tags': ['test_tag']})), -# io.StringIO(dump_data({'question': 'test_content2', -# 'answer': 'some answer2', -# 'tags': ['test_tag2']})), -# ) -# -# test_chat = create_chat_hist(self.question, [], None, self.config) -# -# self.assertEqual(len(test_chat), 6) -# self.assertEqual(test_chat[0], -# {'role': 'system', 'content': self.config.system}) -# self.assertEqual(test_chat[1], -# {'role': 'user', 'content': 'test_content'}) -# self.assertEqual(test_chat[2], -# {'role': 'assistant', 'content': 'some answer'}) -# self.assertEqual(test_chat[3], -# {'role': 'user', 'content': 'test_content2'}) -# self.assertEqual(test_chat[4], -# {'role': 'assistant', 'content': 'some answer2'}) -# -# -# class TestHandleQuestion(CmmTestCase): -# -# def setUp(self) -> None: -# self.question = "test question" -# self.args = argparse.Namespace( -# or_tags=['tag1'], -# and_tags=None, -# exclude_tags=['xtag1'], -# output_tags=None, -# question=[self.question], -# source=None, -# source_code_only=False, -# num_answers=3, -# max_tokens=None, -# temperature=None, -# model=None, -# match_all_tags=False, -# with_tags=False, -# with_file=False, -# ) -# self.config = self.dummy_config(db='test_files') -# -# @patch("chatmastermind.main.create_chat_hist", return_value="test_chat") -# @patch("chatmastermind.main.print_tag_args") -# @patch("chatmastermind.main.print_chat_hist") -# @patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage")) -# @patch("chatmastermind.utils.pp") -# @patch("builtins.print") -# def test_ask_cmd(self, mock_print: MagicMock, mock_pp: MagicMock, mock_ai: MagicMock, -# mock_print_chat_hist: MagicMock, mock_print_tag_args: MagicMock, -# mock_create_chat_hist: MagicMock) -> None: -# open_mock = MagicMock() -# with patch("chatmastermind.storage.open", open_mock): -# ask_cmd(self.args, self.config) -# mock_print_tag_args.assert_called_once_with(self.args.or_tags, -# self.args.exclude_tags, -# []) -# mock_create_chat_hist.assert_called_once_with(self.question, -# self.args.or_tags, -# self.args.exclude_tags, -# self.config, -# match_all_tags=False, -# with_tags=False, -# with_file=False) -# mock_print_chat_hist.assert_called_once_with('test_chat', -# False, -# self.args.source_code_only) -# mock_ai.assert_called_with("test_chat", -# self.config, -# self.args.num_answers) -# expected_calls = [] -# for num, answer in enumerate(mock_ai.return_value[0], start=1): -# title = f'-- ANSWER {num} ' -# title_end = '-' * (terminal_width() - len(title)) -# expected_calls.append(((f'{title}{title_end}',),)) -# expected_calls.append(((answer,),)) -# expected_calls.append((("-" * terminal_width(),),)) -# expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) -# self.assertEqual(mock_print.call_args_list, expected_calls) -# open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)]) -# open_mock.assert_has_calls(open_expected_calls, any_order=True) -# -# -# class TestSaveAnswers(CmmTestCase): -# @mock.patch('builtins.open') -# @mock.patch('chatmastermind.storage.print') -# def test_save_answers(self, print_mock: MagicMock, open_mock: MagicMock) -> None: -# question = "Test question?" -# answers = ["Answer 1", "Answer 2"] -# tags = ["tag1", "tag2"] -# otags = ["otag1", "otag2"] -# config = self.dummy_config(db='test_db') -# -# with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ -# mock.patch('chatmastermind.storage.yaml.dump'), \ -# mock.patch('io.StringIO') as stringio_mock: -# stringio_instance = stringio_mock.return_value -# stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] -# save_answers(question, answers, tags, otags, config) -# -# open_calls = [ -# mock.call(pathlib.Path('test_db/.next'), 'r'), -# mock.call(pathlib.Path('test_db/.next'), 'w'), -# ] -# open_mock.assert_has_calls(open_calls, any_order=True) -# -# -# class TestAI(CmmTestCase): -# -# @patch("openai.ChatCompletion.create") -# def test_ai(self, mock_create: MagicMock) -> None: -# mock_create.return_value = { -# 'choices': [ -# {'message': {'content': 'response_text_1'}}, -# {'message': {'content': 'response_text_2'}} -# ], -# 'usage': {'tokens': 10} -# } -# -# chat = [{"role": "system", "content": "hello ai"}] -# config = self.dummy_config(db='dummy') -# config.openai.model = "text-davinci-002" -# config.openai.max_tokens = 150 -# config.openai.temperature = 0.5 -# -# result = ai(chat, config, 2) -# expected_result = (['response_text_1', 'response_text_2'], -# {'tokens': 10}) -# self.assertEqual(result, expected_result) -# -# -# class TestCreateParser(CmmTestCase): -# def test_create_parser(self) -> None: -# with patch('argparse.ArgumentParser.add_subparsers') as mock_add_subparsers: -# mock_cmdparser = Mock() -# mock_add_subparsers.return_value = mock_cmdparser -# parser = create_parser() -# self.assertIsInstance(parser, argparse.ArgumentParser) -# mock_add_subparsers.assert_called_once_with(dest='command', title='commands', description='supported commands', required=True) -# mock_cmdparser.add_parser.assert_any_call('ask', parents=ANY, help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('hist', parents=ANY, help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('tags', help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('config', help=ANY, aliases=ANY) -# mock_cmdparser.add_parser.assert_any_call('print', help=ANY, aliases=ANY) -# self.assertTrue('.config.yaml' in parser.get_default('config')) -- 2.36.6 From 61e710a4b1d7b5570862714376ae6262b26dcb9f Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 13:31:01 +0200 Subject: [PATCH 142/170] cmm: splitted commands into separate modules (and more cleanup) --- chatmastermind/commands/config.py | 11 ++++++ chatmastermind/commands/hist.py | 23 ++++++++++++ chatmastermind/commands/print.py | 19 ++++++++++ chatmastermind/commands/question.py | 57 +++++++++++++++++++++++++++++ chatmastermind/commands/tags.py | 17 +++++++++ chatmastermind/main.py | 44 ++++++++++------------ setup.py | 2 +- tests/test_ai_factory.py | 48 ++++++++++++++++++++++++ 8 files changed, 196 insertions(+), 25 deletions(-) create mode 100644 chatmastermind/commands/config.py create mode 100644 chatmastermind/commands/hist.py create mode 100644 chatmastermind/commands/print.py create mode 100644 chatmastermind/commands/question.py create mode 100644 chatmastermind/commands/tags.py create mode 100644 tests/test_ai_factory.py diff --git a/chatmastermind/commands/config.py b/chatmastermind/commands/config.py new file mode 100644 index 0000000..262164c --- /dev/null +++ b/chatmastermind/commands/config.py @@ -0,0 +1,11 @@ +import argparse +from pathlib import Path +from ..configuration import Config + + +def config_cmd(args: argparse.Namespace) -> None: + """ + Handler for the 'config' command. + """ + if args.create: + Config.create_default(Path(args.create)) diff --git a/chatmastermind/commands/hist.py b/chatmastermind/commands/hist.py new file mode 100644 index 0000000..88ed3be --- /dev/null +++ b/chatmastermind/commands/hist.py @@ -0,0 +1,23 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB +from ..message import MessageFilter + + +def hist_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'hist' command. + """ + + mfilter = MessageFilter(tags_or=args.or_tags, + tags_and=args.and_tags, + tags_not=args.exclude_tags, + question_contains=args.question, + answer_contains=args.answer) + chat = ChatDB.from_dir(Path('.'), + Path(config.db), + mfilter=mfilter) + chat.print(args.source_code_only, + args.with_tags, + args.with_files) diff --git a/chatmastermind/commands/print.py b/chatmastermind/commands/print.py new file mode 100644 index 0000000..51e76f8 --- /dev/null +++ b/chatmastermind/commands/print.py @@ -0,0 +1,19 @@ +import sys +import argparse +from pathlib import Path +from ..configuration import Config +from ..message import Message, MessageError + + +def print_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'print' command. + """ + fname = Path(args.file) + try: + message = Message.from_file(fname) + if message: + print(message.to_str(source_code_only=args.source_code_only)) + except MessageError: + print(f"File is not a valid message: {args.file}") + sys.exit(1) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py new file mode 100644 index 0000000..9c56ced --- /dev/null +++ b/chatmastermind/commands/question.py @@ -0,0 +1,57 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB +from ..message import Message, Question +from ..ai_factory import create_ai +from ..ai import AI, AIResponse + + +def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: + """ + Creates (and writes) a new message from the given arguments. + """ + # FIXME: add sources to the question + message = Message(question=Question(args.question), + tags=args.output_tags, # FIXME + ai=args.ai, + model=args.model) + chat.add_to_cache([message]) + return message + + +def question_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'question' command. + """ + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) + # if it's a new question, create and store it immediately + if args.ask or args.create: + message = create_message(chat, args) + if args.create: + return + + # create the correct AI instance + ai: AI = create_ai(args, config) + if args.ask: + response: AIResponse = ai.request(message, + chat, + args.num_answers, # FIXME + args.otags) # FIXME + assert response + # TODO: + # * add answer to the message above (and create + # more messages for any additional answers) + pass + elif args.repeat: + lmessage = chat.latest_message() + assert lmessage + # TODO: repeat either the last question or the + # one(s) given in 'args.repeat' (overwrite + # existing ones if 'args.overwrite' is True) + pass + elif args.process: + # TODO: process either all questions without an + # answer or the one(s) given in 'args.process' + pass diff --git a/chatmastermind/commands/tags.py b/chatmastermind/commands/tags.py new file mode 100644 index 0000000..2906a5b --- /dev/null +++ b/chatmastermind/commands/tags.py @@ -0,0 +1,17 @@ +import argparse +from pathlib import Path +from ..configuration import Config +from ..chat import ChatDB + + +def tags_cmd(args: argparse.Namespace, config: Config) -> None: + """ + Handler for the 'tags' command. + """ + chat = ChatDB.from_dir(cache_path=Path('.'), + db_path=Path(config.db)) + if args.list: + tags_freq = chat.tags_frequency(args.prefix, args.contain) + for tag, freq in tags_freq.items(): + print(f"- {tag}: {freq}") + # TODO: add renaming diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 58ce9ed..02cdffd 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -6,12 +6,14 @@ import sys import argcomplete import argparse from pathlib import Path -from .configuration import Config, default_config_path -from .chat import ChatDB -from .message import Message, MessageFilter, MessageError, Question -from .ai_factory import create_ai -from .ai import AI, AIResponse from typing import Any +from .configuration import Config, default_config_path +from .message import Message +from .commands.question import question_cmd +from .commands.tags import tags_cmd +from .commands.config import config_cmd +from .commands.hist import hist_cmd +from .commands.print import print_cmd def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: @@ -136,20 +138,28 @@ def create_parser() -> argparse.ArgumentParser: # a parent parser for all commands that support tag selection tag_parser = argparse.ArgumentParser(add_help=False) tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+', - help='List of tag names (one must match)', metavar='OTAGS') + help='List of tags (one must match)', metavar='OTAGS') tag_arg.completer = tags_completer # type: ignore atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+', - help='List of tag names (all must match)', metavar='ATAGS') + help='List of tags (all must match)', metavar='ATAGS') atag_arg.completer = tags_completer # type: ignore etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+', - help='List of tag names to exclude', metavar='XTAGS') + help='List of tags to exclude', metavar='XTAGS') etag_arg.completer = tags_completer # type: ignore otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+', - help='List of output tag names, default is input', metavar='OUTTAGS') + help='List of output tags (default: use input tags)', metavar='OUTTAGS') otag_arg.completer = tags_completer # type: ignore + # a parent parser for all commands that support AI configuration + ai_parser = argparse.ArgumentParser(add_help=False) + ai_parser.add_argument('-A', '--AI', help='AI ID to use') + ai_parser.add_argument('-M', '--model', help='Model to use') + ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1) + ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int) + ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float) + # 'question' command parser - question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser], + question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser, ai_parser], help="ask, create and process questions.", aliases=['q']) question_cmd_parser.set_defaults(func=question_cmd) @@ -160,12 +170,6 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int) - question_cmd_parser.add_argument('-T', '--temperature', help='Temperature to use', type=float) - question_cmd_parser.add_argument('-A', '--AI', help='AI to use') - question_cmd_parser.add_argument('-M', '--model', help='Model to use') - question_cmd_parser.add_argument('-n', '--num-answers', help='Number of answers to produce', type=int, - default=1) question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', action='store_true') @@ -213,18 +217,10 @@ def create_parser() -> argparse.ArgumentParser: aliases=['p']) print_cmd_parser.set_defaults(func=print_cmd) print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True) -<<<<<<< HEAD print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group() print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true') print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true') print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true') -||||||| parent of bf1cbff (cmm: the 'print' command now uses 'Message.from_file()') - print_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code', - action='store_true') -======= - print_cmd_parser.add_argument('-S', '--source-code-only', help='Print source code only (from the answer, if available)', - action='store_true') ->>>>>>> bf1cbff (cmm: the 'print' command now uses 'Message.from_file()') argcomplete.autocomplete(parser) return parser diff --git a/setup.py b/setup.py index 8484629..a311605 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( long_description=long_description, long_description_content_type="text/markdown", url="https://github.com/ok2/ChatMastermind", - packages=find_packages() + ["chatmastermind.ais"], + packages=find_packages() + ["chatmastermind.ais", "chatmastermind.commands"], classifiers=[ "Development Status :: 3 - Alpha", "Environment :: Console", diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py new file mode 100644 index 0000000..d63970e --- /dev/null +++ b/tests/test_ai_factory.py @@ -0,0 +1,48 @@ +import argparse +import unittest +from unittest.mock import MagicMock +from chatmastermind.ai_factory import create_ai +from chatmastermind.configuration import Config +from chatmastermind.ai import AIError +from chatmastermind.ais.openai import OpenAI + + +class TestCreateAI(unittest.TestCase): + def setUp(self) -> None: + self.args = MagicMock(spec=argparse.Namespace) + self.args.ai = 'default' + self.args.model = None + self.args.max_tokens = None + self.args.temperature = None + + def test_create_ai_from_args(self) -> None: + # Create an AI with the default configuration + config = Config() + self.args.ai = 'default' + ai = create_ai(self.args, config) + self.assertIsInstance(ai, OpenAI) + + def test_create_ai_from_default(self) -> None: + self.args.ai = None + # Create an AI with the default configuration + config = Config() + ai = create_ai(self.args, config) + self.assertIsInstance(ai, OpenAI) + + def test_create_empty_ai_error(self) -> None: + self.args.ai = None + # Create Config with empty AIs + config = Config() + config.ais = {} + # Call create_ai function and assert that it raises AIError + with self.assertRaises(AIError): + create_ai(self.args, config) + + def test_create_unsupported_ai_error(self) -> None: + # Mock argparse.Namespace with ai='invalid_ai' + self.args.ai = 'invalid_ai' + # Create default Config + config = Config() + # Call create_ai function and assert that it raises AIError + with self.assertRaises(AIError): + create_ai(self.args, config) -- 2.36.6 From ecb699478335c1c054b8dd917762c967270dac5b Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 6 Sep 2023 22:52:03 +0200 Subject: [PATCH 143/170] configuration et al: implemented new Config format --- chatmastermind/ai.py | 13 ++-- chatmastermind/ai_factory.py | 29 ++++++-- chatmastermind/ais/openai.py | 9 +-- chatmastermind/configuration.py | 119 ++++++++++++++++++++++++++------ 4 files changed, 134 insertions(+), 36 deletions(-) diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py index 4a8b914..e94de8e 100644 --- a/chatmastermind/ai.py +++ b/chatmastermind/ai.py @@ -33,18 +33,23 @@ class AI(Protocol): The base class for AI clients. """ + ID: str name: str config: AIConfig def request(self, question: Message, - context: Chat, + chat: Chat, num_answers: int = 1, otags: Optional[set[Tag]] = None) -> AIResponse: """ - Make an AI request, asking the given question with the given - context (i. e. chat history). The nr. of requested answers - corresponds to the nr. of messages in the 'AIResponse'. + Make an AI request. Parameters: + * question: the question to ask + * chat: the chat history to be added as context + * num_answers: nr. of requested answers (corresponds + to the nr. of messages in the 'AIResponse') + * otags: the output tags, i. e. the tags that all + returned messages should contain """ raise NotImplementedError diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index c90366b..c4a063a 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -3,18 +3,35 @@ Creates different AI instances, based on the given configuration. """ import argparse -from .configuration import Config +from typing import cast +from .configuration import Config, OpenAIConfig, default_ai_ID from .ai import AI, AIError from .ais.openai import OpenAI def create_ai(args: argparse.Namespace, config: Config) -> AI: """ - Creates an AI subclass instance from the given args and configuration. + Creates an AI subclass instance from the given arguments + and configuration file. """ - if args.ai == 'openai': - # FIXME: create actual 'OpenAIConfig' and set values from 'args' - # FIXME: use actual name from config - return OpenAI("openai", config.openai) + if args.ai: + try: + ai_conf = config.ais[args.ai] + except KeyError: + raise AIError(f"AI ID '{args.ai}' does not exist in this configuration") + elif default_ai_ID in config.ais: + ai_conf = config.ais[default_ai_ID] + else: + raise AIError("No AI name given and no default exists") + + if ai_conf.name == 'openai': + ai = OpenAI(cast(OpenAIConfig, ai_conf)) + if args.model: + ai.config.model = args.model + if args.max_tokens: + ai.config.max_tokens = args.max_tokens + if args.temperature: + ai.config.temperature = args.temperature + return ai else: raise AIError(f"AI '{args.ai}' is not supported") diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 74438b8..14ce33f 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -17,9 +17,11 @@ class OpenAI(AI): The OpenAI AI client. """ - def __init__(self, name: str, config: OpenAIConfig) -> None: - self.name = name + def __init__(self, config: OpenAIConfig) -> None: + self.ID = config.ID + self.name = config.name self.config = config + openai.api_key = config.api_key def request(self, question: Message, @@ -31,8 +33,7 @@ class OpenAI(AI): chat history. The nr. of requested answers corresponds to the nr. of messages in the 'AIResponse'. """ - # FIXME: use real 'system' message (store in OpenAIConfig) - oai_chat = self.openai_chat(chat, "system", question) + oai_chat = self.openai_chat(chat, self.config.system, question) response = openai.ChatCompletion.create( model=self.config.model, messages=oai_chat, diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 0780604..d82f913 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -1,17 +1,40 @@ import yaml -from typing import Type, TypeVar, Any -from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Type, TypeVar, Any, Optional, ClassVar +from dataclasses import dataclass, asdict, field ConfigInst = TypeVar('ConfigInst', bound='Config') +AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig') OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') +supported_ais: list[str] = ['openai'] +default_ai_ID: str = 'default' +default_config_path = '.config.yaml' + + +class ConfigError(Exception): + pass + + @dataclass class AIConfig: """ The base class of all AI configurations. """ - name: str + # the name of the AI the config class represents + # -> it's a class variable and thus not part of the + # dataclass constructor + name: ClassVar[str] + # a user-defined ID for an AI configuration entry + ID: str + + # the name must not be changed + def __setattr__(self, name: str, value: Any) -> None: + if name == 'name': + raise AttributeError("'{name}' is not allowed to be changed") + else: + super().__setattr__(name, value) @dataclass @@ -19,21 +42,27 @@ class OpenAIConfig(AIConfig): """ The OpenAI section of the configuration file. """ - api_key: str - model: str - temperature: float - max_tokens: int - top_p: float - frequency_penalty: float - presence_penalty: float + name: ClassVar[str] = 'openai' + + # all members have default values, so we can easily create + # a default configuration + ID: str = 'default' + api_key: str = '0123456789' + system: str = 'You are an assistant' + model: str = 'gpt-3.5-turbo-16k' + temperature: float = 1.0 + max_tokens: int = 4000 + top_p: float = 1.0 + frequency_penalty: float = 0.0 + presence_penalty: float = 0.0 @classmethod def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: """ Create OpenAIConfig from a dict. """ - return cls( - name='OpenAI', + res = cls( + system=str(source['system']), api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), @@ -42,6 +71,30 @@ class OpenAIConfig(AIConfig): frequency_penalty=float(source['frequency_penalty']), presence_penalty=float(source['presence_penalty']) ) + # overwrite default ID if provided + if 'ID' in source: + res.ID = source['ID'] + return res + + +def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig: + """ + Creates an AIConfig instance of the given name. + """ + if name.lower() == 'openai': + if conf_dict is None: + return OpenAIConfig() + else: + return OpenAIConfig.from_dict(conf_dict) + else: + raise ConfigError(f"AI '{name}' is not supported") + + +def create_default_ai_configs() -> dict[str, AIConfig]: + """ + Create a dict containing default configurations for all supported AIs. + """ + return {ai_config_instance(name).ID: ai_config_instance(name) for name in supported_ais} @dataclass @@ -49,30 +102,52 @@ class Config: """ The configuration file structure. """ - system: str - db: str - openai: OpenAIConfig + # all members have default values, so we can easily create + # a default configuration + db: str = './db/' + ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs) @classmethod def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst: """ - Create Config from a dict. + Create Config from a dict (with the same format as the config file). """ + # create the correct AI type instances + ais: dict[str, AIConfig] = {} + for ID, conf in source['ais'].items(): + # add the AI ID to the config (for easy internal access) + conf['ID'] = ID + ai_conf = ai_config_instance(conf['name'], conf) + ais[ID] = ai_conf return cls( - system=str(source['system']), db=str(source['db']), - openai=OpenAIConfig.from_dict(source['openai']) + ais=ais ) + @classmethod + def create_default(self, file_path: Path) -> None: + """ + Creates a default Config in the given file. + """ + conf = Config() + conf.to_file(file_path) + @classmethod def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst: with open(path, 'r') as f: source = yaml.load(f, Loader=yaml.FullLoader) return cls.from_dict(source) - def to_file(self, path: str) -> None: - with open(path, 'w') as f: - yaml.dump(asdict(self), f, sort_keys=False) + def to_file(self, file_path: Path) -> None: + # remove the AI name from the config (for a cleaner format) + data = self.as_dict() + for conf in data['ais'].values(): + del (conf['ID']) + with open(file_path, 'w') as f: + yaml.dump(data, f, sort_keys=False) def as_dict(self) -> dict[str, Any]: - return asdict(self) + res = asdict(self) + for ID, conf in res['ais'].items(): + conf.update({'name': self.ais[ID].name}) + return res -- 2.36.6 From c52713c833754290d931f174e0a8aa402e0fd58b Mon Sep 17 00:00:00 2001 From: juk0de Date: Fri, 8 Sep 2023 10:40:22 +0200 Subject: [PATCH 144/170] configuration: added tests --- chatmastermind/configuration.py | 2 +- tests/test_configuration.py | 160 ++++++++++++++++++++++++++++++++ 2 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 tests/test_configuration.py diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index d82f913..398fa03 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -87,7 +87,7 @@ def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> else: return OpenAIConfig.from_dict(conf_dict) else: - raise ConfigError(f"AI '{name}' is not supported") + raise ConfigError(f"Unknown AI '{name}'") def create_default_ai_configs() -> dict[str, AIConfig]: diff --git a/tests/test_configuration.py b/tests/test_configuration.py new file mode 100644 index 0000000..f3f9a98 --- /dev/null +++ b/tests/test_configuration.py @@ -0,0 +1,160 @@ +import os +import unittest +import yaml +from tempfile import NamedTemporaryFile +from pathlib import Path +from typing import cast +from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config + + +class TestAIConfigInstance(unittest.TestCase): + def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None: + ai_config = cast(OpenAIConfig, ai_config_instance('openai')) + ai_reference = OpenAIConfig() + self.assertEqual(ai_config.ID, ai_reference.ID) + self.assertEqual(ai_config.name, ai_reference.name) + self.assertEqual(ai_config.api_key, ai_reference.api_key) + self.assertEqual(ai_config.system, ai_reference.system) + self.assertEqual(ai_config.model, ai_reference.model) + self.assertEqual(ai_config.temperature, ai_reference.temperature) + self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens) + self.assertEqual(ai_config.top_p, ai_reference.top_p) + self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty) + self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty) + + def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None: + conf_dict = { + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict)) + self.assertEqual(ai_config.system, 'Custom system') + self.assertEqual(ai_config.api_key, '9876543210') + self.assertEqual(ai_config.model, 'custom_model') + self.assertEqual(ai_config.max_tokens, 5000) + self.assertAlmostEqual(ai_config.temperature, 0.5) + self.assertAlmostEqual(ai_config.top_p, 0.8) + self.assertAlmostEqual(ai_config.frequency_penalty, 0.7) + self.assertAlmostEqual(ai_config.presence_penalty, 0.2) + + def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None: + with self.assertRaises(ConfigError): + ai_config_instance('invalid_name') + + +class TestConfig(unittest.TestCase): + def setUp(self) -> None: + self.test_file = NamedTemporaryFile(delete=False) + + def tearDown(self) -> None: + os.remove(self.test_file.name) + + def test_from_dict_should_create_config_from_dict(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'openai', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + config = Config.from_dict(source_dict) + self.assertEqual(config.db, './test_db/') + self.assertEqual(len(config.ais), 1) + self.assertEqual(config.ais['default'].name, 'openai') + self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + # check that 'ID' has been added + self.assertEqual(config.ais['default'].ID, 'default') + + def test_create_default_should_create_default_config(self) -> None: + Config.create_default(Path(self.test_file.name)) + with open(self.test_file.name, 'r') as f: + default_config = yaml.load(f, Loader=yaml.FullLoader) + config_reference = Config() + self.assertEqual(default_config['db'], config_reference.db) + + def test_from_file_should_load_config_from_file(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'openai', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + with open(self.test_file.name, 'w') as f: + yaml.dump(source_dict, f) + config = Config.from_file(self.test_file.name) + self.assertIsInstance(config, Config) + self.assertEqual(config.db, './test_db/') + self.assertEqual(len(config.ais), 1) + self.assertIsInstance(config.ais['default'], AIConfig) + self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + + def test_to_file_should_save_config_to_file(self) -> None: + config = Config( + db='./test_db/', + ais={ + 'default': OpenAIConfig( + ID='default', + system='Custom system', + api_key='9876543210', + model='custom_model', + max_tokens=5000, + temperature=0.5, + top_p=0.8, + frequency_penalty=0.7, + presence_penalty=0.2 + ) + } + ) + config.to_file(Path(self.test_file.name)) + with open(self.test_file.name, 'r') as f: + saved_config = yaml.load(f, Loader=yaml.FullLoader) + self.assertEqual(saved_config['db'], './test_db/') + self.assertEqual(len(saved_config['ais']), 1) + self.assertEqual(saved_config['ais']['default']['system'], 'Custom system') + + def test_from_file_error_unknown_ai(self) -> None: + source_dict = { + 'db': './test_db/', + 'ais': { + 'default': { + 'name': 'foobla', + 'system': 'Custom system', + 'api_key': '9876543210', + 'model': 'custom_model', + 'max_tokens': 5000, + 'temperature': 0.5, + 'top_p': 0.8, + 'frequency_penalty': 0.7, + 'presence_penalty': 0.2 + } + } + } + with open(self.test_file.name, 'w') as f: + yaml.dump(source_dict, f) + with self.assertRaises(ConfigError): + Config.from_file(self.test_file.name) -- 2.36.6 From c4f7bcc94e87811a5788f0eea65e0e719c29e68b Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 9 Sep 2023 08:51:17 +0200 Subject: [PATCH 145/170] question_cmd: fixes --- chatmastermind/commands/question.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 9c56ced..1709a3c 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -1,5 +1,6 @@ import argparse from pathlib import Path +from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB from ..message import Message, Question @@ -11,8 +12,26 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: """ Creates (and writes) a new message from the given arguments. """ - # FIXME: add sources to the question - message = Message(question=Question(args.question), + question_parts = [] + question_list = args.question if args.question is not None else [] + source_list = args.source if args.source is not None else [] + + # FIXME: don't surround all sourced files with ``` + # -> do it only if '--source-code-only' is True and no source code + # could be extracted from that file + for question, source in zip_longest(question_list, source_list, fillvalue=None): + if question is not None and source is not None: + with open(source) as r: + question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") + elif question is not None: + question_parts.append(question) + elif source is not None: + with open(source) as r: + question_parts.append(f"```\n{r.read().strip()}\n```") + + full_question = '\n\n'.join(question_parts) + + message = Message(question=Question(full_question), tags=args.output_tags, # FIXME ai=args.ai, model=args.model) -- 2.36.6 From 3eca53998b674d0cc6a218c69170b2f60c110355 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 9 Sep 2023 08:31:30 +0200 Subject: [PATCH 146/170] question cmd: added tests --- tests/test_question_cmd.py | 111 +++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tests/test_question_cmd.py diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py new file mode 100644 index 0000000..96b2fdf --- /dev/null +++ b/tests/test_question_cmd.py @@ -0,0 +1,111 @@ +import os +import unittest +import argparse +import tempfile +from pathlib import Path +from unittest.mock import MagicMock +from chatmastermind.commands.question import create_message +from chatmastermind.message import Message, Question +from chatmastermind.chat import ChatDB + + +class TestMessageCreate(unittest.TestCase): + """ + Test if messages created by the 'question' command have + the correct format. + """ + def setUp(self) -> None: + # create ChatDB structure + self.db_path = tempfile.TemporaryDirectory() + self.cache_path = tempfile.TemporaryDirectory() + self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), + db_path=Path(self.db_path.name)) + # create arguments mock + self.args = MagicMock(spec=argparse.Namespace) + self.args.source = None + self.args.source_code_only = False + self.args.ai = None + self.args.model = None + self.args.output_tags = None + # create some files for sourcing + self.source_file1 = tempfile.NamedTemporaryFile(delete=False) + self.source_file1_content = """This is just text. +No source code. +Nope. Go look elsewhere!""" + with open(self.source_file1.name, 'w') as f: + f.write(self.source_file1_content) + self.source_file2 = tempfile.NamedTemporaryFile(delete=False) + self.source_file2_content = """This is just text. +``` +This is embedded source code. +``` +And some text again.""" + with open(self.source_file2.name, 'w') as f: + f.write(self.source_file2_content) + self.source_file3 = tempfile.NamedTemporaryFile(delete=False) + self.source_file3_content = """This is all source code. +Yes, really. +Language is called 'brainfart'.""" + with open(self.source_file3.name, 'w') as f: + f.write(self.source_file3_content) + + def tearDown(self) -> None: + os.remove(self.source_file1.name) + os.remove(self.source_file2.name) + os.remove(self.source_file3.name) + + def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]: + # exclude '.next' + return list(Path(tmp_dir.name).glob('*.[ty]*')) + + def test_message_file_created(self) -> None: + self.args.question = ["What is this?"] + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + create_message(self.chat, self.args) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 1) + message = Message.from_file(cache_dir_files[0]) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] + + def test_single_question(self) -> None: + self.args.question = ["What is this?"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("What is this?")) + self.assertEqual(len(message.question.source_code()), 0) + + def test_multipart_question(self) -> None: + self.args.question = ["What is this", "'bard' thing?", "Is it good?"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + self.assertEqual(message.question, Question("""What is this + +'bard' thing? + +Is it good?""")) + + def test_single_question_with_text_only_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source = [f"{self.source_file1.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains no source code + # -> don't expect any in the question + self.assertEqual(len(message.question.source_code()), 0) + self.assertEqual(message.question, Question("""What is this? + +{self.source_file1_content}""")) + + def test_single_question_with_embedded_source_code_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source = [f"{self.source_file2.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains 1 source code block + # -> expect it in the question + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question("""What is this? + +{self.source_file2_content}""")) -- 2.36.6 From 86eebc39eafa1fcdfa66ac3eec7aa2c1049c9582 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 15:16:17 +0200 Subject: [PATCH 147/170] Allow in question -s for just sourcing file and -S to source file with ``` encapsulation. --- chatmastermind/commands/question.py | 22 ++++++++++++---------- chatmastermind/main.py | 5 ++--- tests/test_question_cmd.py | 22 ++++++++++++++++++---- 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 1709a3c..818b1de 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -15,19 +15,21 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: question_parts = [] question_list = args.question if args.question is not None else [] source_list = args.source if args.source is not None else [] + code_list = args.source_code if args.source_code is not None else [] - # FIXME: don't surround all sourced files with ``` - # -> do it only if '--source-code-only' is True and no source code - # could be extracted from that file - for question, source in zip_longest(question_list, source_list, fillvalue=None): - if question is not None and source is not None: - with open(source) as r: - question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```") - elif question is not None: + for question, source, code in zip_longest(question_list, source_list, code_list, fillvalue=None): + if question is not None and len(question.strip()) > 0: question_parts.append(question) - elif source is not None: + if source is not None and len(source) > 0: with open(source) as r: - question_parts.append(f"```\n{r.read().strip()}\n```") + content = r.read().strip() + if len(content) > 0: + question_parts.append(content) + if code is not None and len(code) > 0: + with open(code) as r: + content = r.read().strip() + if len(content) > 0: + question_parts.append(f"```\n{content}\n```") full_question = '\n\n'.join(question_parts) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 02cdffd..46bad44 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -170,9 +170,8 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Source add content of a file to the query') - question_cmd_parser.add_argument('-S', '--source-code-only', help='Add pure source code to the chat history', - action='store_true') + question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Add content of a file to the query') + question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') # 'hist' command parser hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser], diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 96b2fdf..06cc527 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -23,7 +23,7 @@ class TestMessageCreate(unittest.TestCase): # create arguments mock self.args = MagicMock(spec=argparse.Namespace) self.args.source = None - self.args.source_code_only = False + self.args.source_code = None self.args.ai = None self.args.model = None self.args.output_tags = None @@ -94,11 +94,11 @@ Is it good?""")) # source file contains no source code # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) - self.assertEqual(message.question, Question("""What is this? + self.assertEqual(message.question, Question(f"""What is this? {self.source_file1_content}""")) - def test_single_question_with_embedded_source_code_source(self) -> None: + def test_single_question_with_embedded_source_source(self) -> None: self.args.question = ["What is this?"] self.args.source = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) @@ -106,6 +106,20 @@ Is it good?""")) # source file contains 1 source code block # -> expect it in the question self.assertEqual(len(message.question.source_code()), 1) - self.assertEqual(message.question, Question("""What is this? + self.assertEqual(message.question, Question(f"""What is this? {self.source_file2_content}""")) + + def test_single_question_with_embedded_source_code_source(self) -> None: + self.args.question = ["What is this?"] + self.args.source_code = [f"{self.source_file2.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # source file contains 1 source code block + # -> expect it in the question + self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(message.question, Question(f"""What is this? + +``` +{self.source_file2_content} +```""")) -- 2.36.6 From 54ece6efeb23f36fa6ffc156ed5dc3d97ec83752 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 15:38:40 +0200 Subject: [PATCH 148/170] Port print arguments -q/-a/-S from main to restructuring. --- chatmastermind/commands/print.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/chatmastermind/commands/print.py b/chatmastermind/commands/print.py index 51e76f8..3d2b990 100644 --- a/chatmastermind/commands/print.py +++ b/chatmastermind/commands/print.py @@ -13,7 +13,15 @@ def print_cmd(args: argparse.Namespace, config: Config) -> None: try: message = Message.from_file(fname) if message: - print(message.to_str(source_code_only=args.source_code_only)) + if args.question: + print(message.question) + elif args.answer: + print(message.answer) + elif message.answer and args.only_source_code: + for code in message.answer.source_code(): + print(code) + else: + print(message.to_str()) except MessageError: print(f"File is not a valid message: {args.file}") sys.exit(1) -- 2.36.6 From 6f3ea9842564f26b86eb7235179962c97c9999b0 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 16:05:27 +0200 Subject: [PATCH 149/170] Small fixes. --- chatmastermind/ai_factory.py | 8 ++++---- chatmastermind/commands/question.py | 6 +++--- tests/test_ai_factory.py | 10 +++++----- tests/test_question_cmd.py | 14 +++++++------- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index c4a063a..bc4583c 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -14,11 +14,11 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: Creates an AI subclass instance from the given arguments and configuration file. """ - if args.ai: + if args.AI: try: - ai_conf = config.ais[args.ai] + ai_conf = config.ais[args.AI] except KeyError: - raise AIError(f"AI ID '{args.ai}' does not exist in this configuration") + raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") elif default_ai_ID in config.ais: ai_conf = config.ais[default_ai_ID] else: @@ -34,4 +34,4 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: ai.config.temperature = args.temperature return ai else: - raise AIError(f"AI '{args.ai}' is not supported") + raise AIError(f"AI '{args.AI}' is not supported") diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 818b1de..90b782b 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -13,7 +13,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: Creates (and writes) a new message from the given arguments. """ question_parts = [] - question_list = args.question if args.question is not None else [] + question_list = args.ask if args.ask is not None else [] source_list = args.source if args.source is not None else [] code_list = args.source_code if args.source_code is not None else [] @@ -35,7 +35,7 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: message = Message(question=Question(full_question), tags=args.output_tags, # FIXME - ai=args.ai, + ai=args.AI, model=args.model) chat.add_to_cache([message]) return message @@ -59,7 +59,7 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: response: AIResponse = ai.request(message, chat, args.num_answers, # FIXME - args.otags) # FIXME + args.output_tags) # FIXME assert response # TODO: # * add answer to the message above (and create diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py index d63970e..d00b319 100644 --- a/tests/test_ai_factory.py +++ b/tests/test_ai_factory.py @@ -10,7 +10,7 @@ from chatmastermind.ais.openai import OpenAI class TestCreateAI(unittest.TestCase): def setUp(self) -> None: self.args = MagicMock(spec=argparse.Namespace) - self.args.ai = 'default' + self.args.AI = 'default' self.args.model = None self.args.max_tokens = None self.args.temperature = None @@ -18,19 +18,19 @@ class TestCreateAI(unittest.TestCase): def test_create_ai_from_args(self) -> None: # Create an AI with the default configuration config = Config() - self.args.ai = 'default' + self.args.AI = 'default' ai = create_ai(self.args, config) self.assertIsInstance(ai, OpenAI) def test_create_ai_from_default(self) -> None: - self.args.ai = None + self.args.AI = None # Create an AI with the default configuration config = Config() ai = create_ai(self.args, config) self.assertIsInstance(ai, OpenAI) def test_create_empty_ai_error(self) -> None: - self.args.ai = None + self.args.AI = None # Create Config with empty AIs config = Config() config.ais = {} @@ -40,7 +40,7 @@ class TestCreateAI(unittest.TestCase): def test_create_unsupported_ai_error(self) -> None: # Mock argparse.Namespace with ai='invalid_ai' - self.args.ai = 'invalid_ai' + self.args.AI = 'invalid_ai' # Create default Config config = Config() # Call create_ai function and assert that it raises AIError diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 06cc527..aa0dc25 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -24,7 +24,7 @@ class TestMessageCreate(unittest.TestCase): self.args = MagicMock(spec=argparse.Namespace) self.args.source = None self.args.source_code = None - self.args.ai = None + self.args.AI = None self.args.model = None self.args.output_tags = None # create some files for sourcing @@ -59,7 +59,7 @@ Language is called 'brainfart'.""" return list(Path(tmp_dir.name).glob('*.[ty]*')) def test_message_file_created(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 0) create_message(self.chat, self.args) @@ -70,14 +70,14 @@ Language is called 'brainfart'.""" self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr] def test_single_question(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("What is this?")) self.assertEqual(len(message.question.source_code()), 0) def test_multipart_question(self) -> None: - self.args.question = ["What is this", "'bard' thing?", "Is it good?"] + self.args.ask = ["What is this", "'bard' thing?", "Is it good?"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) self.assertEqual(message.question, Question("""What is this @@ -87,7 +87,7 @@ Language is called 'brainfart'.""" Is it good?""")) def test_single_question_with_text_only_source(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] self.args.source = [f"{self.source_file1.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) @@ -99,7 +99,7 @@ Is it good?""")) {self.source_file1_content}""")) def test_single_question_with_embedded_source_source(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] self.args.source = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) @@ -111,7 +111,7 @@ Is it good?""")) {self.source_file2_content}""")) def test_single_question_with_embedded_source_code_source(self) -> None: - self.args.question = ["What is this?"] + self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) -- 2.36.6 From f99cd3ed41b404d6e6197cc0239825e5177a2d10 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sat, 9 Sep 2023 18:28:10 +0200 Subject: [PATCH 150/170] question_cmd: fixed source code extraction and added a testcase --- chatmastermind/commands/question.py | 17 +++++-- chatmastermind/main.py | 2 +- chatmastermind/message.py | 2 +- tests/test_question_cmd.py | 79 +++++++++++++++++++++-------- 4 files changed, 72 insertions(+), 28 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 90b782b..756a051 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -3,7 +3,7 @@ from pathlib import Path from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB -from ..message import Message, Question +from ..message import Message, Question, source_code from ..ai_factory import create_ai from ..ai import AI, AIResponse @@ -14,10 +14,10 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: """ question_parts = [] question_list = args.ask if args.ask is not None else [] - source_list = args.source if args.source is not None else [] - code_list = args.source_code if args.source_code is not None else [] + text_files = args.source_text if args.source_text is not None else [] + code_files = args.source_code if args.source_code is not None else [] - for question, source, code in zip_longest(question_list, source_list, code_list, fillvalue=None): + for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None): if question is not None and len(question.strip()) > 0: question_parts.append(question) if source is not None and len(source) > 0: @@ -28,7 +28,14 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: if code is not None and len(code) > 0: with open(code) as r: content = r.read().strip() - if len(content) > 0: + if len(content) == 0: + continue + # try to extract and add source code + code_parts = source_code(content, include_delims=True) + if len(code_parts) > 0: + question_parts += code_parts + # if there's none, add the whole file + else: question_parts.append(f"```\n{content}\n```") full_question = '\n\n'.join(question_parts) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 46bad44..1a375d0 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -170,7 +170,7 @@ def create_parser() -> argparse.ArgumentParser: question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions') question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them', action='store_true') - question_cmd_parser.add_argument('-s', '--source', nargs='+', help='Add content of a file to the query') + question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query') question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history') # 'hist' command parser diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 35de3b9..7107c13 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -414,7 +414,7 @@ class Message(): return '\n'.join(output) def __str__(self) -> str: - return self.to_str(False, False, False) + return self.to_str(True, True, False) def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11 """ diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index aa0dc25..40ea4d8 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -22,18 +22,19 @@ class TestMessageCreate(unittest.TestCase): db_path=Path(self.db_path.name)) # create arguments mock self.args = MagicMock(spec=argparse.Namespace) - self.args.source = None + self.args.source_text = None self.args.source_code = None self.args.AI = None self.args.model = None self.args.output_tags = None - # create some files for sourcing + # File 1 : no source code block, only text self.source_file1 = tempfile.NamedTemporaryFile(delete=False) self.source_file1_content = """This is just text. No source code. Nope. Go look elsewhere!""" with open(self.source_file1.name, 'w') as f: f.write(self.source_file1_content) + # File 2 : one embedded source code block self.source_file2 = tempfile.NamedTemporaryFile(delete=False) self.source_file2_content = """This is just text. ``` @@ -42,12 +43,26 @@ This is embedded source code. And some text again.""" with open(self.source_file2.name, 'w') as f: f.write(self.source_file2_content) + # File 3 : all source code self.source_file3 = tempfile.NamedTemporaryFile(delete=False) self.source_file3_content = """This is all source code. Yes, really. Language is called 'brainfart'.""" with open(self.source_file3.name, 'w') as f: f.write(self.source_file3_content) + # File 4 : two source code blocks + self.source_file4 = tempfile.NamedTemporaryFile(delete=False) + self.source_file4_content = """This is just text. +``` +This is embedded source code. +``` +And some text again. +``` +This is embedded source code. +``` +Aaaand again some text.""" + with open(self.source_file4.name, 'w') as f: + f.write(self.source_file4_content) def tearDown(self) -> None: os.remove(self.source_file1.name) @@ -86,40 +101,62 @@ Language is called 'brainfart'.""" Is it good?""")) - def test_single_question_with_text_only_source(self) -> None: + def test_single_question_with_text_only_file(self) -> None: self.args.ask = ["What is this?"] - self.args.source = [f"{self.source_file1.name}"] + self.args.source_text = [f"{self.source_file1.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) - # source file contains no source code + # file contains no source code (only text) # -> don't expect any in the question self.assertEqual(len(message.question.source_code()), 0) self.assertEqual(message.question, Question(f"""What is this? {self.source_file1_content}""")) - def test_single_question_with_embedded_source_source(self) -> None: - self.args.ask = ["What is this?"] - self.args.source = [f"{self.source_file2.name}"] - message = create_message(self.chat, self.args) - self.assertIsInstance(message, Message) - # source file contains 1 source code block - # -> expect it in the question - self.assertEqual(len(message.question.source_code()), 1) - self.assertEqual(message.question, Question(f"""What is this? - -{self.source_file2_content}""")) - - def test_single_question_with_embedded_source_code_source(self) -> None: + def test_single_question_with_text_file_and_embedded_code(self) -> None: self.args.ask = ["What is this?"] self.args.source_code = [f"{self.source_file2.name}"] message = create_message(self.chat, self.args) self.assertIsInstance(message, Message) - # source file contains 1 source code block + # file contains 1 source code block # -> expect it in the question - self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question("""What is this? + +``` +This is embedded source code. +``` +""")) + + def test_single_question_with_code_only_file(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.source_file3.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file is complete source code + self.assertEqual(len(message.question.source_code()), 1) self.assertEqual(message.question, Question(f"""What is this? ``` -{self.source_file2_content} +{self.source_file3_content} ```""")) + + def test_single_question_with_text_file_and_multi_embedded_code(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.source_file4.name}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file contains 2 source code blocks + # -> expect them in the question + self.assertEqual(len(message.question.source_code()), 2) + self.assertEqual(message.question, Question("""What is this? + +``` +This is embedded source code. +``` + + +``` +This is embedded source code. +``` +""")) -- 2.36.6 From cc76da2ab36ae3cef44bd203018656d3a39501d0 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:39:00 +0200 Subject: [PATCH 151/170] chat: added 'update_messages()' function and test --- chatmastermind/chat.py | 16 ++++++++++++++++ tests/test_chat.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 4e8fb20..ddabb56 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -386,3 +386,19 @@ class ChatDB(Chat): msgs = iter(messages if messages else self.messages) while (m := next(msgs, None)): m.to_file() + + def update_messages(self, messages: list[Message], write: bool = True) -> None: + """ + Update existing messages. A message is determined as 'existing' if a message with + the same base filename (i. e. 'file_path.name') is already in the list. Only accepts + existing messages. + """ + if any(not message_in(m, self.messages) for m in messages): + raise ChatError("Can't update messages that are not in the internal list") + # remove old versions and add new ones + self.messages = [m for m in self.messages if not message_in(m, messages)] + self.messages += messages + self.sort() + # write the UPDATED messages if requested + if write: + self.write_messages(messages) diff --git a/tests/test_chat.py b/tests/test_chat.py index 8e4aa8c..ed630a4 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -440,3 +440,31 @@ class TestChatDB(unittest.TestCase): cache_dir_files = self.message_list(self.cache_path) self.assertEqual(len(cache_dir_files), 1) self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files) + + def test_chat_db_update_messages(self) -> None: + # create a new ChatDB instance + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name)) + + db_dir_files = self.message_list(self.db_path) + self.assertEqual(len(db_dir_files), 4) + cache_dir_files = self.message_list(self.cache_path) + self.assertEqual(len(cache_dir_files), 0) + + message = chat_db.messages[0] + message.answer = Answer("New answer") + # update message without writing + chat_db.update_messages([message], write=False) + self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) + # re-read the message and check for old content + chat_db.read_db() + self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1")) + # now check with writing (message should be overwritten) + chat_db.update_messages([message], write=True) + chat_db.read_db() + self.assertEqual(chat_db.messages[0].answer, Answer("New answer")) + # test without file_path -> expect error + message1 = Message(question=Question("Question 1"), + answer=Answer("Answer 1")) + with self.assertRaises(ChatError): + chat_db.update_messages([message1]) -- 2.36.6 From 864ab7aeb1c2980145b258edb4b8baf76dbcd3bf Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 19:18:14 +0200 Subject: [PATCH 152/170] chat: added check for existing files when creating new filenames --- chatmastermind/chat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index ddabb56..7c4dd35 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -62,7 +62,10 @@ def make_file_path(dir_path: Path, Create a file_path for the given directory using the given file_suffix and ID generator function. """ - return dir_path / f"{next_fid():04d}{file_suffix}" + file_path = dir_path / f"{next_fid():04d}{file_suffix}" + while file_path.exists(): + file_path = dir_path / f"{next_fid():04d}{file_suffix}" + return file_path def write_dir(dir_path: Path, -- 2.36.6 From faac42d3c277b06af0b5f36bac18189f31a410cb Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:52:07 +0200 Subject: [PATCH 153/170] question_cmd: fixed '--ask' command --- chatmastermind/ai.py | 6 ++++++ chatmastermind/ais/openai.py | 19 ++++++++++++++----- chatmastermind/commands/question.py | 15 ++++++++++----- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py index e94de8e..b97b5f1 100644 --- a/chatmastermind/ai.py +++ b/chatmastermind/ai.py @@ -66,3 +66,9 @@ class AI(Protocol): and is not implemented for all AIs. """ raise NotImplementedError + + def print(self) -> None: + """ + Print some info about the current AI, like system message. + """ + pass diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 14ce33f..1db4d20 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -43,16 +43,20 @@ class OpenAI(AI): n=num_answers, frequency_penalty=self.config.frequency_penalty, presence_penalty=self.config.presence_penalty) - answers: list[Message] = [] - for choice in response['choices']: # type: ignore + question.answer = Answer(response['choices'][0]['message']['content']) + question.tags = otags + question.ai = self.name + question.model = self.config.model + answers: list[Message] = [question] + for choice in response['choices'][1:]: # type: ignore answers.append(Message(question=question.question, answer=Answer(choice['message']['content']), tags=otags, ai=self.name, model=self.config.model)) - return AIResponse(answers, Tokens(response['usage']['prompt'], - response['usage']['completion'], - response['usage']['total'])) + return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], + response['usage']['completion_tokens'], + response['usage']['total_tokens'])) def models(self) -> list[str]: """ @@ -95,3 +99,8 @@ class OpenAI(AI): def tokens(self, data: Union[Message, Chat]) -> int: raise NotImplementedError + + def print(self) -> None: + print(f"MODEL: {self.config.model}") + print("=== SYSTEM ===") + print(self.config.system) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 756a051..fdabd62 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -63,15 +63,20 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: # create the correct AI instance ai: AI = create_ai(args, config) if args.ask: + ai.print() + chat.print(paged=False) response: AIResponse = ai.request(message, chat, args.num_answers, # FIXME args.output_tags) # FIXME - assert response - # TODO: - # * add answer to the message above (and create - # more messages for any additional answers) - pass + chat.update_messages([response.messages[0]]) + chat.add_to_cache(response.messages[1:]) + for idx, msg in enumerate(response.messages): + print(f"=== ANSWER {idx+1} ===") + print(msg.answer) + if response.tokens: + print("===============") + print(response.tokens) elif args.repeat: lmessage = chat.latest_message() assert lmessage -- 2.36.6 From 595ff8e294c945db38effda65d6445668db99e74 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:54:17 +0200 Subject: [PATCH 154/170] question_cmd: added message filtering by tags --- chatmastermind/commands/question.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index fdabd62..f439447 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -3,7 +3,7 @@ from pathlib import Path from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB -from ..message import Message, Question, source_code +from ..message import Message, MessageFilter, Question, source_code from ..ai_factory import create_ai from ..ai import AI, AIResponse @@ -52,8 +52,12 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'question' command. """ + mfilter = MessageFilter(tags_or=args.or_tags, + tags_and=args.and_tags, + tags_not=args.exclude_tags) chat = ChatDB.from_dir(cache_path=Path('.'), - db_path=Path(config.db)) + db_path=Path(config.db), + mfilter=mfilter) # if it's a new question, create and store it immediately if args.ask or args.create: message = create_message(chat, args) @@ -77,14 +81,14 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: if response.tokens: print("===============") print(response.tokens) - elif args.repeat: + elif args.repeat is not None: lmessage = chat.latest_message() assert lmessage # TODO: repeat either the last question or the # one(s) given in 'args.repeat' (overwrite # existing ones if 'args.overwrite' is True) pass - elif args.process: + elif args.process is not None: # TODO: process either all questions without an # answer or the one(s) given in 'args.process' pass -- 2.36.6 From 2e08ccf6060ebac3f66cc2d3e0d0c45ee9e5c3e2 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 07:55:47 +0200 Subject: [PATCH 155/170] openai: stores AI.ID instead of AI.name in message --- chatmastermind/ais/openai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index 1db4d20..a388a7a 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -45,14 +45,14 @@ class OpenAI(AI): presence_penalty=self.config.presence_penalty) question.answer = Answer(response['choices'][0]['message']['content']) question.tags = otags - question.ai = self.name + question.ai = self.ID question.model = self.config.model answers: list[Message] = [question] for choice in response['choices'][1:]: # type: ignore answers.append(Message(question=question.question, answer=Answer(choice['message']['content']), tags=otags, - ai=self.name, + ai=self.ID, model=self.config.model)) return AIResponse(answers, Tokens(response['usage']['prompt_tokens'], response['usage']['completion_tokens'], -- 2.36.6 From 66908f5fed330f625250c35ae12c4ba970d83daf Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:24:20 +0200 Subject: [PATCH 156/170] message: fixed matching with empty tag sets --- chatmastermind/message.py | 4 ++-- tests/test_chat.py | 22 ++++++++++++++++++++-- tests/test_message.py | 6 ++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index 7107c13..df59ed6 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -312,7 +312,7 @@ class Message(): mfilter.tags_not if mfilter else None) else: message = cls.__from_file_yaml(file_path) - if message and (not mfilter or (mfilter and message.match(mfilter))): + if message and (mfilter is None or message.match(mfilter)): return message else: return None @@ -508,7 +508,7 @@ class Message(): Return True if all attributes match, else False. """ mytags = self.tags or set() - if (((mfilter.tags_or or mfilter.tags_and or mfilter.tags_not) + if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None) and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503 or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503 or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503 diff --git a/tests/test_chat.py b/tests/test_chat.py index ed630a4..1916a2b 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -202,7 +202,25 @@ class TestChatDB(unittest.TestCase): self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0003.txt')) - def test_chat_db_filter(self) -> None: + def test_chat_db_from_dir_filter_tags(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(tags_or={Tag('tag1')})) + self.assertEqual(len(chat_db.messages), 1) + self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name)) + self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name)) + self.assertEqual(chat_db.messages[0].file_path, + pathlib.Path(self.db_path.name, '0001.txt')) + + def test_chat_db_from_dir_filter_tags_empty(self) -> None: + chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), + pathlib.Path(self.db_path.name), + mfilter=MessageFilter(tags_or=set(), + tags_and=set(), + tags_not=set())) + self.assertEqual(len(chat_db.messages), 0) + + def test_chat_db_from_dir_filter_answer(self) -> None: chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), mfilter=MessageFilter(answer_contains='Answer 2')) @@ -213,7 +231,7 @@ class TestChatDB(unittest.TestCase): pathlib.Path(self.db_path.name, '0002.yaml')) self.assertEqual(chat_db.messages[0].answer, 'Answer 2') - def test_chat_db_from_messges(self) -> None: + def test_chat_db_from_messages(self) -> None: chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name), pathlib.Path(self.db_path.name), messages=[self.message1, self.message2, diff --git a/tests/test_message.py b/tests/test_message.py index 57d5982..1f440df 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -300,6 +300,12 @@ This is a question. MessageFilter(tags_or={Tag('tag1')})) self.assertIsNone(message) + def test_from_file_txt_empty_tags_dont_match(self) -> None: + message = Message.from_file(self.file_path_min, + MessageFilter(tags_or=set(), + tags_and=set())) + self.assertIsNone(message) + def test_from_file_txt_no_tags_match_tags_not(self) -> None: message = Message.from_file(self.file_path_min, MessageFilter(tags_not={Tag('tag1')})) -- 2.36.6 From b840ebd7923b6e31f2f7070c0d10d74bd2343a51 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 19:56:50 +0200 Subject: [PATCH 157/170] message: to_file() now uses intermediate temporary file --- chatmastermind/message.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/chatmastermind/message.py b/chatmastermind/message.py index df59ed6..64929a3 100644 --- a/chatmastermind/message.py +++ b/chatmastermind/message.py @@ -3,6 +3,8 @@ Module implementing message related functions and classes. """ import pathlib import yaml +import tempfile +import shutil from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable from dataclasses import dataclass, asdict, field from .tags import Tag, TagLine, TagError, match_tags, rename_tags @@ -445,16 +447,18 @@ class Message(): * Answer.txt_header * Answer """ - with open(file_path, "w") as fd: + with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: + temp_file_path = pathlib.Path(temp_fd.name) if self.tags: - fd.write(f'{TagLine.from_set(self.tags)}\n') + temp_fd.write(f'{TagLine.from_set(self.tags)}\n') if self.ai: - fd.write(f'{AILine.from_ai(self.ai)}\n') + temp_fd.write(f'{AILine.from_ai(self.ai)}\n') if self.model: - fd.write(f'{ModelLine.from_model(self.model)}\n') - fd.write(f'{Question.txt_header}\n{self.question}\n') + temp_fd.write(f'{ModelLine.from_model(self.model)}\n') + temp_fd.write(f'{Question.txt_header}\n{self.question}\n') if self.answer: - fd.write(f'{Answer.txt_header}\n{self.answer}\n') + temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n') + shutil.move(temp_file_path, file_path) def __to_file_yaml(self, file_path: pathlib.Path) -> None: """ @@ -466,7 +470,8 @@ class Message(): * Message.ai_yaml_key: str [Optional] * Message.model_yaml_key: str [Optional] """ - with open(file_path, "w") as fd: + with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd: + temp_file_path = pathlib.Path(temp_fd.name) data: YamlDict = {Question.yaml_key: str(self.question)} if self.answer: data[Answer.yaml_key] = str(self.answer) @@ -476,7 +481,8 @@ class Message(): data[self.model_yaml_key] = self.model if self.tags: data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags]) - yaml.dump(data, fd, sort_keys=False) + yaml.dump(data, temp_fd, sort_keys=False) + shutil.move(temp_file_path, file_path) def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]: """ -- 2.36.6 From 22fa187e5f8f886bddcf61fd4ccbb0825cedf044 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:25:33 +0200 Subject: [PATCH 158/170] question_cmd: when no tags are specified, no tags are selected --- chatmastermind/commands/question.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index f439447..4936d8f 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -52,9 +52,9 @@ def question_cmd(args: argparse.Namespace, config: Config) -> None: """ Handler for the 'question' command. """ - mfilter = MessageFilter(tags_or=args.or_tags, - tags_and=args.and_tags, - tags_not=args.exclude_tags) + mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(), + tags_and=args.and_tags if args.and_tags is not None else set(), + tags_not=args.exclude_tags if args.exclude_tags is not None else set()) chat = ChatDB.from_dir(cache_path=Path('.'), db_path=Path(config.db), mfilter=mfilter) -- 2.36.6 From 481f9ecf7cf178fce8dd55ff8af854b0db0835b6 Mon Sep 17 00:00:00 2001 From: juk0de Date: Sun, 10 Sep 2023 08:37:06 +0200 Subject: [PATCH 159/170] configuration: improved config file format --- chatmastermind/configuration.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 398fa03..08f6cbe 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -17,6 +17,18 @@ class ConfigError(Exception): pass +def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode: + """ + Changes the YAML dump style to multiline syntax for multiline strings. + """ + if len(data.splitlines()) > 1: + return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar('tag:yaml.org,2002:str', data) + + +yaml.add_representer(str, str_presenter) + + @dataclass class AIConfig: """ @@ -48,13 +60,13 @@ class OpenAIConfig(AIConfig): # a default configuration ID: str = 'default' api_key: str = '0123456789' - system: str = 'You are an assistant' model: str = 'gpt-3.5-turbo-16k' temperature: float = 1.0 max_tokens: int = 4000 top_p: float = 1.0 frequency_penalty: float = 0.0 presence_penalty: float = 0.0 + system: str = 'You are an assistant' @classmethod def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst: @@ -62,14 +74,14 @@ class OpenAIConfig(AIConfig): Create OpenAIConfig from a dict. """ res = cls( - system=str(source['system']), api_key=str(source['api_key']), model=str(source['model']), max_tokens=int(source['max_tokens']), temperature=float(source['temperature']), top_p=float(source['top_p']), frequency_penalty=float(source['frequency_penalty']), - presence_penalty=float(source['presence_penalty']) + presence_penalty=float(source['presence_penalty']), + system=str(source['system']) ) # overwrite default ID if provided if 'ID' in source: @@ -148,6 +160,8 @@ class Config: def as_dict(self) -> dict[str, Any]: res = asdict(self) + # add the AI name manually (as first element) + # (not done by 'asdict' because it's a class variable) for ID, conf in res['ais'].items(): - conf.update({'name': self.ais[ID].name}) + res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf} return res -- 2.36.6 From 33023d29f9de4fde3e12bc49d34aee88c89dca2f Mon Sep 17 00:00:00 2001 From: juk0de Date: Mon, 11 Sep 2023 07:38:49 +0200 Subject: [PATCH 160/170] configuration: made 'default' AI ID optional --- chatmastermind/ai_factory.py | 18 ++++++++++++------ chatmastermind/configuration.py | 3 +-- tests/test_ai_factory.py | 4 ++-- tests/test_configuration.py | 14 +++++++------- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index bc4583c..420b287 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -4,25 +4,31 @@ Creates different AI instances, based on the given configuration. import argparse from typing import cast -from .configuration import Config, OpenAIConfig, default_ai_ID +from .configuration import Config, AIConfig, OpenAIConfig from .ai import AI, AIError from .ais.openai import OpenAI -def create_ai(args: argparse.Namespace, config: Config) -> AI: +def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 """ Creates an AI subclass instance from the given arguments - and configuration file. + and configuration file. If AI has not been set in the + arguments, it searches for the ID 'default'. If that + is not found, it uses the first AI in the list. """ + ai_conf: AIConfig if args.AI: try: ai_conf = config.ais[args.AI] except KeyError: raise AIError(f"AI ID '{args.AI}' does not exist in this configuration") - elif default_ai_ID in config.ais: - ai_conf = config.ais[default_ai_ID] + elif 'default' in config.ais: + ai_conf = config.ais['default'] else: - raise AIError("No AI name given and no default exists") + try: + ai_conf = next(iter(config.ais.values())) + except StopIteration: + raise AIError("No AI found in this configuration") if ai_conf.name == 'openai': ai = OpenAI(cast(OpenAIConfig, ai_conf)) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 08f6cbe..5397f4a 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -9,7 +9,6 @@ OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig') supported_ais: list[str] = ['openai'] -default_ai_ID: str = 'default' default_config_path = '.config.yaml' @@ -58,7 +57,7 @@ class OpenAIConfig(AIConfig): # all members have default values, so we can easily create # a default configuration - ID: str = 'default' + ID: str = 'myopenai' api_key: str = '0123456789' model: str = 'gpt-3.5-turbo-16k' temperature: float = 1.0 diff --git a/tests/test_ai_factory.py b/tests/test_ai_factory.py index d00b319..9cb94d3 100644 --- a/tests/test_ai_factory.py +++ b/tests/test_ai_factory.py @@ -10,7 +10,7 @@ from chatmastermind.ais.openai import OpenAI class TestCreateAI(unittest.TestCase): def setUp(self) -> None: self.args = MagicMock(spec=argparse.Namespace) - self.args.AI = 'default' + self.args.AI = 'myopenai' self.args.model = None self.args.max_tokens = None self.args.temperature = None @@ -18,7 +18,7 @@ class TestCreateAI(unittest.TestCase): def test_create_ai_from_args(self) -> None: # Create an AI with the default configuration config = Config() - self.args.AI = 'default' + self.args.AI = 'myopenai' ai = create_ai(self.args, config) self.assertIsInstance(ai, OpenAI) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index f3f9a98..ba8a5aa 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -59,7 +59,7 @@ class TestConfig(unittest.TestCase): source_dict = { 'db': './test_db/', 'ais': { - 'default': { + 'myopenai': { 'name': 'openai', 'system': 'Custom system', 'api_key': '9876543210', @@ -75,10 +75,10 @@ class TestConfig(unittest.TestCase): config = Config.from_dict(source_dict) self.assertEqual(config.db, './test_db/') self.assertEqual(len(config.ais), 1) - self.assertEqual(config.ais['default'].name, 'openai') - self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system') + self.assertEqual(config.ais['myopenai'].name, 'openai') + self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system') # check that 'ID' has been added - self.assertEqual(config.ais['default'].ID, 'default') + self.assertEqual(config.ais['myopenai'].ID, 'myopenai') def test_create_default_should_create_default_config(self) -> None: Config.create_default(Path(self.test_file.name)) @@ -117,8 +117,8 @@ class TestConfig(unittest.TestCase): config = Config( db='./test_db/', ais={ - 'default': OpenAIConfig( - ID='default', + 'myopenai': OpenAIConfig( + ID='myopenai', system='Custom system', api_key='9876543210', model='custom_model', @@ -135,7 +135,7 @@ class TestConfig(unittest.TestCase): saved_config = yaml.load(f, Loader=yaml.FullLoader) self.assertEqual(saved_config['db'], './test_db/') self.assertEqual(len(saved_config['ais']), 1) - self.assertEqual(saved_config['ais']['default']['system'], 'Custom system') + self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system') def test_from_file_error_unknown_ai(self) -> None: source_dict = { -- 2.36.6 From 17de0b99678381fa7e0fac9285d71bc26c649a67 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Mon, 11 Sep 2023 13:17:59 +0200 Subject: [PATCH 161/170] Remove old code. --- chatmastermind/main.py | 105 +---------------------------------------- 1 file changed, 1 insertion(+), 104 deletions(-) diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 1a375d0..99aca09 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -18,110 +18,7 @@ from .commands.print import print_cmd def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]: config = Config.from_file(parsed_args.config) - return get_tags_unique(config, prefix) - - -def tags_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'tags' command. - """ - chat = ChatDB.from_dir(cache_path=Path('.'), - db_path=Path(config.db)) - if args.list: - tags_freq = chat.tags_frequency(args.prefix, args.contain) - for tag, freq in tags_freq.items(): - print(f"- {tag}: {freq}") - # TODO: add renaming - - -def config_cmd(args: argparse.Namespace) -> None: - """ - Handler for the 'config' command. - """ - if args.create: - Config.create_default(Path(args.create)) - - -def question_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'question' command. - """ - chat = ChatDB.from_dir(cache_path=Path('.'), - db_path=Path(config.db)) - # if it's a new question, create and store it immediately - if args.ask or args.create: - # FIXME: add sources to the question - message = Message(question=Question(args.question), - tags=args.ouput_tags, # FIXME - ai=args.ai, - model=args.model) - chat.add_to_cache([message]) - if args.create: - return - - # create the correct AI instance - ai: AI = create_ai(args, config) - if args.ask: - response: AIResponse = ai.request(message, - chat, - args.num_answers, # FIXME - args.otags) # FIXME - assert response - # TODO: - # * add answer to the message above (and create - # more messages for any additional answers) - pass - elif args.repeat: - lmessage = chat.latest_message() - assert lmessage - # TODO: repeat either the last question or the - # one(s) given in 'args.repeat' (overwrite - # existing ones if 'args.overwrite' is True) - pass - elif args.process: - # TODO: process either all questions without an - # answer or the one(s) given in 'args.process' - pass - - -def hist_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'hist' command. - """ - - mfilter = MessageFilter(tags_or=args.or_tags, - tags_and=args.and_tags, - tags_not=args.exclude_tags, - question_contains=args.question, - answer_contains=args.answer) - chat = ChatDB.from_dir(Path('.'), - Path(config.db), - mfilter=mfilter) - chat.print(args.source_code_only, - args.with_tags, - args.with_files) - - -def print_cmd(args: argparse.Namespace, config: Config) -> None: - """ - Handler for the 'print' command. - """ - fname = Path(args.file) - try: - message = Message.from_file(fname) - if message: - print(message.to_str(source_code_only=args.source_code_only)) - except MessageError: - print(f"File is not a valid message: {args.file}") - sys.exit(1) - if args.source_code_only: - display_source_code(data['answer']) - elif args.answer: - print(data['answer'].strip()) - elif args.question: - print(data['question'].strip()) - else: - print(dump_data(data).strip()) + return list(Message.tags_from_dir(Path(config.db), prefix=prefix)) def create_parser() -> argparse.ArgumentParser: -- 2.36.6 From 2b62cb8c4be08b541cf8eee71ca7c731af7100b5 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Sat, 9 Sep 2023 19:24:45 +0200 Subject: [PATCH 162/170] Remove the `-*terminal_width()` to save space on screen. --- chatmastermind/chat.py | 1 - tests/test_chat.py | 14 +------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/chatmastermind/chat.py b/chatmastermind/chat.py index 7c4dd35..dd18293 100644 --- a/chatmastermind/chat.py +++ b/chatmastermind/chat.py @@ -204,7 +204,6 @@ class Chat: output.append(message.to_str(source_code_only=True)) continue output.append(message.to_str(with_tags, with_files)) - output.append('\n' + ('-' * terminal_width()) + '\n') if paged: print_paged('\n'.join(output)) else: diff --git a/tests/test_chat.py b/tests/test_chat.py index 1916a2b..f34cb24 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -6,7 +6,7 @@ from io import StringIO from unittest.mock import patch from chatmastermind.tags import TagLine from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter -from chatmastermind.chat import Chat, ChatDB, terminal_width, ChatError +from chatmastermind.chat import Chat, ChatDB, ChatError class TestChat(unittest.TestCase): @@ -92,16 +92,10 @@ class TestChat(unittest.TestCase): Question 1 {Answer.txt_header} Answer 1 - -{'-'*terminal_width()} - {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 - -{'-'*terminal_width()} - """ self.assertEqual(mock_stdout.getvalue(), expected_output) @@ -115,18 +109,12 @@ FILE: 0001.txt Question 1 {Answer.txt_header} Answer 1 - -{'-'*terminal_width()} - {TagLine.prefix} btag2 FILE: 0002.txt {Question.txt_header} Question 2 {Answer.txt_header} Answer 2 - -{'-'*terminal_width()} - """ self.assertEqual(mock_stdout.getvalue(), expected_output) -- 2.36.6 From f96e82bdd7c96fe12de38956be501b715c4d6a9c Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Tue, 12 Sep 2023 16:34:17 +0200 Subject: [PATCH 163/170] Implement the config -l and config -m commands. --- chatmastermind/ai.py | 6 ++++++ chatmastermind/commands/config.py | 9 +++++++++ chatmastermind/configuration.py | 1 + 3 files changed, 16 insertions(+) diff --git a/chatmastermind/ai.py b/chatmastermind/ai.py index b97b5f1..622aa4f 100644 --- a/chatmastermind/ai.py +++ b/chatmastermind/ai.py @@ -59,6 +59,12 @@ class AI(Protocol): """ raise NotImplementedError + def print_models(self) -> None: + """ + Print all models supported by this AI. + """ + raise NotImplementedError + def tokens(self, data: Union[Message, Chat]) -> int: """ Computes the nr. of AI language tokens for the given message diff --git a/chatmastermind/commands/config.py b/chatmastermind/commands/config.py index 262164c..3714573 100644 --- a/chatmastermind/commands/config.py +++ b/chatmastermind/commands/config.py @@ -1,6 +1,8 @@ import argparse from pathlib import Path from ..configuration import Config +from ..ai import AI +from ..ai_factory import create_ai def config_cmd(args: argparse.Namespace) -> None: @@ -9,3 +11,10 @@ def config_cmd(args: argparse.Namespace) -> None: """ if args.create: Config.create_default(Path(args.create)) + elif args.list_models or args.print_model: + config: Config = Config.from_file(args.config) + ai: AI = create_ai(args, config) + if args.list_models: + ai.print_models() + else: + print(ai.config.model) diff --git a/chatmastermind/configuration.py b/chatmastermind/configuration.py index 5397f4a..1415eb2 100644 --- a/chatmastermind/configuration.py +++ b/chatmastermind/configuration.py @@ -39,6 +39,7 @@ class AIConfig: name: ClassVar[str] # a user-defined ID for an AI configuration entry ID: str + model: str = 'n/a' # the name must not be changed def __setattr__(self, name: str, value: Any) -> None: -- 2.36.6 From 544bf0bf069973a5b65e1e537a8863a36a66a7e5 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Tue, 12 Sep 2023 16:34:39 +0200 Subject: [PATCH 164/170] Improve README.md --- README.md | 116 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 74 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index d55102a..00f4720 100644 --- a/README.md +++ b/README.md @@ -37,63 +37,95 @@ cmm [global options] command [command options] ### Global Options -- `-c`, `--config`: Config file name (defaults to `.config.yaml`). - -### Commands - -- `ask`: Ask a question. -- `hist`: Print chat history. -- `tag`: Manage tags. -- `config`: Manage configuration. -- `print`: Print files. +- `-C`, `--config`: Config file name (defaults to `.config.yaml`). ### Command Options -#### `ask` Command Options +#### Question -- `-q`, `--question`: Question to ask (required). -- `-m`, `--max-tokens`: Max tokens to use. -- `-T`, `--temperature`: Temperature to use. -- `-M`, `--model`: Model to use. -- `-n`, `--number`: Number of answers to produce (default is 3). -- `-s`, `--source`: Add content of a file to the query. -- `-S`, `--only-source-code`: Add pure source code to the chat history. -- `-t`, `--tags`: List of tag names. -- `-e`, `--extags`: List of tag names to exclude. -- `-o`, `--output-tags`: List of output tag names (default is the input tags). -- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries. +The `question` command is used to ask, create, and process questions. -#### `hist` Command Options +```bash +cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI] [-M MODEL] [-n NUM] [-m MAX] [-T TEMP] (-a ASK | -c CREATE | -r REPEAT | -p PROCESS) [-O] [-s SOURCE]... [-S SOURCE]... +``` -- `-d`, `--dump`: Print chat history as Python structure. -- `-w`, `--with-tags`: Print chat history with tags. -- `-W`, `--with-files`: Print chat history with filenames. -- `-S`, `--only-source-code`: Print only source code. -- `-t`, `--tags`: List of tag names. -- `-e`, `--extags`: List of tag names to exclude. -- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries. +* `-t, --or-tags OTAGS` : List of tags (one must match) +* `-k, --and-tags ATAGS` : List of tags (all must match) +* `-x, --exclude-tags XTAGS` : List of tags to exclude +* `-o, --output-tags OUTTAGS` : List of output tags (default: use input tags) +* `-A, --AI AI` : AI ID to use +* `-M, --model MODEL` : Model to use +* `-n, --num-answers NUM` : Number of answers to request +* `-m, --max-tokens MAX` : Max. number of tokens +* `-T, --temperature TEMP` : Temperature value +* `-a, --ask ASK` : Ask a question +* `-c, --create CREATE` : Create a question +* `-r, --repeat REPEAT` : Repeat a question +* `-p, --process PROCESS` : Process existing questions +* `-O, --overwrite` : Overwrite existing messages when repeating them +* `-s, --source-text SOURCE` : Add content of a file to the query +* `-S, --source-code SOURCE` : Add source code file content to the chat history -#### `tag` Command Options +#### Hist -- `-l`, `--list`: List all tags and their frequency. +The `hist` command is used to print the chat history. -#### `config` Command Options +```bash +cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A ANSWER] [-Q QUESTION] +``` -- `-l`, `--list-models`: List all available models. -- `-m`, `--print-model`: Print the currently configured model. -- `-M`, `--model`: Set model in the config file. +* `-t, --or-tags OTAGS` : List of tags (one must match) +* `-k, --and-tags ATAGS` : List of tags (all must match) +* `-x, --exclude-tags XTAGS` : List of tags to exclude +* `-w, --with-tags` : Print chat history with tags +* `-W, --with-files` : Print chat history with filenames +* `-S, --source-code-only` : Print only source code +* `-A, --answer ANSWER` : Search for answer substring +* `-Q, --question QUESTION` : Search for question substring -#### `print` Command Options +#### Tags -- `-f`, `--file`: File to print (required). -- `-S`, `--only-source-code`: Print only source code. +The `tags` command is used to manage tags. + +```bash +cmm tags (-l | -p PREFIX | -c CONTENT) +``` + +* `-l, --list` : List all tags and their frequency +* `-p, --prefix PREFIX` : Filter tags by prefix +* `-c, --contain CONTENT` : Filter tags by contained substring + +#### Config + +The `config` command is used to manage the configuration. + +```bash +cmm config (-l | -m | -c CREATE) +``` + +* `-l, --list-models` : List all available models +* `-m, --print-model` : Print the currently configured model +* `-c, --create CREATE` : Create config with default settings in the given file + +#### Print + +The `print` command is used to print message files. + +```bash +cmm print -f FILE [-q | -a | -S] +``` + +* `-f, --file FILE` : File to print +* `-q, --question` : Print only question +* `-a, --answer` : Print only answer +* `-S, --only-source-code` : Print only source code ### Examples 1. Ask a question: ```bash -cmm ask -q "What is the meaning of life?" -t philosophy -e religion +cmm question -a "What is the meaning of life?" -t philosophy -x religion ``` 2. Display the chat history: @@ -105,19 +137,19 @@ cmm hist 3. Filter chat history by tags: ```bash -cmm hist -t tag1 tag2 +cmm hist --or-tags tag1 tag2 ``` 4. Exclude chat history by tags: ```bash -cmm hist -e tag3 tag4 +cmm hist --exclude-tags tag3 tag4 ``` 5. List all tags and their frequency: ```bash -cmm tag -l +cmm tags -l ``` 6. Print the contents of a file: -- 2.36.6 From 1ec3d6fcdaba199594dc847c6251e4c55d7e5e87 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Tue, 12 Sep 2023 16:37:50 +0200 Subject: [PATCH 165/170] Make it possible to specify the AI in config command. --- chatmastermind/ai_factory.py | 8 ++++---- chatmastermind/ais/openai.py | 7 ++++++- chatmastermind/main.py | 1 + 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index 420b287..a3cf9c3 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -17,7 +17,7 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 is not found, it uses the first AI in the list. """ ai_conf: AIConfig - if args.AI: + if 'AI' in args and args.AI: try: ai_conf = config.ais[args.AI] except KeyError: @@ -32,11 +32,11 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 if ai_conf.name == 'openai': ai = OpenAI(cast(OpenAIConfig, ai_conf)) - if args.model: + if 'model' in args and args.model: ai.config.model = args.model - if args.max_tokens: + if 'max_tokens' in args and args.max_tokens: ai.config.max_tokens = args.max_tokens - if args.temperature: + if 'temperature' in args and args.temperature: ai.config.temperature = args.temperature return ai else: diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index a388a7a..0e7ad41 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -62,7 +62,12 @@ class OpenAI(AI): """ Return all models supported by this AI. """ - raise NotImplementedError + ret = [] + for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): + if engine['ready']: + ret.append(engine['id']) + ret.sort() + return ret def print_models(self) -> None: """ diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 99aca09..7e18185 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -100,6 +100,7 @@ def create_parser() -> argparse.ArgumentParser: help="Manage configuration", aliases=['c']) config_cmd_parser.set_defaults(func=config_cmd) + config_cmd_parser.add_argument('-A', '--AI', help='AI ID to use') config_group = config_cmd_parser.add_mutually_exclusive_group(required=True) config_group.add_argument('-l', '--list-models', help="List all available models", action='store_true') -- 2.36.6 From a7345cbc419b302aa0a3bbccd8b16d580e0c4d90 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 13 Sep 2023 07:52:05 +0200 Subject: [PATCH 166/170] ai_factory: fixed argument parsing bug --- chatmastermind/ai_factory.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index a3cf9c3..36a987b 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -17,7 +17,7 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 is not found, it uses the first AI in the list. """ ai_conf: AIConfig - if 'AI' in args and args.AI: + if hasattr(args, 'AI') and args.AI: try: ai_conf = config.ais[args.AI] except KeyError: @@ -32,11 +32,11 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 if ai_conf.name == 'openai': ai = OpenAI(cast(OpenAIConfig, ai_conf)) - if 'model' in args and args.model: + if hasattr(args, 'model') and args.model: ai.config.model = args.model - if 'max_tokens' in args and args.max_tokens: + if hasattr(args, 'max_tokens') and args.max_tokens: ai.config.max_tokens = args.max_tokens - if 'temperature' in args and args.temperature: + if hasattr(args, 'temperature') and args.temperature: ai.config.temperature = args.temperature return ai else: -- 2.36.6 From b5af751193bfdbb02778a4441e13c0a02045dd23 Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 13 Sep 2023 08:49:06 +0200 Subject: [PATCH 167/170] openai: added test module --- tests/test_ais_openai.py | 81 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 tests/test_ais_openai.py diff --git a/tests/test_ais_openai.py b/tests/test_ais_openai.py new file mode 100644 index 0000000..b53a14d --- /dev/null +++ b/tests/test_ais_openai.py @@ -0,0 +1,81 @@ +import unittest +from unittest import mock +from chatmastermind.ais.openai import OpenAI +from chatmastermind.message import Message, Question, Answer +from chatmastermind.chat import Chat +from chatmastermind.ai import AIResponse, Tokens +from chatmastermind.configuration import OpenAIConfig + + +class OpenAITest(unittest.TestCase): + + @mock.patch('openai.ChatCompletion.create') + def test_request(self, mock_create: mock.MagicMock) -> None: + # Create a test instance of OpenAI + config = OpenAIConfig() + openai = OpenAI(config) + + # Set up the mock response from openai.ChatCompletion.create + mock_response = { + 'choices': [ + { + 'message': { + 'content': 'Answer 1' + } + }, + { + 'message': { + 'content': 'Answer 2' + } + } + ], + 'usage': { + 'prompt_tokens': 10, + 'completion_tokens': 20, + 'total_tokens': 30 + } + } + mock_create.return_value = mock_response + + # Create test data + question = Message(Question('Question')) + chat = Chat([ + Message(Question('Question 1'), answer=Answer('Answer 1')), + Message(Question('Question 2'), answer=Answer('Answer 2')), + # add message without an answer -> expect to be skipped + Message(Question('Question 3')) + ]) + + # Make the request + response = openai.request(question, chat, num_answers=2) + + # Assert the AIResponse + self.assertIsInstance(response, AIResponse) + self.assertEqual(len(response.messages), 2) + self.assertEqual(response.messages[0].answer, 'Answer 1') + self.assertEqual(response.messages[1].answer, 'Answer 2') + self.assertIsNotNone(response.tokens) + self.assertIsInstance(response.tokens, Tokens) + assert response.tokens + self.assertEqual(response.tokens.prompt, 10) + self.assertEqual(response.tokens.completion, 20) + self.assertEqual(response.tokens.total, 30) + + # Assert the mock call to openai.ChatCompletion.create + mock_create.assert_called_once_with( + model=f'{config.model}', + messages=[ + {'role': 'system', 'content': f'{config.system}'}, + {'role': 'user', 'content': 'Question 1'}, + {'role': 'assistant', 'content': 'Answer 1'}, + {'role': 'user', 'content': 'Question 2'}, + {'role': 'assistant', 'content': 'Answer 2'}, + {'role': 'user', 'content': 'Question'} + ], + temperature=config.temperature, + max_tokens=config.max_tokens, + top_p=config.top_p, + n=2, + frequency_penalty=config.frequency_penalty, + presence_penalty=config.presence_penalty + ) -- 2.36.6 From 26e3d38afbc05b994eb1740f5f3c0a7143fe1f6e Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Wed, 13 Sep 2023 10:53:12 +0200 Subject: [PATCH 168/170] Add the Gitea web hooks. --- hooks/gitea_cmm_hook.php | 56 ++++++++++++++++++++++++++++++++++++++++ hooks/push_hook.sh | 7 +++++ 2 files changed, 63 insertions(+) create mode 100644 hooks/gitea_cmm_hook.php create mode 100755 hooks/push_hook.sh diff --git a/hooks/gitea_cmm_hook.php b/hooks/gitea_cmm_hook.php new file mode 100644 index 0000000..6b37eb6 --- /dev/null +++ b/hooks/gitea_cmm_hook.php @@ -0,0 +1,56 @@ + diff --git a/hooks/push_hook.sh b/hooks/push_hook.sh new file mode 100755 index 0000000..9406c4c --- /dev/null +++ b/hooks/push_hook.sh @@ -0,0 +1,7 @@ +#!/usr/bin/bash + +. /home/kaizen/.bashrc +set -e +cd /home/kaizen/repos/ChatMastermind +git pull +pytest -- 2.36.6 From 7f4a16894ebd5b7b3a0562618911fa697a5b8792 Mon Sep 17 00:00:00 2001 From: Oleksandr Kozachuk Date: Wed, 13 Sep 2023 11:08:02 +0200 Subject: [PATCH 169/170] Add pre-commit checks into push webhook. --- hooks/push_hook.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/hooks/push_hook.sh b/hooks/push_hook.sh index 9406c4c..6d6b4ff 100755 --- a/hooks/push_hook.sh +++ b/hooks/push_hook.sh @@ -4,4 +4,5 @@ set -e cd /home/kaizen/repos/ChatMastermind git pull +pre-commit run -a pytest -- 2.36.6 From 17a0264025489b978a6fc450cbb29b1b77467f4b Mon Sep 17 00:00:00 2001 From: juk0de Date: Wed, 13 Sep 2023 14:56:40 +0200 Subject: [PATCH 170/170] question_cmd: now also accepts Messages as source files --- chatmastermind/commands/question.py | 66 ++++++++++++++++++++--------- tests/test_question_cmd.py | 35 ++++++++++++++- 2 files changed, 81 insertions(+), 20 deletions(-) diff --git a/chatmastermind/commands/question.py b/chatmastermind/commands/question.py index 4936d8f..d143792 100644 --- a/chatmastermind/commands/question.py +++ b/chatmastermind/commands/question.py @@ -3,11 +3,52 @@ from pathlib import Path from itertools import zip_longest from ..configuration import Config from ..chat import ChatDB -from ..message import Message, MessageFilter, Question, source_code +from ..message import Message, MessageFilter, MessageError, Question, source_code from ..ai_factory import create_ai from ..ai import AI, AIResponse +def add_file_as_text(question_parts: list[str], file: str) -> None: + """ + Add the given file as plain text to the question part list. + If the file is a Message, add the answer. + """ + file_path = Path(file) + content: str + try: + message = Message.from_file(file_path) + if message and message.answer: + content = message.answer + except MessageError: + with open(file) as r: + content = r.read().strip() + if len(content) > 0: + question_parts.append(content) + + +def add_file_as_code(question_parts: list[str], file: str) -> None: + """ + Add all source code from the given file. If no code segments can be extracted, + the whole content is added as source code segment. If the file is a Message, + extract the source code from the answer. + """ + file_path = Path(file) + content: str + try: + message = Message.from_file(file_path) + if message and message.answer: + content = message.answer + except MessageError: + with open(file) as r: + content = r.read().strip() + # extract and add source code + code_parts = source_code(content, include_delims=True) + if len(code_parts) > 0: + question_parts += code_parts + else: + question_parts.append(f"```\n{content}\n```") + + def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: """ Creates (and writes) a new message from the given arguments. @@ -17,26 +58,13 @@ def create_message(chat: ChatDB, args: argparse.Namespace) -> Message: text_files = args.source_text if args.source_text is not None else [] code_files = args.source_code if args.source_code is not None else [] - for question, source, code in zip_longest(question_list, text_files, code_files, fillvalue=None): + for question, text_file, code_file in zip_longest(question_list, text_files, code_files, fillvalue=None): if question is not None and len(question.strip()) > 0: question_parts.append(question) - if source is not None and len(source) > 0: - with open(source) as r: - content = r.read().strip() - if len(content) > 0: - question_parts.append(content) - if code is not None and len(code) > 0: - with open(code) as r: - content = r.read().strip() - if len(content) == 0: - continue - # try to extract and add source code - code_parts = source_code(content, include_delims=True) - if len(code_parts) > 0: - question_parts += code_parts - # if there's none, add the whole file - else: - question_parts.append(f"```\n{content}\n```") + if text_file is not None and len(text_file) > 0: + add_file_as_text(question_parts, text_file) + if code_file is not None and len(code_file) > 0: + add_file_as_code(question_parts, code_file) full_question = '\n\n'.join(question_parts) diff --git a/tests/test_question_cmd.py b/tests/test_question_cmd.py index 40ea4d8..b94560f 100644 --- a/tests/test_question_cmd.py +++ b/tests/test_question_cmd.py @@ -5,7 +5,7 @@ import tempfile from pathlib import Path from unittest.mock import MagicMock from chatmastermind.commands.question import create_message -from chatmastermind.message import Message, Question +from chatmastermind.message import Message, Question, Answer from chatmastermind.chat import ChatDB @@ -20,6 +20,12 @@ class TestMessageCreate(unittest.TestCase): self.cache_path = tempfile.TemporaryDirectory() self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name), db_path=Path(self.db_path.name)) + # create some messages + self.message_text = Message(Question("What is this?"), + Answer("It is pure text")) + self.message_code = Message(Question("What is this?"), + Answer("Text\n```\nIt is embedded code\n```\ntext")) + self.chat.add_to_db([self.message_text, self.message_code]) # create arguments mock self.args = MagicMock(spec=argparse.Namespace) self.args.source_text = None @@ -159,4 +165,31 @@ This is embedded source code. ``` This is embedded source code. ``` +""")) + + def test_single_question_with_text_only_message(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_text = [f"{self.chat.messages[0].file_path}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # file contains no source code (only text) + # -> don't expect any in the question + self.assertEqual(len(message.question.source_code()), 0) + self.assertEqual(message.question, Question(f"""What is this? + +{self.message_text.answer}""")) + + def test_single_question_with_message_and_embedded_code(self) -> None: + self.args.ask = ["What is this?"] + self.args.source_code = [f"{self.chat.messages[1].file_path}"] + message = create_message(self.chat, self.args) + self.assertIsInstance(message, Message) + # answer contains 1 source code block + # -> expect it in the question + self.assertEqual(len(message.question.source_code()), 1) + self.assertEqual(message.question, Question("""What is this? + +``` +It is embedded code +``` """)) -- 2.36.6