diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..9b38853 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 4c00309..27e00f1 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -56,7 +56,7 @@ def handle_question(args: argparse.Namespace, chat, question, tags = process_and_display_chat(args, config, dump) otags = args.output_tags or [] answers, usage = ai(chat, config, args.number) - save_answers(question, answers, tags, otags) + save_answers(question, answers, tags, otags, config) print("-" * terminal_width()) print(f"Usage: {usage}") diff --git a/chatmastermind/storage.py b/chatmastermind/storage.py index 2d1d373..fb7bd8d 100644 --- a/chatmastermind/storage.py +++ b/chatmastermind/storage.py @@ -8,15 +8,25 @@ from typing import List, Dict, Any, Optional def save_answers(question: str, answers: list[str], tags: list[str], - otags: Optional[list[str]] + otags: Optional[list[str]], + config: Dict[str, Any] ) -> None: wtags = otags or tags - for num, answer in enumerate(answers, start=1): - title = f'-- ANSWER {num} ' + num, inum = 0, 0 + next_fname = pathlib.Path(config['db']) / '.next' + try: + with open(next_fname, 'r') as f: + num = int(f.read()) + except Exception: + pass + for answer in answers: + num += 1 + inum += 1 + title = f'-- ANSWER {inum} ' title_end = '-' * (terminal_width() - len(title)) print(f'{title}{title_end}') print(answer) - with open(f"{num:02d}.yaml", "w") as fd: + with open(f"{num:04d}.yaml", "w") as fd: with io.StringIO() as f: yaml.dump({'question': question}, f, @@ -32,6 +42,8 @@ def save_answers(question: str, yaml.dump({'tags': wtags}, fd, default_flow_style=False) + with open(next_fname, 'w') as f: + f.write(f'{num}') def create_chat(question: Optional[str], diff --git a/tests/test_main.py b/tests/test_main.py index 6c35fe9..d3c7755 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,6 +1,6 @@ import unittest import io -import os +import pathlib import yaml import argparse from chatmastermind.utils import terminal_width @@ -98,6 +98,7 @@ class TestHandleQuestion(unittest.TestCase): number=3 ) self.config = { + 'db': 'test_files', 'setting1': 'value1', 'setting2': 'value2' } @@ -132,42 +133,33 @@ class TestHandleQuestion(unittest.TestCase): expected_calls.append(((answer,),)) expected_calls.append((("-" * terminal_width(),),)) expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),)) - open_mock.assert_has_calls( - [mock.call(f"{num:02d}.yaml", "w") for num in range(1, 4)] + [ - mock.call().__enter__(), mock.call().__exit__(None, None, None)] * 3, - any_order=True) self.assertEqual(mock_print.call_args_list, expected_calls) + open_expected_calls = list([mock.call(f"{num:04d}.yaml", "w") for num in range(2, 5)]) + open_mock.assert_has_calls(open_expected_calls, any_order=True) class TestSaveAnswers(unittest.TestCase): + @mock.patch('builtins.open') + @mock.patch('chatmastermind.storage.print') + def test_save_answers(self, print_mock, open_mock): + question = "Test question?" + answers = ["Answer 1", "Answer 2"] + tags = ["tag1", "tag2"] + otags = ["otag1", "otag2"] + config = {'db': 'test_db'} - def setUp(self): - self.question = "What is AI?" - self.answers = ["AI is Artificial Intelligence", - "AI is a simulation of human intelligence"] - self.tags = ["ai", "definition"] + with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \ + mock.patch('chatmastermind.storage.yaml.dump'), \ + mock.patch('io.StringIO') as stringio_mock: + stringio_instance = stringio_mock.return_value + stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"] + save_answers(question, answers, tags, otags, config) - @patch('sys.stdout', new_callable=io.StringIO) - def assert_stdout(self, expected_output: str, mock_stdout: io.StringIO): - save_answers(self.question, self.answers, self.tags, None) - self.assertEqual(mock_stdout.getvalue(), expected_output) - - def test_save_answers(self): - try: - self.assert_stdout(f"-- ANSWER 1 {'-'*(terminal_width()-12)}\n" - "AI is Artificial Intelligence\n" - f"-- ANSWER 2 {'-'*(terminal_width()-12)}\n" - "AI is a simulation of human intelligence\n") - for idx, answer in enumerate(self.answers, start=1): - with open(f"{idx:02d}.yaml", "r") as file: - data = yaml.safe_load(file) - self.assertEqual(data["question"], self.question) - self.assertEqual(data["answer"], answer) - self.assertEqual(data["tags"], self.tags) - finally: - for idx in range(1, len(self.answers) + 1): - if os.path.exists(f"{idx:02d}.yaml"): - os.remove(f"{idx:02d}.yaml") + open_calls = [ + mock.call(pathlib.Path('test_db/.next'), 'r'), + mock.call(pathlib.Path('test_db/.next'), 'w'), + ] + open_mock.assert_has_calls(open_calls, any_order=True) class TestAI(unittest.TestCase):