Refactor process_tags

This commit is contained in:
OK 2023-04-07 17:56:02 +02:00
parent b23a9f663f
commit 4ee777118d
4 changed files with 41 additions and 30 deletions

View File

@ -6,8 +6,8 @@ import yaml
import sys import sys
import argcomplete import argcomplete
import argparse import argparse
from .utils import terminal_width, pp, tags_completer, process_tags, display_chat from .utils import terminal_width, pp, process_tags, display_chat
from .storage import save_answers, create_chat from .storage import save_answers, create_chat, get_tags
from .api_client import ai, openai_api_key from .api_client import ai, openai_api_key
@ -23,7 +23,8 @@ def process_and_display_chat(args: argparse.Namespace,
) -> tuple[list[dict[str, str]], str, list[str]]: ) -> tuple[list[dict[str, str]], str, list[str]]:
tags = args.tags or [] tags = args.tags or []
extags = args.extags or [] extags = args.extags or []
process_tags(config, tags, extags) otags = args.output_tags or []
process_tags(tags, extags, otags)
question_parts = [] question_parts = []
question_list = args.question if args.question is not None else [] question_list = args.question if args.question is not None else []
@ -60,6 +61,12 @@ def handle_question(args: argparse.Namespace,
print(f"Usage: {usage}") print(f"Usage: {usage}")
def tags_completer(prefix, parsed_args, **kwargs):
with open(parsed_args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return get_tags(config, prefix)
def create_parser() -> argparse.ArgumentParser: def create_parser() -> argparse.ArgumentParser:
default_config = '.config.yaml' default_config = '.config.yaml'
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(

View File

@ -55,3 +55,18 @@ def create_chat(question: Optional[str],
if question: if question:
append_message(chat, 'user', question) append_message(chat, 'user', question)
return chat return chat
def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
result = []
for file in sorted(pathlib.Path(config['db']).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
result.append(tag)
return list(set(result))

View File

@ -1,6 +1,4 @@
import shutil import shutil
import yaml
import pathlib
from pprint import PrettyPrinter from pprint import PrettyPrinter
from typing import List, Dict from typing import List, Dict
@ -13,10 +11,18 @@ def pp(*args, **kwargs) -> None:
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs) return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
def process_tags(config: dict, tags: list, extags: list) -> None: def process_tags(tags: list[str], extags: list[str], otags: list[str]) -> None:
print(f"Tags: {', '.join(tags)}") printed_messages = []
if len(extags) > 0:
print(f"Excluding tags: {', '.join(extags)}") if tags:
printed_messages.append(f"Tags: {', '.join(tags)}")
if extags:
printed_messages.append(f"Excluding tags: {', '.join(extags)}")
if otags:
printed_messages.append(f"Output tags: {', '.join(otags)}")
if printed_messages:
print("\n".join(printed_messages))
print() print()
@ -46,20 +52,3 @@ def display_chat(chat, dump=False) -> None:
print(message['content']) print(message['content'])
else: else:
print(f"{message['role'].upper()}: {message['content']}") print(f"{message['role'].upper()}: {message['content']}")
def tags_completer(prefix, parsed_args, **kwargs):
with open(parsed_args.config, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
result = []
for file in sorted(pathlib.Path(config['db']).iterdir()):
if file.suffix == '.yaml':
with open(file, 'r') as f:
data = yaml.load(f, Loader=yaml.FullLoader)
for tag in data.get('tags', []):
if prefix and len(prefix) > 0:
if tag.startswith(prefix):
result.append(tag)
else:
result.append(tag)
return list(set(result))

View File

@ -113,9 +113,9 @@ class TestHandleQuestion(unittest.TestCase):
open_mock = MagicMock() open_mock = MagicMock()
with patch("chatmastermind.storage.open", open_mock): with patch("chatmastermind.storage.open", open_mock):
handle_question(self.args, self.config, True) handle_question(self.args, self.config, True)
mock_process_tags.assert_called_once_with(self.config, mock_process_tags.assert_called_once_with(self.args.tags,
self.args.tags, self.args.extags,
self.args.extags) [])
mock_create_chat.assert_called_once_with(self.question, mock_create_chat.assert_called_once_with(self.question,
self.args.tags, self.args.tags,
self.args.extags, self.args.extags,