Refactor process_tags
This commit is contained in:
parent
b23a9f663f
commit
4ee777118d
@ -6,8 +6,8 @@ import yaml
|
||||
import sys
|
||||
import argcomplete
|
||||
import argparse
|
||||
from .utils import terminal_width, pp, tags_completer, process_tags, display_chat
|
||||
from .storage import save_answers, create_chat
|
||||
from .utils import terminal_width, pp, process_tags, display_chat
|
||||
from .storage import save_answers, create_chat, get_tags
|
||||
from .api_client import ai, openai_api_key
|
||||
|
||||
|
||||
@ -23,7 +23,8 @@ def process_and_display_chat(args: argparse.Namespace,
|
||||
) -> tuple[list[dict[str, str]], str, list[str]]:
|
||||
tags = args.tags or []
|
||||
extags = args.extags or []
|
||||
process_tags(config, tags, extags)
|
||||
otags = args.output_tags or []
|
||||
process_tags(tags, extags, otags)
|
||||
|
||||
question_parts = []
|
||||
question_list = args.question if args.question is not None else []
|
||||
@ -60,6 +61,12 @@ def handle_question(args: argparse.Namespace,
|
||||
print(f"Usage: {usage}")
|
||||
|
||||
|
||||
def tags_completer(prefix, parsed_args, **kwargs):
|
||||
with open(parsed_args.config, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return get_tags(config, prefix)
|
||||
|
||||
|
||||
def create_parser() -> argparse.ArgumentParser:
|
||||
default_config = '.config.yaml'
|
||||
parser = argparse.ArgumentParser(
|
||||
|
||||
@ -55,3 +55,18 @@ def create_chat(question: Optional[str],
|
||||
if question:
|
||||
append_message(chat, 'user', question)
|
||||
return chat
|
||||
|
||||
|
||||
def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
|
||||
result = []
|
||||
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
||||
if file.suffix == '.yaml':
|
||||
with open(file, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
for tag in data.get('tags', []):
|
||||
if prefix and len(prefix) > 0:
|
||||
if tag.startswith(prefix):
|
||||
result.append(tag)
|
||||
else:
|
||||
result.append(tag)
|
||||
return list(set(result))
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
import shutil
|
||||
import yaml
|
||||
import pathlib
|
||||
from pprint import PrettyPrinter
|
||||
from typing import List, Dict
|
||||
|
||||
@ -13,11 +11,19 @@ def pp(*args, **kwargs) -> None:
|
||||
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
|
||||
|
||||
|
||||
def process_tags(config: dict, tags: list, extags: list) -> None:
|
||||
print(f"Tags: {', '.join(tags)}")
|
||||
if len(extags) > 0:
|
||||
print(f"Excluding tags: {', '.join(extags)}")
|
||||
print()
|
||||
def process_tags(tags: list[str], extags: list[str], otags: list[str]) -> None:
|
||||
printed_messages = []
|
||||
|
||||
if tags:
|
||||
printed_messages.append(f"Tags: {', '.join(tags)}")
|
||||
if extags:
|
||||
printed_messages.append(f"Excluding tags: {', '.join(extags)}")
|
||||
if otags:
|
||||
printed_messages.append(f"Output tags: {', '.join(otags)}")
|
||||
|
||||
if printed_messages:
|
||||
print("\n".join(printed_messages))
|
||||
print()
|
||||
|
||||
|
||||
def append_message(chat: List[Dict[str, str]],
|
||||
@ -46,20 +52,3 @@ def display_chat(chat, dump=False) -> None:
|
||||
print(message['content'])
|
||||
else:
|
||||
print(f"{message['role'].upper()}: {message['content']}")
|
||||
|
||||
|
||||
def tags_completer(prefix, parsed_args, **kwargs):
|
||||
with open(parsed_args.config, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
result = []
|
||||
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
||||
if file.suffix == '.yaml':
|
||||
with open(file, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
for tag in data.get('tags', []):
|
||||
if prefix and len(prefix) > 0:
|
||||
if tag.startswith(prefix):
|
||||
result.append(tag)
|
||||
else:
|
||||
result.append(tag)
|
||||
return list(set(result))
|
||||
|
||||
@ -113,9 +113,9 @@ class TestHandleQuestion(unittest.TestCase):
|
||||
open_mock = MagicMock()
|
||||
with patch("chatmastermind.storage.open", open_mock):
|
||||
handle_question(self.args, self.config, True)
|
||||
mock_process_tags.assert_called_once_with(self.config,
|
||||
self.args.tags,
|
||||
self.args.extags)
|
||||
mock_process_tags.assert_called_once_with(self.args.tags,
|
||||
self.args.extags,
|
||||
[])
|
||||
mock_create_chat.assert_called_once_with(self.question,
|
||||
self.args.tags,
|
||||
self.args.extags,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user