Compare commits
176 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
17a0264025 | ||
| 7f4a16894e | |||
|
|
26e3d38afb | ||
|
|
b5af751193 | ||
|
|
a7345cbc41 | ||
|
|
310cb9421e | ||
| 1ec3d6fcda | |||
| 544bf0bf06 | |||
| f96e82bdd7 | |||
| 2b62cb8c4b | |||
|
|
a895c1fc6a | ||
| ddfcc71510 | |||
| 17de0b9967 | |||
|
|
33023d29f9 | ||
|
|
481f9ecf7c | ||
|
|
22fa187e5f | ||
|
|
b840ebd792 | ||
|
|
66908f5fed | ||
|
|
2e08ccf606 | ||
|
|
595ff8e294 | ||
|
|
faac42d3c2 | ||
|
|
864ab7aeb1 | ||
|
|
cc76da2ab3 | ||
|
|
f99cd3ed41 | ||
| 6f3ea98425 | |||
| 54ece6efeb | |||
| 86eebc39ea | |||
|
|
3eca53998b | ||
|
|
c4f7bcc94e | ||
|
|
c52713c833 | ||
|
|
ecb6994783 | ||
|
|
61e710a4b1 | ||
|
|
21d39c6c66 | ||
|
|
6a4cc7a65d | ||
|
|
d6bb5800b1 | ||
|
|
034e4093f1 | ||
|
|
7d15452242 | ||
|
|
823d3bf7dc | ||
|
|
4bd144c4d7 | ||
|
|
e186afbef0 | ||
|
|
5e4ec70072 | ||
|
|
4c378dde85 | ||
|
|
8923a13352 | ||
|
|
e1414835c8 | ||
|
|
abb7fdacb6 | ||
|
|
2e2228bd60 | ||
|
|
713b55482a | ||
|
|
d35de86c67 | ||
|
|
aba3eb783d | ||
|
|
8e63831701 | ||
|
|
c318b99671 | ||
|
|
48c8e951e1 | ||
|
|
b22a4b07ed | ||
|
|
33565d351d | ||
|
|
6737fa98c7 | ||
|
|
815a21893c | ||
|
|
64893949a4 | ||
|
|
a093f9b867 | ||
|
|
dc3f3dc168 | ||
|
|
74c39070d6 | ||
|
|
fde0ae4652 | ||
|
|
238dbbee60 | ||
|
|
17f7b2fb45 | ||
|
|
9c2598a4b8 | ||
|
|
acec5f1d55 | ||
|
|
c0f50bace5 | ||
|
|
30ccec2462 | ||
|
|
09da312657 | ||
|
|
33567df15f | ||
|
|
264979a60d | ||
|
|
061e5f8682 | ||
|
|
2d456e68f1 | ||
|
|
8bd659e888 | ||
|
|
d4021eeb11 | ||
|
|
c143c001f9 | ||
|
|
59b851650a | ||
|
|
6f71a2ff69 | ||
|
|
eca44b14cb | ||
|
|
b48667bfa0 | ||
|
|
533ee1c1a9 | ||
|
|
cf50818f28 | ||
|
|
dd3d3ffc82 | ||
|
|
1e3bfdd67f | ||
|
|
53582a7123 | ||
| 39b518a8a6 | |||
| d22877a0f1 | |||
| 7cf62c54ef | |||
| 3ef1339cc0 | |||
|
|
5fb5dde550 | ||
|
|
c0b7d17587 | ||
|
|
76f2373397 | ||
|
|
eaa399bcb9 | ||
|
|
b1a23394fc | ||
| ed567afbea | |||
|
|
2df9dd6427 | ||
|
|
74a26b8c2f | ||
| 6e447018d5 | |||
|
|
893917e455 | ||
|
|
ba5aa1fbc7 | ||
|
|
eb2fcba99d | ||
|
|
b7e3ca7ca7 | ||
|
|
aa322de718 | ||
|
|
bf1cbff6a2 | ||
|
|
f93a57c00d | ||
|
|
b0504aedbe | ||
|
|
eb0d97ddc8 | ||
|
|
7e25a08d6e | ||
|
|
63040b3688 | ||
|
|
6e2d5009c1 | ||
|
|
44cd1fab45 | ||
|
|
4b0f40bccd | ||
|
|
fa292fb73a | ||
|
|
f9d749cdd8 | ||
|
|
ba56caf013 | ||
|
|
d80c3962bd | ||
|
|
ddfe29b951 | ||
|
|
d93598a74f | ||
|
|
7f612bfc17 | ||
|
|
93290da5b5 | ||
|
|
9f4897a5b8 | ||
|
|
214a6919db | ||
|
|
b83cbb719b | ||
|
|
8e1cdee3bf | ||
|
|
73d2a9ea3b | ||
|
|
169f1bb458 | ||
|
|
7f91a2b567 | ||
|
|
fc1b8006a0 | ||
|
|
aa89270876 | ||
|
|
0d6a6dd604 | ||
|
|
580c86e948 | ||
|
|
879831d7f5 | ||
|
|
dfc1261931 | ||
|
|
173a46a9b5 | ||
|
|
604e5ccf73 | ||
|
|
ef46f5efc9 | ||
|
|
b13a68836a | ||
|
|
a5c91adc41 | ||
| 380b7c1b67 | |||
|
|
e8343fde01 | ||
|
|
ee8deed320 | ||
|
|
dc13213c4d | ||
|
|
4303fb414f | ||
|
|
ba41794f4e | ||
|
|
a5075b14a0 | ||
|
|
e8eba0b755 | ||
|
|
1e15a52e26 | ||
|
|
c4a7c07a0c | ||
|
|
22bebc16ed | ||
|
|
f7ba0c000f | ||
| b6eb7d9af8 | |||
|
|
f371a6146e | ||
| 6ed459be6f | |||
| 1fb9144192 | |||
| 4b2f634b79 | |||
| e4d055b900 | |||
|
|
bc5e6228a6 | ||
|
|
056bf4c6b5 | ||
|
|
93a8b0081a | ||
|
|
5119b3a874 | ||
|
|
5a435c5f8f | ||
|
|
f90e7bcd47 | ||
|
|
6406d2f5b5 | ||
| df91ca863a | |||
| bc9baff0dc | |||
|
|
7a92ebe539 | ||
| 9b6b13993c | |||
| c5c4a6628f | |||
|
|
f8ed0e3636 | ||
| caf5244d52 | |||
| ca3a53e68b | |||
| 820d938060 | |||
| 8bb2a002a6 | |||
| 01de75bef3 | |||
| 1bb553b223 | |||
| bb8aa2f817 | |||
| 0d31b88567 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -131,3 +131,4 @@ dmypy.json
|
||||
.config.yaml
|
||||
db
|
||||
noweb
|
||||
Session.vim
|
||||
|
||||
171
README.md
171
README.md
@ -4,9 +4,11 @@ ChatMastermind is a Python application that automates conversation with AI, stor
|
||||
|
||||
The project uses the OpenAI API to generate responses and stores the data in YAML files. It also allows you to filter chat history based on tags and supports autocompletion for tags.
|
||||
|
||||
Official repository URL: https://kaizenkodo.no/gitea/kaizenkodo/ChatMastermind.git
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.6 or higher
|
||||
- Python 3.9 or higher
|
||||
- openai
|
||||
- PyYAML
|
||||
- argcomplete
|
||||
@ -27,65 +29,133 @@ pip install .
|
||||
|
||||
## Usage
|
||||
|
||||
The `cmm` script has global options, a list of commands, and options per command:
|
||||
|
||||
```bash
|
||||
cmm [-h] [-p PRINT | -q QUESTION | -D | -d | -l] [-c CONFIG] [-m MAX_TOKENS] [-T TEMPERATURE] [-M MODEL] [-n NUMBER] [-t [TAGS [TAGS ...]]] [-e [EXTAGS [EXTAGS ...]]] [-o [OTAGS [OTAGS ...]]] [-a] [-w] [-W]
|
||||
cmm [global options] command [command options]
|
||||
```
|
||||
|
||||
### Arguments
|
||||
### Global Options
|
||||
|
||||
- `-p`, `--print`: YAML file to print.
|
||||
- `-q`, `--question`: Question to ask.
|
||||
- `-D`, `--chat-dump`: Print chat history as a Python structure.
|
||||
- `-d`, `--chat`: Print chat history as readable text.
|
||||
- `-a`, `--match-all-tags`: All given tags must match when selecting chat history entries.
|
||||
- `-w`, `--with-tags`: Print chat history with tags.
|
||||
- `-W`, `--with-tags`: Print chat history with filenames.
|
||||
- `-l`, `--list-tags`: List all tags and their frequency.
|
||||
- `-c`, `--config`: Config file name (defaults to `.config.yaml`).
|
||||
- `-m`, `--max-tokens`: Max tokens to use.
|
||||
- `-T`, `--temperature`: Temperature to use.
|
||||
- `-M`, `--model`: Model to use.
|
||||
- `-n`, `--number`: Number of answers to produce (default is 3).
|
||||
- `-t`, `--tags`: List of tag names.
|
||||
- `-e`, `--extags`: List of tag names to exclude.
|
||||
- `-o`, `--output-tags`: List of output tag names (default is the input tags).
|
||||
- `-C`, `--config`: Config file name (defaults to `.config.yaml`).
|
||||
|
||||
### Command Options
|
||||
|
||||
#### Question
|
||||
|
||||
The `question` command is used to ask, create, and process questions.
|
||||
|
||||
```bash
|
||||
cmm question [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-o OUTTAGS]... [-A AI] [-M MODEL] [-n NUM] [-m MAX] [-T TEMP] (-a ASK | -c CREATE | -r REPEAT | -p PROCESS) [-O] [-s SOURCE]... [-S SOURCE]...
|
||||
```
|
||||
|
||||
* `-t, --or-tags OTAGS` : List of tags (one must match)
|
||||
* `-k, --and-tags ATAGS` : List of tags (all must match)
|
||||
* `-x, --exclude-tags XTAGS` : List of tags to exclude
|
||||
* `-o, --output-tags OUTTAGS` : List of output tags (default: use input tags)
|
||||
* `-A, --AI AI` : AI ID to use
|
||||
* `-M, --model MODEL` : Model to use
|
||||
* `-n, --num-answers NUM` : Number of answers to request
|
||||
* `-m, --max-tokens MAX` : Max. number of tokens
|
||||
* `-T, --temperature TEMP` : Temperature value
|
||||
* `-a, --ask ASK` : Ask a question
|
||||
* `-c, --create CREATE` : Create a question
|
||||
* `-r, --repeat REPEAT` : Repeat a question
|
||||
* `-p, --process PROCESS` : Process existing questions
|
||||
* `-O, --overwrite` : Overwrite existing messages when repeating them
|
||||
* `-s, --source-text SOURCE` : Add content of a file to the query
|
||||
* `-S, --source-code SOURCE` : Add source code file content to the chat history
|
||||
|
||||
#### Hist
|
||||
|
||||
The `hist` command is used to print the chat history.
|
||||
|
||||
```bash
|
||||
cmm hist [-t OTAGS]... [-k ATAGS]... [-x XTAGS]... [-w] [-W] [-S] [-A ANSWER] [-Q QUESTION]
|
||||
```
|
||||
|
||||
* `-t, --or-tags OTAGS` : List of tags (one must match)
|
||||
* `-k, --and-tags ATAGS` : List of tags (all must match)
|
||||
* `-x, --exclude-tags XTAGS` : List of tags to exclude
|
||||
* `-w, --with-tags` : Print chat history with tags
|
||||
* `-W, --with-files` : Print chat history with filenames
|
||||
* `-S, --source-code-only` : Print only source code
|
||||
* `-A, --answer ANSWER` : Search for answer substring
|
||||
* `-Q, --question QUESTION` : Search for question substring
|
||||
|
||||
#### Tags
|
||||
|
||||
The `tags` command is used to manage tags.
|
||||
|
||||
```bash
|
||||
cmm tags (-l | -p PREFIX | -c CONTENT)
|
||||
```
|
||||
|
||||
* `-l, --list` : List all tags and their frequency
|
||||
* `-p, --prefix PREFIX` : Filter tags by prefix
|
||||
* `-c, --contain CONTENT` : Filter tags by contained substring
|
||||
|
||||
#### Config
|
||||
|
||||
The `config` command is used to manage the configuration.
|
||||
|
||||
```bash
|
||||
cmm config (-l | -m | -c CREATE)
|
||||
```
|
||||
|
||||
* `-l, --list-models` : List all available models
|
||||
* `-m, --print-model` : Print the currently configured model
|
||||
* `-c, --create CREATE` : Create config with default settings in the given file
|
||||
|
||||
#### Print
|
||||
|
||||
The `print` command is used to print message files.
|
||||
|
||||
```bash
|
||||
cmm print -f FILE [-q | -a | -S]
|
||||
```
|
||||
|
||||
* `-f, --file FILE` : File to print
|
||||
* `-q, --question` : Print only question
|
||||
* `-a, --answer` : Print only answer
|
||||
* `-S, --only-source-code` : Print only source code
|
||||
|
||||
### Examples
|
||||
|
||||
1. Print the contents of a YAML file:
|
||||
1. Ask a question:
|
||||
|
||||
```bash
|
||||
cmm -p example.yaml
|
||||
cmm question -a "What is the meaning of life?" -t philosophy -x religion
|
||||
```
|
||||
|
||||
2. Ask a question:
|
||||
2. Display the chat history:
|
||||
|
||||
```bash
|
||||
cmm -q "What is the meaning of life?" -t philosophy -e religion
|
||||
cmm hist
|
||||
```
|
||||
|
||||
3. Display the chat history as a Python structure:
|
||||
3. Filter chat history by tags:
|
||||
|
||||
```bash
|
||||
cmm -D
|
||||
cmm hist --or-tags tag1 tag2
|
||||
```
|
||||
|
||||
4. Display the chat history as readable text:
|
||||
4. Exclude chat history by tags:
|
||||
|
||||
```bash
|
||||
cmm -d
|
||||
cmm hist --exclude-tags tag3 tag4
|
||||
```
|
||||
|
||||
5. Filter chat history by tags:
|
||||
5. List all tags and their frequency:
|
||||
|
||||
```bash
|
||||
cmm -d -t tag1 tag2
|
||||
cmm tags -l
|
||||
```
|
||||
|
||||
6. Exclude chat history by tags:
|
||||
6. Print the contents of a file:
|
||||
|
||||
```bash
|
||||
cmm -d -e tag3 tag4
|
||||
cmm print -f example.yaml
|
||||
```
|
||||
|
||||
## Configuration
|
||||
@ -113,6 +183,45 @@ eval "$(register-python-argcomplete cmm)"
|
||||
|
||||
After adding this line, restart your shell or run `source <your-shell-config-file>` to enable autocompletion for the `cmm` script.
|
||||
|
||||
## Contributing
|
||||
|
||||
### Enable commit hooks
|
||||
```
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
```
|
||||
### Execute tests before opening a PR
|
||||
```
|
||||
pytest
|
||||
```
|
||||
### Consider using `pyenv` / `pyenv-virtualenv`
|
||||
Short installation instructions:
|
||||
* install `pyenv`:
|
||||
```
|
||||
cd ~
|
||||
git clone https://github.com/pyenv/pyenv .pyenv
|
||||
cd ~/.pyenv && src/configure && make -C src
|
||||
```
|
||||
* make sure that `~/.pyenv/shims` and `~/.pyenv/bin` are the first entries in your `PATH`, e. g. by setting it in `~/.bashrc`
|
||||
* add the following to your `~/.bashrc` (after setting `PATH`): `eval "$(pyenv init -)"`
|
||||
* create a new terminal or source the changes (e. g. `source ~/.bashrc`)
|
||||
* install `virtualenv`
|
||||
```
|
||||
git clone https://github.com/pyenv/pyenv-virtualenv.git $(pyenv root)/plugins/pyenv-virtualenv
|
||||
```
|
||||
* add the following to your `~/.bashrc` (after the commands above): `eval "$(pyenv virtualenv-init -)`
|
||||
* create a new terminal or source the changes (e. g. `source ~/.bashrc`)
|
||||
* go back to the `ChatMasterMind` repo and create a virtual environment with the latest `Python`, e. g. `3.11.4`:
|
||||
```
|
||||
cd <CMM_REPO_PATH>
|
||||
pyenv install 3.11.4
|
||||
pyenv virtualenv 3.11.4 py311
|
||||
pyenv activate py311
|
||||
```
|
||||
* see also the [official pyenv documentation](https://github.com/pyenv/pyenv#readme)
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the terms of the WTFPL License.
|
||||
|
||||
|
||||
|
||||
80
chatmastermind/ai.py
Normal file
80
chatmastermind/ai.py
Normal file
@ -0,0 +1,80 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol, Optional, Union
|
||||
from .configuration import AIConfig
|
||||
from .tags import Tag
|
||||
from .message import Message
|
||||
from .chat import Chat
|
||||
|
||||
|
||||
class AIError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tokens:
|
||||
prompt: int = 0
|
||||
completion: int = 0
|
||||
total: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIResponse:
|
||||
"""
|
||||
The response to an AI request. Consists of one or more messages
|
||||
(each containing the question and a single answer) and the nr.
|
||||
of used tokens.
|
||||
"""
|
||||
messages: list[Message]
|
||||
tokens: Optional[Tokens] = None
|
||||
|
||||
|
||||
class AI(Protocol):
|
||||
"""
|
||||
The base class for AI clients.
|
||||
"""
|
||||
|
||||
ID: str
|
||||
name: str
|
||||
config: AIConfig
|
||||
|
||||
def request(self,
|
||||
question: Message,
|
||||
chat: Chat,
|
||||
num_answers: int = 1,
|
||||
otags: Optional[set[Tag]] = None) -> AIResponse:
|
||||
"""
|
||||
Make an AI request. Parameters:
|
||||
* question: the question to ask
|
||||
* chat: the chat history to be added as context
|
||||
* num_answers: nr. of requested answers (corresponds
|
||||
to the nr. of messages in the 'AIResponse')
|
||||
* otags: the output tags, i. e. the tags that all
|
||||
returned messages should contain
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def models(self) -> list[str]:
|
||||
"""
|
||||
Return all models supported by this AI.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def print_models(self) -> None:
|
||||
"""
|
||||
Print all models supported by this AI.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||
"""
|
||||
Computes the nr. of AI language tokens for the given message
|
||||
or chat. Note that the computation may not be 100% accurate
|
||||
and is not implemented for all AIs.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def print(self) -> None:
|
||||
"""
|
||||
Print some info about the current AI, like system message.
|
||||
"""
|
||||
pass
|
||||
43
chatmastermind/ai_factory.py
Normal file
43
chatmastermind/ai_factory.py
Normal file
@ -0,0 +1,43 @@
|
||||
"""
|
||||
Creates different AI instances, based on the given configuration.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from typing import cast
|
||||
from .configuration import Config, AIConfig, OpenAIConfig
|
||||
from .ai import AI, AIError
|
||||
from .ais.openai import OpenAI
|
||||
|
||||
|
||||
def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11
|
||||
"""
|
||||
Creates an AI subclass instance from the given arguments
|
||||
and configuration file. If AI has not been set in the
|
||||
arguments, it searches for the ID 'default'. If that
|
||||
is not found, it uses the first AI in the list.
|
||||
"""
|
||||
ai_conf: AIConfig
|
||||
if hasattr(args, 'AI') and args.AI:
|
||||
try:
|
||||
ai_conf = config.ais[args.AI]
|
||||
except KeyError:
|
||||
raise AIError(f"AI ID '{args.AI}' does not exist in this configuration")
|
||||
elif 'default' in config.ais:
|
||||
ai_conf = config.ais['default']
|
||||
else:
|
||||
try:
|
||||
ai_conf = next(iter(config.ais.values()))
|
||||
except StopIteration:
|
||||
raise AIError("No AI found in this configuration")
|
||||
|
||||
if ai_conf.name == 'openai':
|
||||
ai = OpenAI(cast(OpenAIConfig, ai_conf))
|
||||
if hasattr(args, 'model') and args.model:
|
||||
ai.config.model = args.model
|
||||
if hasattr(args, 'max_tokens') and args.max_tokens:
|
||||
ai.config.max_tokens = args.max_tokens
|
||||
if hasattr(args, 'temperature') and args.temperature:
|
||||
ai.config.temperature = args.temperature
|
||||
return ai
|
||||
else:
|
||||
raise AIError(f"AI '{args.AI}' is not supported")
|
||||
0
chatmastermind/ais/__init__.py
Normal file
0
chatmastermind/ais/__init__.py
Normal file
111
chatmastermind/ais/openai.py
Normal file
111
chatmastermind/ais/openai.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""
|
||||
Implements the OpenAI client classes and functions.
|
||||
"""
|
||||
import openai
|
||||
from typing import Optional, Union
|
||||
from ..tags import Tag
|
||||
from ..message import Message, Answer
|
||||
from ..chat import Chat
|
||||
from ..ai import AI, AIResponse, Tokens
|
||||
from ..configuration import OpenAIConfig
|
||||
|
||||
ChatType = list[dict[str, str]]
|
||||
|
||||
|
||||
class OpenAI(AI):
|
||||
"""
|
||||
The OpenAI AI client.
|
||||
"""
|
||||
|
||||
def __init__(self, config: OpenAIConfig) -> None:
|
||||
self.ID = config.ID
|
||||
self.name = config.name
|
||||
self.config = config
|
||||
openai.api_key = config.api_key
|
||||
|
||||
def request(self,
|
||||
question: Message,
|
||||
chat: Chat,
|
||||
num_answers: int = 1,
|
||||
otags: Optional[set[Tag]] = None) -> AIResponse:
|
||||
"""
|
||||
Make an AI request, asking the given question with the given
|
||||
chat history. The nr. of requested answers corresponds to the
|
||||
nr. of messages in the 'AIResponse'.
|
||||
"""
|
||||
oai_chat = self.openai_chat(chat, self.config.system, question)
|
||||
response = openai.ChatCompletion.create(
|
||||
model=self.config.model,
|
||||
messages=oai_chat,
|
||||
temperature=self.config.temperature,
|
||||
max_tokens=self.config.max_tokens,
|
||||
top_p=self.config.top_p,
|
||||
n=num_answers,
|
||||
frequency_penalty=self.config.frequency_penalty,
|
||||
presence_penalty=self.config.presence_penalty)
|
||||
question.answer = Answer(response['choices'][0]['message']['content'])
|
||||
question.tags = otags
|
||||
question.ai = self.ID
|
||||
question.model = self.config.model
|
||||
answers: list[Message] = [question]
|
||||
for choice in response['choices'][1:]: # type: ignore
|
||||
answers.append(Message(question=question.question,
|
||||
answer=Answer(choice['message']['content']),
|
||||
tags=otags,
|
||||
ai=self.ID,
|
||||
model=self.config.model))
|
||||
return AIResponse(answers, Tokens(response['usage']['prompt_tokens'],
|
||||
response['usage']['completion_tokens'],
|
||||
response['usage']['total_tokens']))
|
||||
|
||||
def models(self) -> list[str]:
|
||||
"""
|
||||
Return all models supported by this AI.
|
||||
"""
|
||||
ret = []
|
||||
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
|
||||
if engine['ready']:
|
||||
ret.append(engine['id'])
|
||||
ret.sort()
|
||||
return ret
|
||||
|
||||
def print_models(self) -> None:
|
||||
"""
|
||||
Print all models supported by the current AI.
|
||||
"""
|
||||
not_ready = []
|
||||
for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']):
|
||||
if engine['ready']:
|
||||
print(engine['id'])
|
||||
else:
|
||||
not_ready.append(engine['id'])
|
||||
if len(not_ready) > 0:
|
||||
print('\nNot ready: ' + ', '.join(not_ready))
|
||||
|
||||
def openai_chat(self, chat: Chat, system: str,
|
||||
question: Optional[Message] = None) -> ChatType:
|
||||
"""
|
||||
Create a chat history with system message in OpenAI format.
|
||||
Optionally append a new question.
|
||||
"""
|
||||
oai_chat: ChatType = []
|
||||
|
||||
def append(role: str, content: str) -> None:
|
||||
oai_chat.append({'role': role, 'content': content.replace("''", "'")})
|
||||
|
||||
append('system', system)
|
||||
for message in chat.messages:
|
||||
if message.answer:
|
||||
append('user', message.question)
|
||||
append('assistant', message.answer)
|
||||
if question:
|
||||
append('user', question.question)
|
||||
return oai_chat
|
||||
|
||||
def tokens(self, data: Union[Message, Chat]) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def print(self) -> None:
|
||||
print(f"MODEL: {self.config.model}")
|
||||
print("=== SYSTEM ===")
|
||||
print(self.config.system)
|
||||
@ -1,24 +0,0 @@
|
||||
import openai
|
||||
|
||||
|
||||
def openai_api_key(api_key: str) -> None:
|
||||
openai.api_key = api_key
|
||||
|
||||
|
||||
def ai(chat: list[dict[str, str]],
|
||||
config: dict,
|
||||
number: int
|
||||
) -> tuple[list[str], dict[str, int]]:
|
||||
response = openai.ChatCompletion.create(
|
||||
model=config['openai']['model'],
|
||||
messages=chat,
|
||||
temperature=config['openai']['temperature'],
|
||||
max_tokens=config['openai']['max_tokens'],
|
||||
top_p=config['openai']['top_p'],
|
||||
n=number,
|
||||
frequency_penalty=config['openai']['frequency_penalty'],
|
||||
presence_penalty=config['openai']['presence_penalty'])
|
||||
result = []
|
||||
for choice in response['choices']: # type: ignore
|
||||
result.append(choice['message']['content'].strip())
|
||||
return result, dict(response['usage']) # type: ignore
|
||||
406
chatmastermind/chat.py
Normal file
406
chatmastermind/chat.py
Normal file
@ -0,0 +1,406 @@
|
||||
"""
|
||||
Module implementing various chat classes and functions for managing a chat history.
|
||||
"""
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from pprint import PrettyPrinter
|
||||
from pydoc import pager
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar, Type, Optional, ClassVar, Any, Callable
|
||||
from .message import Message, MessageFilter, MessageError, message_in
|
||||
from .tags import Tag
|
||||
|
||||
ChatInst = TypeVar('ChatInst', bound='Chat')
|
||||
ChatDBInst = TypeVar('ChatDBInst', bound='ChatDB')
|
||||
|
||||
|
||||
class ChatError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def terminal_width() -> int:
|
||||
return shutil.get_terminal_size().columns
|
||||
|
||||
|
||||
def pp(*args: Any, **kwargs: Any) -> None:
|
||||
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
|
||||
|
||||
|
||||
def print_paged(text: str) -> None:
|
||||
pager(text)
|
||||
|
||||
|
||||
def read_dir(dir_path: Path,
|
||||
glob: Optional[str] = None,
|
||||
mfilter: Optional[MessageFilter] = None) -> list[Message]:
|
||||
"""
|
||||
Reads the messages from the given folder.
|
||||
Parameters:
|
||||
* 'dir_path': source directory
|
||||
* 'glob': if specified, files will be filtered using 'path.glob()',
|
||||
otherwise it uses 'path.iterdir()'.
|
||||
* 'mfilter': use with 'Message.from_file()' to filter messages
|
||||
when reading them.
|
||||
"""
|
||||
messages: list[Message] = []
|
||||
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
|
||||
for file_path in sorted(file_iter):
|
||||
if file_path.is_file() and file_path.suffix in Message.file_suffixes:
|
||||
try:
|
||||
message = Message.from_file(file_path, mfilter)
|
||||
if message:
|
||||
messages.append(message)
|
||||
except MessageError as e:
|
||||
print(f"Error processing message in '{file_path}': {str(e)}")
|
||||
return messages
|
||||
|
||||
|
||||
def make_file_path(dir_path: Path,
|
||||
file_suffix: str,
|
||||
next_fid: Callable[[], int]) -> Path:
|
||||
"""
|
||||
Create a file_path for the given directory using the
|
||||
given file_suffix and ID generator function.
|
||||
"""
|
||||
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
|
||||
while file_path.exists():
|
||||
file_path = dir_path / f"{next_fid():04d}{file_suffix}"
|
||||
return file_path
|
||||
|
||||
|
||||
def write_dir(dir_path: Path,
|
||||
messages: list[Message],
|
||||
file_suffix: str,
|
||||
next_fid: Callable[[], int]) -> None:
|
||||
"""
|
||||
Write all messages to the given directory. If a message has no file_path,
|
||||
a new one will be created. If message.file_path exists, it will be modified
|
||||
to point to the given directory.
|
||||
Parameters:
|
||||
* 'dir_path': destination directory
|
||||
* 'messages': list of messages to write
|
||||
* 'file_suffix': suffix for the message files ['.txt'|'.yaml']
|
||||
* 'next_fid': callable that returns the next file ID
|
||||
"""
|
||||
for message in messages:
|
||||
file_path = message.file_path
|
||||
# message has no file_path: create one
|
||||
if not file_path:
|
||||
file_path = make_file_path(dir_path, file_suffix, next_fid)
|
||||
# file_path does not point to given directory: modify it
|
||||
elif not file_path.parent.samefile(dir_path):
|
||||
file_path = dir_path / file_path.name
|
||||
message.to_file(file_path)
|
||||
|
||||
|
||||
def clear_dir(dir_path: Path,
|
||||
glob: Optional[str] = None) -> None:
|
||||
"""
|
||||
Deletes all Message files in the given directory.
|
||||
"""
|
||||
file_iter = dir_path.glob(glob) if glob else dir_path.iterdir()
|
||||
for file_path in file_iter:
|
||||
if file_path.is_file() and file_path.suffix in Message.file_suffixes:
|
||||
file_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chat:
|
||||
"""
|
||||
A class containing a complete chat history.
|
||||
"""
|
||||
|
||||
messages: list[Message]
|
||||
|
||||
def filter(self, mfilter: MessageFilter) -> None:
|
||||
"""
|
||||
Use 'Message.match(mfilter) to remove all messages that
|
||||
don't fulfill the filter requirements.
|
||||
"""
|
||||
self.messages = [m for m in self.messages if m.match(mfilter)]
|
||||
|
||||
def sort(self, reverse: bool = False) -> None:
|
||||
"""
|
||||
Sort the messages according to 'Message.msg_id()'.
|
||||
"""
|
||||
try:
|
||||
# the message may not have an ID if it doesn't have a file_path
|
||||
self.messages.sort(key=lambda m: m.msg_id(), reverse=reverse)
|
||||
except MessageError:
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Delete all messages.
|
||||
"""
|
||||
self.messages = []
|
||||
|
||||
def add_messages(self, messages: list[Message]) -> None:
|
||||
"""
|
||||
Add new messages and sort them if possible.
|
||||
"""
|
||||
self.messages += messages
|
||||
self.sort()
|
||||
|
||||
def latest_message(self) -> Optional[Message]:
|
||||
"""
|
||||
Returns the last added message (according to the file ID).
|
||||
"""
|
||||
if len(self.messages) > 0:
|
||||
self.sort()
|
||||
return self.messages[-1]
|
||||
else:
|
||||
return None
|
||||
|
||||
def find_messages(self, msg_names: list[str]) -> list[Message]:
|
||||
"""
|
||||
Search and return the messages with the given names. Names can either be filenames
|
||||
(incl. suffixes) or full paths. Messages that can't be found are ignored (i. e. the
|
||||
caller should check the result if he requires all messages).
|
||||
"""
|
||||
return [m for m in self.messages
|
||||
if any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
|
||||
|
||||
def remove_messages(self, msg_names: list[str]) -> None:
|
||||
"""
|
||||
Remove the messages with the given names. Names can either be filenames
|
||||
(incl. the suffix) or full paths.
|
||||
"""
|
||||
self.messages = [m for m in self.messages
|
||||
if not any((m.file_path and (m.file_path == Path(mn) or m.file_path.name == mn)) for mn in msg_names)]
|
||||
self.sort()
|
||||
|
||||
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
|
||||
"""
|
||||
Get the tags of all messages, optionally filtered by prefix or substring.
|
||||
"""
|
||||
tags: set[Tag] = set()
|
||||
for m in self.messages:
|
||||
tags |= m.filter_tags(prefix, contain)
|
||||
return set(sorted(tags))
|
||||
|
||||
def tags_frequency(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> dict[Tag, int]:
|
||||
"""
|
||||
Get the frequency of all tags of all messages, optionally filtered by prefix or substring.
|
||||
"""
|
||||
tags: list[Tag] = []
|
||||
for m in self.messages:
|
||||
tags += [tag for tag in m.filter_tags(prefix, contain)]
|
||||
return {tag: tags.count(tag) for tag in sorted(tags)}
|
||||
|
||||
def tokens(self) -> int:
|
||||
"""
|
||||
Returns the nr. of AI language tokens used by all messages in this chat.
|
||||
If unknown, 0 is returned.
|
||||
"""
|
||||
return sum(m.tokens() for m in self.messages)
|
||||
|
||||
def print(self, source_code_only: bool = False,
|
||||
with_tags: bool = False, with_files: bool = False,
|
||||
paged: bool = True) -> None:
|
||||
output: list[str] = []
|
||||
for message in self.messages:
|
||||
if source_code_only:
|
||||
output.append(message.to_str(source_code_only=True))
|
||||
continue
|
||||
output.append(message.to_str(with_tags, with_files))
|
||||
if paged:
|
||||
print_paged('\n'.join(output))
|
||||
else:
|
||||
print(*output, sep='\n')
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatDB(Chat):
|
||||
"""
|
||||
A 'Chat' class that is bound to a given directory structure. Supports reading
|
||||
and writing messages from / to that structure. Such a structure consists of
|
||||
two directories: a 'cache directory', where all messages are temporarily
|
||||
stored, and a 'DB' directory, where selected messages can be stored
|
||||
persistently.
|
||||
"""
|
||||
|
||||
default_file_suffix: ClassVar[str] = '.txt'
|
||||
|
||||
cache_path: Path
|
||||
db_path: Path
|
||||
# a MessageFilter that all messages must match (if given)
|
||||
mfilter: Optional[MessageFilter] = None
|
||||
file_suffix: str = default_file_suffix
|
||||
# the glob pattern for all messages
|
||||
glob: Optional[str] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# contains the latest message ID
|
||||
self.next_fname = self.db_path / '.next'
|
||||
# make all paths absolute
|
||||
self.cache_path = self.cache_path.absolute()
|
||||
self.db_path = self.db_path.absolute()
|
||||
|
||||
@classmethod
|
||||
def from_dir(cls: Type[ChatDBInst],
|
||||
cache_path: Path,
|
||||
db_path: Path,
|
||||
glob: Optional[str] = None,
|
||||
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
|
||||
"""
|
||||
Create a 'ChatDB' instance from the given directory structure.
|
||||
Reads all messages from 'db_path' into the local message list.
|
||||
Parameters:
|
||||
* 'cache_path': path to the directory for temporary messages
|
||||
* 'db_path': path to the directory for persistent messages
|
||||
* 'glob': if specified, files will be filtered using 'path.glob()',
|
||||
otherwise it uses 'path.iterdir()'.
|
||||
* 'mfilter': use with 'Message.from_file()' to filter messages
|
||||
when reading them.
|
||||
"""
|
||||
messages = read_dir(db_path, glob, mfilter)
|
||||
return cls(messages, cache_path, db_path, mfilter,
|
||||
cls.default_file_suffix, glob)
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls: Type[ChatDBInst],
|
||||
cache_path: Path,
|
||||
db_path: Path,
|
||||
messages: list[Message],
|
||||
mfilter: Optional[MessageFilter] = None) -> ChatDBInst:
|
||||
"""
|
||||
Create a ChatDB instance from the given message list.
|
||||
"""
|
||||
return cls(messages, cache_path, db_path, mfilter)
|
||||
|
||||
def get_next_fid(self) -> int:
|
||||
try:
|
||||
with open(self.next_fname, 'r') as f:
|
||||
next_fid = int(f.read()) + 1
|
||||
self.set_next_fid(next_fid)
|
||||
return next_fid
|
||||
except Exception:
|
||||
self.set_next_fid(1)
|
||||
return 1
|
||||
|
||||
def set_next_fid(self, fid: int) -> None:
|
||||
with open(self.next_fname, 'w') as f:
|
||||
f.write(f'{fid}')
|
||||
|
||||
def read_db(self) -> None:
|
||||
"""
|
||||
Reads new messages from the DB directory. New ones are added to the internal list,
|
||||
existing ones are replaced. A message is determined as 'existing' if a message with
|
||||
the same base filename (i. e. 'file_path.name') is already in the list.
|
||||
"""
|
||||
new_messages = read_dir(self.db_path, self.glob, self.mfilter)
|
||||
# remove all messages from self.messages that are in the new list
|
||||
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
|
||||
# copy the messages from the temporary list to self.messages and sort them
|
||||
self.messages += new_messages
|
||||
self.sort()
|
||||
|
||||
def read_cache(self) -> None:
|
||||
"""
|
||||
Reads new messages from the cache directory. New ones are added to the internal list,
|
||||
existing ones are replaced. A message is determined as 'existing' if a message with
|
||||
the same base filename (i. e. 'file_path.name') is already in the list.
|
||||
"""
|
||||
new_messages = read_dir(self.cache_path, self.glob, self.mfilter)
|
||||
# remove all messages from self.messages that are in the new list
|
||||
self.messages = [m for m in self.messages if not message_in(m, new_messages)]
|
||||
# copy the messages from the temporary list to self.messages and sort them
|
||||
self.messages += new_messages
|
||||
self.sort()
|
||||
|
||||
def write_db(self, messages: Optional[list[Message]] = None) -> None:
|
||||
"""
|
||||
Write messages to the DB directory. If a message has no file_path, a new one
|
||||
will be created. If message.file_path exists, it will be modified to point
|
||||
to the DB directory.
|
||||
"""
|
||||
write_dir(self.db_path,
|
||||
messages if messages else self.messages,
|
||||
self.file_suffix,
|
||||
self.get_next_fid)
|
||||
|
||||
def write_cache(self, messages: Optional[list[Message]] = None) -> None:
|
||||
"""
|
||||
Write messages to the cache directory. If a message has no file_path, a new one
|
||||
will be created. If message.file_path exists, it will be modified to point to
|
||||
the cache directory.
|
||||
"""
|
||||
write_dir(self.cache_path,
|
||||
messages if messages else self.messages,
|
||||
self.file_suffix,
|
||||
self.get_next_fid)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""
|
||||
Deletes all Message files from the cache dir and removes those messages from
|
||||
the internal list.
|
||||
"""
|
||||
clear_dir(self.cache_path, self.glob)
|
||||
# only keep messages from DB dir (or those that have not yet been written)
|
||||
self.messages = [m for m in self.messages if not m.file_path or m.file_path.parent.samefile(self.db_path)]
|
||||
|
||||
def add_to_db(self, messages: list[Message], write: bool = True) -> None:
|
||||
"""
|
||||
Add the given new messages and set the file_path to the DB directory.
|
||||
Only accepts messages without a file_path.
|
||||
"""
|
||||
if any(m.file_path is not None for m in messages):
|
||||
raise ChatError("Can't add new messages with existing file_path")
|
||||
if write:
|
||||
write_dir(self.db_path,
|
||||
messages,
|
||||
self.file_suffix,
|
||||
self.get_next_fid)
|
||||
else:
|
||||
for m in messages:
|
||||
m.file_path = make_file_path(self.db_path, self.default_file_suffix, self.get_next_fid)
|
||||
self.messages += messages
|
||||
self.sort()
|
||||
|
||||
def add_to_cache(self, messages: list[Message], write: bool = True) -> None:
|
||||
"""
|
||||
Add the given new messages and set the file_path to the cache directory.
|
||||
Only accepts messages without a file_path.
|
||||
"""
|
||||
if any(m.file_path is not None for m in messages):
|
||||
raise ChatError("Can't add new messages with existing file_path")
|
||||
if write:
|
||||
write_dir(self.cache_path,
|
||||
messages,
|
||||
self.file_suffix,
|
||||
self.get_next_fid)
|
||||
else:
|
||||
for m in messages:
|
||||
m.file_path = make_file_path(self.cache_path, self.default_file_suffix, self.get_next_fid)
|
||||
self.messages += messages
|
||||
self.sort()
|
||||
|
||||
def write_messages(self, messages: Optional[list[Message]] = None) -> None:
|
||||
"""
|
||||
Write either the given messages or the internal ones to their current file_path.
|
||||
If messages are given, they all must have a valid file_path. When writing the
|
||||
internal messages, the ones with a valid file_path are written, the others
|
||||
are ignored.
|
||||
"""
|
||||
if messages and any(m.file_path is None for m in messages):
|
||||
raise ChatError("Can't write files without a valid file_path")
|
||||
msgs = iter(messages if messages else self.messages)
|
||||
while (m := next(msgs, None)):
|
||||
m.to_file()
|
||||
|
||||
def update_messages(self, messages: list[Message], write: bool = True) -> None:
|
||||
"""
|
||||
Update existing messages. A message is determined as 'existing' if a message with
|
||||
the same base filename (i. e. 'file_path.name') is already in the list. Only accepts
|
||||
existing messages.
|
||||
"""
|
||||
if any(not message_in(m, self.messages) for m in messages):
|
||||
raise ChatError("Can't update messages that are not in the internal list")
|
||||
# remove old versions and add new ones
|
||||
self.messages = [m for m in self.messages if not message_in(m, messages)]
|
||||
self.messages += messages
|
||||
self.sort()
|
||||
# write the UPDATED messages if requested
|
||||
if write:
|
||||
self.write_messages(messages)
|
||||
20
chatmastermind/commands/config.py
Normal file
20
chatmastermind/commands/config.py
Normal file
@ -0,0 +1,20 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from ..configuration import Config
|
||||
from ..ai import AI
|
||||
from ..ai_factory import create_ai
|
||||
|
||||
|
||||
def config_cmd(args: argparse.Namespace) -> None:
|
||||
"""
|
||||
Handler for the 'config' command.
|
||||
"""
|
||||
if args.create:
|
||||
Config.create_default(Path(args.create))
|
||||
elif args.list_models or args.print_model:
|
||||
config: Config = Config.from_file(args.config)
|
||||
ai: AI = create_ai(args, config)
|
||||
if args.list_models:
|
||||
ai.print_models()
|
||||
else:
|
||||
print(ai.config.model)
|
||||
23
chatmastermind/commands/hist.py
Normal file
23
chatmastermind/commands/hist.py
Normal file
@ -0,0 +1,23 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from ..configuration import Config
|
||||
from ..chat import ChatDB
|
||||
from ..message import MessageFilter
|
||||
|
||||
|
||||
def hist_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
"""
|
||||
Handler for the 'hist' command.
|
||||
"""
|
||||
|
||||
mfilter = MessageFilter(tags_or=args.or_tags,
|
||||
tags_and=args.and_tags,
|
||||
tags_not=args.exclude_tags,
|
||||
question_contains=args.question,
|
||||
answer_contains=args.answer)
|
||||
chat = ChatDB.from_dir(Path('.'),
|
||||
Path(config.db),
|
||||
mfilter=mfilter)
|
||||
chat.print(args.source_code_only,
|
||||
args.with_tags,
|
||||
args.with_files)
|
||||
27
chatmastermind/commands/print.py
Normal file
27
chatmastermind/commands/print.py
Normal file
@ -0,0 +1,27 @@
|
||||
import sys
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from ..configuration import Config
|
||||
from ..message import Message, MessageError
|
||||
|
||||
|
||||
def print_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
"""
|
||||
Handler for the 'print' command.
|
||||
"""
|
||||
fname = Path(args.file)
|
||||
try:
|
||||
message = Message.from_file(fname)
|
||||
if message:
|
||||
if args.question:
|
||||
print(message.question)
|
||||
elif args.answer:
|
||||
print(message.answer)
|
||||
elif message.answer and args.only_source_code:
|
||||
for code in message.answer.source_code():
|
||||
print(code)
|
||||
else:
|
||||
print(message.to_str())
|
||||
except MessageError:
|
||||
print(f"File is not a valid message: {args.file}")
|
||||
sys.exit(1)
|
||||
122
chatmastermind/commands/question.py
Normal file
122
chatmastermind/commands/question.py
Normal file
@ -0,0 +1,122 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from itertools import zip_longest
|
||||
from ..configuration import Config
|
||||
from ..chat import ChatDB
|
||||
from ..message import Message, MessageFilter, MessageError, Question, source_code
|
||||
from ..ai_factory import create_ai
|
||||
from ..ai import AI, AIResponse
|
||||
|
||||
|
||||
def add_file_as_text(question_parts: list[str], file: str) -> None:
|
||||
"""
|
||||
Add the given file as plain text to the question part list.
|
||||
If the file is a Message, add the answer.
|
||||
"""
|
||||
file_path = Path(file)
|
||||
content: str
|
||||
try:
|
||||
message = Message.from_file(file_path)
|
||||
if message and message.answer:
|
||||
content = message.answer
|
||||
except MessageError:
|
||||
with open(file) as r:
|
||||
content = r.read().strip()
|
||||
if len(content) > 0:
|
||||
question_parts.append(content)
|
||||
|
||||
|
||||
def add_file_as_code(question_parts: list[str], file: str) -> None:
|
||||
"""
|
||||
Add all source code from the given file. If no code segments can be extracted,
|
||||
the whole content is added as source code segment. If the file is a Message,
|
||||
extract the source code from the answer.
|
||||
"""
|
||||
file_path = Path(file)
|
||||
content: str
|
||||
try:
|
||||
message = Message.from_file(file_path)
|
||||
if message and message.answer:
|
||||
content = message.answer
|
||||
except MessageError:
|
||||
with open(file) as r:
|
||||
content = r.read().strip()
|
||||
# extract and add source code
|
||||
code_parts = source_code(content, include_delims=True)
|
||||
if len(code_parts) > 0:
|
||||
question_parts += code_parts
|
||||
else:
|
||||
question_parts.append(f"```\n{content}\n```")
|
||||
|
||||
|
||||
def create_message(chat: ChatDB, args: argparse.Namespace) -> Message:
|
||||
"""
|
||||
Creates (and writes) a new message from the given arguments.
|
||||
"""
|
||||
question_parts = []
|
||||
question_list = args.ask if args.ask is not None else []
|
||||
text_files = args.source_text if args.source_text is not None else []
|
||||
code_files = args.source_code if args.source_code is not None else []
|
||||
|
||||
for question, text_file, code_file in zip_longest(question_list, text_files, code_files, fillvalue=None):
|
||||
if question is not None and len(question.strip()) > 0:
|
||||
question_parts.append(question)
|
||||
if text_file is not None and len(text_file) > 0:
|
||||
add_file_as_text(question_parts, text_file)
|
||||
if code_file is not None and len(code_file) > 0:
|
||||
add_file_as_code(question_parts, code_file)
|
||||
|
||||
full_question = '\n\n'.join(question_parts)
|
||||
|
||||
message = Message(question=Question(full_question),
|
||||
tags=args.output_tags, # FIXME
|
||||
ai=args.AI,
|
||||
model=args.model)
|
||||
chat.add_to_cache([message])
|
||||
return message
|
||||
|
||||
|
||||
def question_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
"""
|
||||
Handler for the 'question' command.
|
||||
"""
|
||||
mfilter = MessageFilter(tags_or=args.or_tags if args.or_tags is not None else set(),
|
||||
tags_and=args.and_tags if args.and_tags is not None else set(),
|
||||
tags_not=args.exclude_tags if args.exclude_tags is not None else set())
|
||||
chat = ChatDB.from_dir(cache_path=Path('.'),
|
||||
db_path=Path(config.db),
|
||||
mfilter=mfilter)
|
||||
# if it's a new question, create and store it immediately
|
||||
if args.ask or args.create:
|
||||
message = create_message(chat, args)
|
||||
if args.create:
|
||||
return
|
||||
|
||||
# create the correct AI instance
|
||||
ai: AI = create_ai(args, config)
|
||||
if args.ask:
|
||||
ai.print()
|
||||
chat.print(paged=False)
|
||||
response: AIResponse = ai.request(message,
|
||||
chat,
|
||||
args.num_answers, # FIXME
|
||||
args.output_tags) # FIXME
|
||||
chat.update_messages([response.messages[0]])
|
||||
chat.add_to_cache(response.messages[1:])
|
||||
for idx, msg in enumerate(response.messages):
|
||||
print(f"=== ANSWER {idx+1} ===")
|
||||
print(msg.answer)
|
||||
if response.tokens:
|
||||
print("===============")
|
||||
print(response.tokens)
|
||||
elif args.repeat is not None:
|
||||
lmessage = chat.latest_message()
|
||||
assert lmessage
|
||||
# TODO: repeat either the last question or the
|
||||
# one(s) given in 'args.repeat' (overwrite
|
||||
# existing ones if 'args.overwrite' is True)
|
||||
pass
|
||||
elif args.process is not None:
|
||||
# TODO: process either all questions without an
|
||||
# answer or the one(s) given in 'args.process'
|
||||
pass
|
||||
17
chatmastermind/commands/tags.py
Normal file
17
chatmastermind/commands/tags.py
Normal file
@ -0,0 +1,17 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from ..configuration import Config
|
||||
from ..chat import ChatDB
|
||||
|
||||
|
||||
def tags_cmd(args: argparse.Namespace, config: Config) -> None:
|
||||
"""
|
||||
Handler for the 'tags' command.
|
||||
"""
|
||||
chat = ChatDB.from_dir(cache_path=Path('.'),
|
||||
db_path=Path(config.db))
|
||||
if args.list:
|
||||
tags_freq = chat.tags_frequency(args.prefix, args.contain)
|
||||
for tag, freq in tags_freq.items():
|
||||
print(f"- {tag}: {freq}")
|
||||
# TODO: add renaming
|
||||
167
chatmastermind/configuration.py
Normal file
167
chatmastermind/configuration.py
Normal file
@ -0,0 +1,167 @@
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from typing import Type, TypeVar, Any, Optional, ClassVar
|
||||
from dataclasses import dataclass, asdict, field
|
||||
|
||||
ConfigInst = TypeVar('ConfigInst', bound='Config')
|
||||
AIConfigInst = TypeVar('AIConfigInst', bound='AIConfig')
|
||||
OpenAIConfigInst = TypeVar('OpenAIConfigInst', bound='OpenAIConfig')
|
||||
|
||||
|
||||
supported_ais: list[str] = ['openai']
|
||||
default_config_path = '.config.yaml'
|
||||
|
||||
|
||||
class ConfigError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
|
||||
"""
|
||||
Changes the YAML dump style to multiline syntax for multiline strings.
|
||||
"""
|
||||
if len(data.splitlines()) > 1:
|
||||
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
|
||||
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
|
||||
|
||||
|
||||
yaml.add_representer(str, str_presenter)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AIConfig:
|
||||
"""
|
||||
The base class of all AI configurations.
|
||||
"""
|
||||
# the name of the AI the config class represents
|
||||
# -> it's a class variable and thus not part of the
|
||||
# dataclass constructor
|
||||
name: ClassVar[str]
|
||||
# a user-defined ID for an AI configuration entry
|
||||
ID: str
|
||||
model: str = 'n/a'
|
||||
|
||||
# the name must not be changed
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name == 'name':
|
||||
raise AttributeError("'{name}' is not allowed to be changed")
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIConfig(AIConfig):
|
||||
"""
|
||||
The OpenAI section of the configuration file.
|
||||
"""
|
||||
name: ClassVar[str] = 'openai'
|
||||
|
||||
# all members have default values, so we can easily create
|
||||
# a default configuration
|
||||
ID: str = 'myopenai'
|
||||
api_key: str = '0123456789'
|
||||
model: str = 'gpt-3.5-turbo-16k'
|
||||
temperature: float = 1.0
|
||||
max_tokens: int = 4000
|
||||
top_p: float = 1.0
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
system: str = 'You are an assistant'
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[OpenAIConfigInst], source: dict[str, Any]) -> OpenAIConfigInst:
|
||||
"""
|
||||
Create OpenAIConfig from a dict.
|
||||
"""
|
||||
res = cls(
|
||||
api_key=str(source['api_key']),
|
||||
model=str(source['model']),
|
||||
max_tokens=int(source['max_tokens']),
|
||||
temperature=float(source['temperature']),
|
||||
top_p=float(source['top_p']),
|
||||
frequency_penalty=float(source['frequency_penalty']),
|
||||
presence_penalty=float(source['presence_penalty']),
|
||||
system=str(source['system'])
|
||||
)
|
||||
# overwrite default ID if provided
|
||||
if 'ID' in source:
|
||||
res.ID = source['ID']
|
||||
return res
|
||||
|
||||
|
||||
def ai_config_instance(name: str, conf_dict: Optional[dict[str, Any]] = None) -> AIConfig:
|
||||
"""
|
||||
Creates an AIConfig instance of the given name.
|
||||
"""
|
||||
if name.lower() == 'openai':
|
||||
if conf_dict is None:
|
||||
return OpenAIConfig()
|
||||
else:
|
||||
return OpenAIConfig.from_dict(conf_dict)
|
||||
else:
|
||||
raise ConfigError(f"Unknown AI '{name}'")
|
||||
|
||||
|
||||
def create_default_ai_configs() -> dict[str, AIConfig]:
|
||||
"""
|
||||
Create a dict containing default configurations for all supported AIs.
|
||||
"""
|
||||
return {ai_config_instance(name).ID: ai_config_instance(name) for name in supported_ais}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
"""
|
||||
The configuration file structure.
|
||||
"""
|
||||
# all members have default values, so we can easily create
|
||||
# a default configuration
|
||||
db: str = './db/'
|
||||
ais: dict[str, AIConfig] = field(default_factory=create_default_ai_configs)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[ConfigInst], source: dict[str, Any]) -> ConfigInst:
|
||||
"""
|
||||
Create Config from a dict (with the same format as the config file).
|
||||
"""
|
||||
# create the correct AI type instances
|
||||
ais: dict[str, AIConfig] = {}
|
||||
for ID, conf in source['ais'].items():
|
||||
# add the AI ID to the config (for easy internal access)
|
||||
conf['ID'] = ID
|
||||
ai_conf = ai_config_instance(conf['name'], conf)
|
||||
ais[ID] = ai_conf
|
||||
return cls(
|
||||
db=str(source['db']),
|
||||
ais=ais
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_default(self, file_path: Path) -> None:
|
||||
"""
|
||||
Creates a default Config in the given file.
|
||||
"""
|
||||
conf = Config()
|
||||
conf.to_file(file_path)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls: Type[ConfigInst], path: str) -> ConfigInst:
|
||||
with open(path, 'r') as f:
|
||||
source = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return cls.from_dict(source)
|
||||
|
||||
def to_file(self, file_path: Path) -> None:
|
||||
# remove the AI name from the config (for a cleaner format)
|
||||
data = self.as_dict()
|
||||
for conf in data['ais'].values():
|
||||
del (conf['ID'])
|
||||
with open(file_path, 'w') as f:
|
||||
yaml.dump(data, f, sort_keys=False)
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
res = asdict(self)
|
||||
# add the AI name manually (as first element)
|
||||
# (not done by 'asdict' because it's a class variable)
|
||||
for ID, conf in res['ais'].items():
|
||||
res['ais'][ID] = {**{'name': self.ais[ID].name}, **conf}
|
||||
return res
|
||||
@ -2,116 +2,123 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# vim: set fileencoding=utf-8 :
|
||||
|
||||
import yaml
|
||||
import sys
|
||||
import argcomplete
|
||||
import argparse
|
||||
import pathlib
|
||||
from .utils import terminal_width, process_tags, display_chat, display_source_code, display_tags_frequency
|
||||
from .storage import save_answers, create_chat, get_tags, get_tags_unique, read_file, dump_data
|
||||
from .api_client import ai, openai_api_key
|
||||
from itertools import zip_longest
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from .configuration import Config, default_config_path
|
||||
from .message import Message
|
||||
from .commands.question import question_cmd
|
||||
from .commands.tags import tags_cmd
|
||||
from .commands.config import config_cmd
|
||||
from .commands.hist import hist_cmd
|
||||
from .commands.print import print_cmd
|
||||
|
||||
|
||||
def run_print_command(args: argparse.Namespace, config: dict) -> None:
|
||||
fname = pathlib.Path(args.print)
|
||||
if fname.suffix == '.yaml':
|
||||
with open(args.print, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
elif fname.suffix == '.txt':
|
||||
data = read_file(fname)
|
||||
else:
|
||||
print(f"Unknown file type: {args.print}")
|
||||
sys.exit(1)
|
||||
if args.only_source_code:
|
||||
display_source_code(data['answer'])
|
||||
else:
|
||||
print(dump_data(data).strip())
|
||||
|
||||
|
||||
def process_and_display_chat(args: argparse.Namespace,
|
||||
config: dict,
|
||||
dump: bool = False
|
||||
) -> tuple[list[dict[str, str]], str, list[str]]:
|
||||
tags = args.tags or []
|
||||
extags = args.extags or []
|
||||
otags = args.output_tags or []
|
||||
|
||||
if not args.only_source_code:
|
||||
process_tags(tags, extags, otags)
|
||||
|
||||
question_parts = []
|
||||
question_list = args.question if args.question is not None else []
|
||||
source_list = args.source if args.source is not None else []
|
||||
|
||||
for question, source in zip_longest(question_list, source_list, fillvalue=None):
|
||||
if question is not None and source is not None:
|
||||
with open(source) as r:
|
||||
question_parts.append(f"{question}\n\n```\n{r.read().strip()}\n```")
|
||||
elif question is not None:
|
||||
question_parts.append(question)
|
||||
elif source is not None:
|
||||
with open(source) as r:
|
||||
question_parts.append(f"```\n{r.read().strip()}\n```")
|
||||
|
||||
full_question = '\n\n'.join(question_parts)
|
||||
chat = create_chat(full_question, tags, extags, config,
|
||||
args.match_all_tags, args.with_tags,
|
||||
args.with_file)
|
||||
display_chat(chat, dump, args.only_source_code)
|
||||
return chat, full_question, tags
|
||||
|
||||
def process_and_display_tags(args: argparse.Namespace,
|
||||
config: dict,
|
||||
dump: bool=False
|
||||
) -> None:
|
||||
display_tags_frequency(get_tags(config, None), dump)
|
||||
|
||||
|
||||
def handle_question(args: argparse.Namespace,
|
||||
config: dict,
|
||||
dump: bool = False
|
||||
) -> None:
|
||||
chat, question, tags = process_and_display_chat(args, config, dump)
|
||||
otags = args.output_tags or []
|
||||
answers, usage = ai(chat, config, args.number)
|
||||
save_answers(question, answers, tags, otags, config)
|
||||
print("-" * terminal_width())
|
||||
print(f"Usage: {usage}")
|
||||
|
||||
|
||||
def tags_completer(prefix, parsed_args, **kwargs):
|
||||
with open(parsed_args.config, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
return get_tags_unique(config, prefix)
|
||||
def tags_completer(prefix: str, parsed_args: Any, **kwargs: Any) -> list[str]:
|
||||
config = Config.from_file(parsed_args.config)
|
||||
return list(Message.tags_from_dir(Path(config.db), prefix=prefix))
|
||||
|
||||
|
||||
def create_parser() -> argparse.ArgumentParser:
|
||||
default_config = '.config.yaml'
|
||||
parser = argparse.ArgumentParser(
|
||||
description="ChatMastermind is a Python application that automates conversation with AI")
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument('-p', '--print', help='File to print')
|
||||
group.add_argument('-q', '--question', nargs='*', help='Question to ask')
|
||||
group.add_argument('-D', '--chat-dump', help="Print chat history as Python structure", action='store_true')
|
||||
group.add_argument('-d', '--chat', help="Print chat history as readable text", action='store_true')
|
||||
group.add_argument('-l', '--list-tags', help="List all tags and their frequency", action='store_true')
|
||||
parser.add_argument('-c', '--config', help='Config file name.', default=default_config)
|
||||
parser.add_argument('-m', '--max-tokens', help='Max tokens to use', type=int)
|
||||
parser.add_argument('-T', '--temperature', help='Temperature to use', type=float)
|
||||
parser.add_argument('-M', '--model', help='Model to use')
|
||||
parser.add_argument('-n', '--number', help='Number of answers to produce', type=int, default=1)
|
||||
parser.add_argument('-s', '--source', nargs='*', help='Source add content of a file to the query')
|
||||
parser.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true')
|
||||
parser.add_argument('-w', '--with-tags', help="Print chat history with tags.", action='store_true')
|
||||
parser.add_argument('-W', '--with-file', help="Print chat history with filename.", action='store_true')
|
||||
parser.add_argument('-a', '--match-all-tags', help="All given tags must match when selecting chat history entries.", action='store_true')
|
||||
tags_arg = parser.add_argument('-t', '--tags', nargs='*', help='List of tag names', metavar='TAGS')
|
||||
tags_arg.completer = tags_completer # type: ignore
|
||||
extags_arg = parser.add_argument('-e', '--extags', nargs='*', help='List of tag names to exclude', metavar='EXTAGS')
|
||||
extags_arg.completer = tags_completer # type: ignore
|
||||
otags_arg = parser.add_argument('-o', '--output-tags', nargs='*', help='List of output tag names, default is input', metavar='OTAGS')
|
||||
otags_arg.completer = tags_completer # type: ignore
|
||||
parser.add_argument('-C', '--config', help='Config file name.', default=default_config_path)
|
||||
|
||||
# subcommand-parser
|
||||
cmdparser = parser.add_subparsers(dest='command',
|
||||
title='commands',
|
||||
description='supported commands',
|
||||
required=True)
|
||||
|
||||
# a parent parser for all commands that support tag selection
|
||||
tag_parser = argparse.ArgumentParser(add_help=False)
|
||||
tag_arg = tag_parser.add_argument('-t', '--or-tags', nargs='+',
|
||||
help='List of tags (one must match)', metavar='OTAGS')
|
||||
tag_arg.completer = tags_completer # type: ignore
|
||||
atag_arg = tag_parser.add_argument('-k', '--and-tags', nargs='+',
|
||||
help='List of tags (all must match)', metavar='ATAGS')
|
||||
atag_arg.completer = tags_completer # type: ignore
|
||||
etag_arg = tag_parser.add_argument('-x', '--exclude-tags', nargs='+',
|
||||
help='List of tags to exclude', metavar='XTAGS')
|
||||
etag_arg.completer = tags_completer # type: ignore
|
||||
otag_arg = tag_parser.add_argument('-o', '--output-tags', nargs='+',
|
||||
help='List of output tags (default: use input tags)', metavar='OUTTAGS')
|
||||
otag_arg.completer = tags_completer # type: ignore
|
||||
|
||||
# a parent parser for all commands that support AI configuration
|
||||
ai_parser = argparse.ArgumentParser(add_help=False)
|
||||
ai_parser.add_argument('-A', '--AI', help='AI ID to use')
|
||||
ai_parser.add_argument('-M', '--model', help='Model to use')
|
||||
ai_parser.add_argument('-n', '--num-answers', help='Number of answers to request', type=int, default=1)
|
||||
ai_parser.add_argument('-m', '--max-tokens', help='Max. nr. of tokens', type=int)
|
||||
ai_parser.add_argument('-T', '--temperature', help='Temperature value', type=float)
|
||||
|
||||
# 'question' command parser
|
||||
question_cmd_parser = cmdparser.add_parser('question', parents=[tag_parser, ai_parser],
|
||||
help="ask, create and process questions.",
|
||||
aliases=['q'])
|
||||
question_cmd_parser.set_defaults(func=question_cmd)
|
||||
question_group = question_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||
question_group.add_argument('-a', '--ask', nargs='+', help='Ask a question')
|
||||
question_group.add_argument('-c', '--create', nargs='+', help='Create a question')
|
||||
question_group.add_argument('-r', '--repeat', nargs='*', help='Repeat a question')
|
||||
question_group.add_argument('-p', '--process', nargs='*', help='Process existing questions')
|
||||
question_cmd_parser.add_argument('-O', '--overwrite', help='Overwrite existing messages when repeating them',
|
||||
action='store_true')
|
||||
question_cmd_parser.add_argument('-s', '--source-text', nargs='+', help='Add content of a file to the query')
|
||||
question_cmd_parser.add_argument('-S', '--source-code', nargs='+', help='Add source code file content to the chat history')
|
||||
|
||||
# 'hist' command parser
|
||||
hist_cmd_parser = cmdparser.add_parser('hist', parents=[tag_parser],
|
||||
help="Print chat history.",
|
||||
aliases=['h'])
|
||||
hist_cmd_parser.set_defaults(func=hist_cmd)
|
||||
hist_cmd_parser.add_argument('-w', '--with-tags', help="Print chat history with tags.",
|
||||
action='store_true')
|
||||
hist_cmd_parser.add_argument('-W', '--with-files', help="Print chat history with filenames.",
|
||||
action='store_true')
|
||||
hist_cmd_parser.add_argument('-S', '--source-code-only', help='Print only source code',
|
||||
action='store_true')
|
||||
hist_cmd_parser.add_argument('-A', '--answer', help='Search for answer substring')
|
||||
hist_cmd_parser.add_argument('-Q', '--question', help='Search for question substring')
|
||||
|
||||
# 'tags' command parser
|
||||
tags_cmd_parser = cmdparser.add_parser('tags',
|
||||
help="Manage tags.",
|
||||
aliases=['t'])
|
||||
tags_cmd_parser.set_defaults(func=tags_cmd)
|
||||
tags_group = tags_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||
tags_group.add_argument('-l', '--list', help="List all tags and their frequency",
|
||||
action='store_true')
|
||||
tags_cmd_parser.add_argument('-p', '--prefix', help="Filter tags by prefix")
|
||||
tags_cmd_parser.add_argument('-c', '--contain', help="Filter tags by contained substring")
|
||||
|
||||
# 'config' command parser
|
||||
config_cmd_parser = cmdparser.add_parser('config',
|
||||
help="Manage configuration",
|
||||
aliases=['c'])
|
||||
config_cmd_parser.set_defaults(func=config_cmd)
|
||||
config_cmd_parser.add_argument('-A', '--AI', help='AI ID to use')
|
||||
config_group = config_cmd_parser.add_mutually_exclusive_group(required=True)
|
||||
config_group.add_argument('-l', '--list-models', help="List all available models",
|
||||
action='store_true')
|
||||
config_group.add_argument('-m', '--print-model', help="Print the currently configured model",
|
||||
action='store_true')
|
||||
config_group.add_argument('-c', '--create', help="Create config with default settings in the given file")
|
||||
|
||||
# 'print' command parser
|
||||
print_cmd_parser = cmdparser.add_parser('print',
|
||||
help="Print message files.",
|
||||
aliases=['p'])
|
||||
print_cmd_parser.set_defaults(func=print_cmd)
|
||||
print_cmd_parser.add_argument('-f', '--file', help='File to print', required=True)
|
||||
print_cmd_modes = print_cmd_parser.add_mutually_exclusive_group()
|
||||
print_cmd_modes.add_argument('-q', '--question', help='Print only question', action='store_true')
|
||||
print_cmd_modes.add_argument('-a', '--answer', help='Print only answer', action='store_true')
|
||||
print_cmd_modes.add_argument('-S', '--only-source-code', help='Print only source code', action='store_true')
|
||||
|
||||
argcomplete.autocomplete(parser)
|
||||
return parser
|
||||
|
||||
@ -119,31 +126,13 @@ def create_parser() -> argparse.ArgumentParser:
|
||||
def main() -> int:
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
command = parser.parse_args()
|
||||
|
||||
with open(args.config, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
|
||||
openai_api_key(config['openai']['api_key'])
|
||||
|
||||
if args.max_tokens:
|
||||
config['openai']['max_tokens'] = args.max_tokens
|
||||
|
||||
if args.temperature:
|
||||
config['openai']['temperature'] = args.temperature
|
||||
|
||||
if args.model:
|
||||
config['openai']['model'] = args.model
|
||||
|
||||
if args.print:
|
||||
run_print_command(args, config)
|
||||
elif args.question:
|
||||
handle_question(args, config)
|
||||
elif args.chat_dump:
|
||||
process_and_display_chat(args, config, dump=True)
|
||||
elif args.chat:
|
||||
process_and_display_chat(args, config)
|
||||
elif args.list_tags:
|
||||
process_and_display_tags(args, config)
|
||||
if command.func == config_cmd:
|
||||
command.func(command)
|
||||
else:
|
||||
config = Config.from_file(args.config)
|
||||
command.func(command, config)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
561
chatmastermind/message.py
Normal file
561
chatmastermind/message.py
Normal file
@ -0,0 +1,561 @@
|
||||
"""
|
||||
Module implementing message related functions and classes.
|
||||
"""
|
||||
import pathlib
|
||||
import yaml
|
||||
import tempfile
|
||||
import shutil
|
||||
from typing import Type, TypeVar, ClassVar, Optional, Any, Union, Final, Literal, Iterable
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from .tags import Tag, TagLine, TagError, match_tags, rename_tags
|
||||
|
||||
QuestionInst = TypeVar('QuestionInst', bound='Question')
|
||||
AnswerInst = TypeVar('AnswerInst', bound='Answer')
|
||||
MessageInst = TypeVar('MessageInst', bound='Message')
|
||||
AILineInst = TypeVar('AILineInst', bound='AILine')
|
||||
ModelLineInst = TypeVar('ModelLineInst', bound='ModelLine')
|
||||
YamlDict = dict[str, Union[QuestionInst, AnswerInst, set[Tag]]]
|
||||
|
||||
|
||||
class MessageError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def str_presenter(dumper: yaml.Dumper, data: str) -> yaml.ScalarNode:
|
||||
"""
|
||||
Changes the YAML dump style to multiline syntax for multiline strings.
|
||||
"""
|
||||
if len(data.splitlines()) > 1:
|
||||
return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|')
|
||||
return dumper.represent_scalar('tag:yaml.org,2002:str', data)
|
||||
|
||||
|
||||
yaml.add_representer(str, str_presenter)
|
||||
|
||||
|
||||
def source_code(text: str, include_delims: bool = False) -> list[str]:
|
||||
"""
|
||||
Extract all source code sections from the given text, i. e. all lines
|
||||
surrounded by lines tarting with '```'. If 'include_delims' is True,
|
||||
the surrounding lines are included, otherwise they are omitted. The
|
||||
result list contains every source code section as a single string.
|
||||
The order in the list represents the order of the sections in the text.
|
||||
"""
|
||||
code_sections: list[str] = []
|
||||
code_lines: list[str] = []
|
||||
in_code_block = False
|
||||
|
||||
for line in text.split('\n'):
|
||||
if line.strip().startswith('```'):
|
||||
if include_delims:
|
||||
code_lines.append(line)
|
||||
if in_code_block:
|
||||
code_sections.append('\n'.join(code_lines) + '\n')
|
||||
code_lines.clear()
|
||||
in_code_block = not in_code_block
|
||||
elif in_code_block:
|
||||
code_lines.append(line)
|
||||
|
||||
return code_sections
|
||||
|
||||
|
||||
def message_in(message: MessageInst, messages: Iterable[MessageInst]) -> bool:
|
||||
"""
|
||||
Searches the given message list for a message with the same file
|
||||
name as the given one (i. e. it compares Message.file_path.name).
|
||||
If the given message has no file_path, False is returned.
|
||||
"""
|
||||
if not message.file_path:
|
||||
return False
|
||||
for m in messages:
|
||||
if m.file_path and m.file_path.name == message.file_path.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class MessageFilter:
|
||||
"""
|
||||
Various filters for a Message.
|
||||
"""
|
||||
tags_or: Optional[set[Tag]] = None
|
||||
tags_and: Optional[set[Tag]] = None
|
||||
tags_not: Optional[set[Tag]] = None
|
||||
ai: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
question_contains: Optional[str] = None
|
||||
answer_contains: Optional[str] = None
|
||||
answer_state: Optional[Literal['available', 'missing']] = None
|
||||
ai_state: Optional[Literal['available', 'missing']] = None
|
||||
model_state: Optional[Literal['available', 'missing']] = None
|
||||
|
||||
|
||||
class AILine(str):
|
||||
"""
|
||||
A line that represents the AI name in a '.txt' file..
|
||||
"""
|
||||
prefix: Final[str] = 'AI:'
|
||||
|
||||
def __new__(cls: Type[AILineInst], string: str) -> AILineInst:
|
||||
if not string.startswith(cls.prefix):
|
||||
raise MessageError(f"AILine '{string}' is missing prefix '{cls.prefix}'")
|
||||
instance = super().__new__(cls, string)
|
||||
return instance
|
||||
|
||||
def ai(self) -> str:
|
||||
return self[len(self.prefix):].strip()
|
||||
|
||||
@classmethod
|
||||
def from_ai(cls: Type[AILineInst], ai: str) -> AILineInst:
|
||||
return cls(' '.join([cls.prefix, ai]))
|
||||
|
||||
|
||||
class ModelLine(str):
|
||||
"""
|
||||
A line that represents the model name in a '.txt' file..
|
||||
"""
|
||||
prefix: Final[str] = 'MODEL:'
|
||||
|
||||
def __new__(cls: Type[ModelLineInst], string: str) -> ModelLineInst:
|
||||
if not string.startswith(cls.prefix):
|
||||
raise MessageError(f"ModelLine '{string}' is missing prefix '{cls.prefix}'")
|
||||
instance = super().__new__(cls, string)
|
||||
return instance
|
||||
|
||||
def model(self) -> str:
|
||||
return self[len(self.prefix):].strip()
|
||||
|
||||
@classmethod
|
||||
def from_model(cls: Type[ModelLineInst], model: str) -> ModelLineInst:
|
||||
return cls(' '.join([cls.prefix, model]))
|
||||
|
||||
|
||||
class Answer(str):
|
||||
"""
|
||||
A single answer with a defined header.
|
||||
"""
|
||||
tokens: int = 0 # tokens used by this answer
|
||||
txt_header: ClassVar[str] = '==== ANSWER ===='
|
||||
yaml_key: ClassVar[str] = 'answer'
|
||||
|
||||
def __new__(cls: Type[AnswerInst], string: str) -> AnswerInst:
|
||||
"""
|
||||
Make sure the answer string does not contain the header as a whole line.
|
||||
"""
|
||||
if cls.txt_header in string.split('\n'):
|
||||
raise MessageError(f"Answer '{string}' contains the header '{cls.txt_header}'")
|
||||
instance = super().__new__(cls, string)
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def from_list(cls: Type[AnswerInst], strings: list[str]) -> AnswerInst:
|
||||
"""
|
||||
Build Question from a list of strings. Make sure strings do not contain the header.
|
||||
"""
|
||||
if cls.txt_header in strings:
|
||||
raise MessageError(f"Question contains the header '{cls.txt_header}'")
|
||||
instance = super().__new__(cls, '\n'.join(strings).strip())
|
||||
return instance
|
||||
|
||||
def source_code(self, include_delims: bool = False) -> list[str]:
|
||||
"""
|
||||
Extract and return all source code sections.
|
||||
"""
|
||||
return source_code(self, include_delims)
|
||||
|
||||
|
||||
class Question(str):
|
||||
"""
|
||||
A single question with a defined header.
|
||||
"""
|
||||
tokens: int = 0 # tokens used by this question
|
||||
txt_header: ClassVar[str] = '=== QUESTION ==='
|
||||
yaml_key: ClassVar[str] = 'question'
|
||||
|
||||
def __new__(cls: Type[QuestionInst], string: str) -> QuestionInst:
|
||||
"""
|
||||
Make sure the question string does not contain the header as a whole line
|
||||
(also not that from 'Answer', so it's always clear where the answer starts).
|
||||
"""
|
||||
string_lines = string.split('\n')
|
||||
if cls.txt_header in string_lines:
|
||||
raise MessageError(f"Question '{string}' contains the header '{cls.txt_header}'")
|
||||
if Answer.txt_header in string_lines:
|
||||
raise MessageError(f"Question '{string}' contains the header '{Answer.txt_header}'")
|
||||
instance = super().__new__(cls, string)
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def from_list(cls: Type[QuestionInst], strings: list[str]) -> QuestionInst:
|
||||
"""
|
||||
Build Question from a list of strings. Make sure strings do not contain the header.
|
||||
"""
|
||||
if cls.txt_header in strings:
|
||||
raise MessageError(f"Question contains the header '{cls.txt_header}'")
|
||||
instance = super().__new__(cls, '\n'.join(strings).strip())
|
||||
return instance
|
||||
|
||||
def source_code(self, include_delims: bool = False) -> list[str]:
|
||||
"""
|
||||
Extract and return all source code sections.
|
||||
"""
|
||||
return source_code(self, include_delims)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message():
|
||||
"""
|
||||
Single message. Consists of a question and optionally an answer, a set of tags
|
||||
and a file path.
|
||||
"""
|
||||
question: Question
|
||||
answer: Optional[Answer] = None
|
||||
# metadata, ignored when comparing messages
|
||||
tags: Optional[set[Tag]] = field(default=None, compare=False)
|
||||
ai: Optional[str] = field(default=None, compare=False)
|
||||
model: Optional[str] = field(default=None, compare=False)
|
||||
file_path: Optional[pathlib.Path] = field(default=None, compare=False)
|
||||
# class variables
|
||||
file_suffixes: ClassVar[list[str]] = ['.txt', '.yaml']
|
||||
tags_yaml_key: ClassVar[str] = 'tags'
|
||||
file_yaml_key: ClassVar[str] = 'file_path'
|
||||
ai_yaml_key: ClassVar[str] = 'ai'
|
||||
model_yaml_key: ClassVar[str] = 'model'
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
The hash value is computed based on immutable members.
|
||||
"""
|
||||
return hash((self.question, self.answer))
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[MessageInst], data: dict[str, Any]) -> MessageInst:
|
||||
"""
|
||||
Create a Message from the given dict.
|
||||
"""
|
||||
return cls(question=data[Question.yaml_key],
|
||||
answer=data.get(Answer.yaml_key, None),
|
||||
tags=set(data.get(cls.tags_yaml_key, [])),
|
||||
ai=data.get(cls.ai_yaml_key, None),
|
||||
model=data.get(cls.model_yaml_key, None),
|
||||
file_path=data.get(cls.file_yaml_key, None))
|
||||
|
||||
@classmethod
|
||||
def tags_from_file(cls: Type[MessageInst],
|
||||
file_path: pathlib.Path,
|
||||
prefix: Optional[str] = None,
|
||||
contain: Optional[str] = None) -> set[Tag]:
|
||||
"""
|
||||
Return only the tags from the given Message file,
|
||||
optionally filtered based on prefix or contained string.
|
||||
"""
|
||||
tags: set[Tag] = set()
|
||||
if not file_path.exists():
|
||||
raise MessageError(f"Message file '{file_path}' does not exist")
|
||||
if file_path.suffix not in cls.file_suffixes:
|
||||
raise MessageError(f"File type '{file_path.suffix}' is not supported")
|
||||
# for TXT, it's enough to read the TagLine
|
||||
if file_path.suffix == '.txt':
|
||||
with open(file_path, "r") as fd:
|
||||
try:
|
||||
tags = TagLine(fd.readline()).tags(prefix, contain)
|
||||
except TagError:
|
||||
pass # message without tags
|
||||
else: # '.yaml'
|
||||
try:
|
||||
message = cls.from_file(file_path)
|
||||
if message:
|
||||
msg_tags = message.filter_tags(prefix=prefix, contain=contain)
|
||||
except MessageError as e:
|
||||
print(f"Error processing message in '{file_path}': {str(e)}")
|
||||
if msg_tags:
|
||||
tags = msg_tags
|
||||
return tags
|
||||
|
||||
@classmethod
|
||||
def tags_from_dir(cls: Type[MessageInst],
|
||||
path: pathlib.Path,
|
||||
glob: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
contain: Optional[str] = None) -> set[Tag]:
|
||||
|
||||
"""
|
||||
Return only the tags from message files in the given directory.
|
||||
The files can be filtered using 'glob', the tags by using 'prefix'
|
||||
and 'contain'.
|
||||
"""
|
||||
tags: set[Tag] = set()
|
||||
file_iter = path.glob(glob) if glob else path.iterdir()
|
||||
for file_path in sorted(file_iter):
|
||||
if file_path.is_file():
|
||||
try:
|
||||
tags |= cls.tags_from_file(file_path, prefix, contain)
|
||||
except MessageError as e:
|
||||
print(f"Error processing message in '{file_path}': {str(e)}")
|
||||
return tags
|
||||
|
||||
@classmethod
|
||||
def from_file(cls: Type[MessageInst], file_path: pathlib.Path,
|
||||
mfilter: Optional[MessageFilter] = None) -> Optional[MessageInst]:
|
||||
"""
|
||||
Create a Message from the given file. Returns 'None' if the message does
|
||||
not fulfill the filter requirements. For TXT files, the tags are matched
|
||||
before building the whole message. The other filters are applied afterwards.
|
||||
"""
|
||||
if not file_path.exists():
|
||||
raise MessageError(f"Message file '{file_path}' does not exist")
|
||||
if file_path.suffix not in cls.file_suffixes:
|
||||
raise MessageError(f"File type '{file_path.suffix}' is not supported")
|
||||
|
||||
if file_path.suffix == '.txt':
|
||||
message = cls.__from_file_txt(file_path,
|
||||
mfilter.tags_or if mfilter else None,
|
||||
mfilter.tags_and if mfilter else None,
|
||||
mfilter.tags_not if mfilter else None)
|
||||
else:
|
||||
message = cls.__from_file_yaml(file_path)
|
||||
if message and (mfilter is None or message.match(mfilter)):
|
||||
return message
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def __from_file_txt(cls: Type[MessageInst], file_path: pathlib.Path, # noqa: 11
|
||||
tags_or: Optional[set[Tag]] = None,
|
||||
tags_and: Optional[set[Tag]] = None,
|
||||
tags_not: Optional[set[Tag]] = None) -> Optional[MessageInst]:
|
||||
"""
|
||||
Create a Message from the given TXT file. Expects the following file structures:
|
||||
For '.txt':
|
||||
* TagLine [Optional]
|
||||
* AI [Optional]
|
||||
* Model [Optional]
|
||||
* Question.txt_header
|
||||
* Question
|
||||
* Answer.txt_header [Optional]
|
||||
* Answer [Optional]
|
||||
|
||||
Returns 'None' if the message does not fulfill the tag requirements.
|
||||
"""
|
||||
tags: set[Tag] = set()
|
||||
question: Question
|
||||
answer: Optional[Answer] = None
|
||||
ai: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
with open(file_path, "r") as fd:
|
||||
# TagLine (Optional)
|
||||
try:
|
||||
pos = fd.tell()
|
||||
tags = TagLine(fd.readline()).tags()
|
||||
except TagError:
|
||||
fd.seek(pos)
|
||||
if tags_or or tags_and or tags_not:
|
||||
# match with an empty set if the file has no tags
|
||||
if not match_tags(tags, tags_or, tags_and, tags_not):
|
||||
return None
|
||||
# AILine (Optional)
|
||||
try:
|
||||
pos = fd.tell()
|
||||
ai = AILine(fd.readline()).ai()
|
||||
except MessageError:
|
||||
fd.seek(pos)
|
||||
# ModelLine (Optional)
|
||||
try:
|
||||
pos = fd.tell()
|
||||
model = ModelLine(fd.readline()).model()
|
||||
except MessageError:
|
||||
fd.seek(pos)
|
||||
# Question and Answer
|
||||
text = fd.read().strip().split('\n')
|
||||
try:
|
||||
question_idx = text.index(Question.txt_header) + 1
|
||||
except ValueError:
|
||||
raise MessageError(f"Question header '{Question.txt_header}' not found in '{file_path}'")
|
||||
try:
|
||||
answer_idx = text.index(Answer.txt_header)
|
||||
question = Question.from_list(text[question_idx:answer_idx])
|
||||
answer = Answer.from_list(text[answer_idx + 1:])
|
||||
except ValueError:
|
||||
question = Question.from_list(text[question_idx:])
|
||||
return cls(question, answer, tags, ai, model, file_path)
|
||||
|
||||
@classmethod
|
||||
def __from_file_yaml(cls: Type[MessageInst], file_path: pathlib.Path) -> MessageInst:
|
||||
"""
|
||||
Create a Message from the given YAML file. Expects the following file structures:
|
||||
* Question.yaml_key: single or multiline string
|
||||
* Answer.yaml_key: single or multiline string [Optional]
|
||||
* Message.tags_yaml_key: list of strings [Optional]
|
||||
* Message.ai_yaml_key: str [Optional]
|
||||
* Message.model_yaml_key: str [Optional]
|
||||
"""
|
||||
with open(file_path, "r") as fd:
|
||||
data = yaml.load(fd, Loader=yaml.FullLoader)
|
||||
data[cls.file_yaml_key] = file_path
|
||||
return cls.from_dict(data)
|
||||
|
||||
def to_str(self, with_tags: bool = False, with_file: bool = False, source_code_only: bool = False) -> str:
|
||||
"""
|
||||
Return the current Message as a string.
|
||||
"""
|
||||
output: list[str] = []
|
||||
if source_code_only:
|
||||
# use the source code from answer only
|
||||
if self.answer:
|
||||
output.extend(self.answer.source_code(include_delims=True))
|
||||
return '\n'.join(output) if len(output) > 0 else ''
|
||||
if with_tags:
|
||||
output.append(self.tags_str())
|
||||
if with_file:
|
||||
output.append('FILE: ' + str(self.file_path))
|
||||
output.append(Question.txt_header)
|
||||
output.append(self.question)
|
||||
if self.answer:
|
||||
output.append(Answer.txt_header)
|
||||
output.append(self.answer)
|
||||
return '\n'.join(output)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_str(True, True, False)
|
||||
|
||||
def to_file(self, file_path: Optional[pathlib.Path]=None) -> None: # noqa: 11
|
||||
"""
|
||||
Write a Message to the given file. Type is determined based on the suffix.
|
||||
Currently supported suffixes: ['.txt', '.yaml']
|
||||
"""
|
||||
if file_path:
|
||||
self.file_path = file_path
|
||||
if not self.file_path:
|
||||
raise MessageError("Got no valid path to write message")
|
||||
if self.file_path.suffix not in self.file_suffixes:
|
||||
raise MessageError(f"File type '{self.file_path.suffix}' is not supported")
|
||||
# TXT
|
||||
if self.file_path.suffix == '.txt':
|
||||
return self.__to_file_txt(self.file_path)
|
||||
elif self.file_path.suffix == '.yaml':
|
||||
return self.__to_file_yaml(self.file_path)
|
||||
|
||||
def __to_file_txt(self, file_path: pathlib.Path) -> None:
|
||||
"""
|
||||
Write a Message to the given file in TXT format.
|
||||
Creates the following file structures:
|
||||
* TagLine
|
||||
* AI [Optional]
|
||||
* Model [Optional]
|
||||
* Question.txt_header
|
||||
* Question
|
||||
* Answer.txt_header
|
||||
* Answer
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
|
||||
temp_file_path = pathlib.Path(temp_fd.name)
|
||||
if self.tags:
|
||||
temp_fd.write(f'{TagLine.from_set(self.tags)}\n')
|
||||
if self.ai:
|
||||
temp_fd.write(f'{AILine.from_ai(self.ai)}\n')
|
||||
if self.model:
|
||||
temp_fd.write(f'{ModelLine.from_model(self.model)}\n')
|
||||
temp_fd.write(f'{Question.txt_header}\n{self.question}\n')
|
||||
if self.answer:
|
||||
temp_fd.write(f'{Answer.txt_header}\n{self.answer}\n')
|
||||
shutil.move(temp_file_path, file_path)
|
||||
|
||||
def __to_file_yaml(self, file_path: pathlib.Path) -> None:
|
||||
"""
|
||||
Write a Message to the given file in YAML format.
|
||||
Creates the following file structures:
|
||||
* Question.yaml_key: single or multiline string
|
||||
* Answer.yaml_key: single or multiline string
|
||||
* Message.tags_yaml_key: list of strings
|
||||
* Message.ai_yaml_key: str [Optional]
|
||||
* Message.model_yaml_key: str [Optional]
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(dir=file_path.parent, prefix=file_path.name, mode="w", delete=False) as temp_fd:
|
||||
temp_file_path = pathlib.Path(temp_fd.name)
|
||||
data: YamlDict = {Question.yaml_key: str(self.question)}
|
||||
if self.answer:
|
||||
data[Answer.yaml_key] = str(self.answer)
|
||||
if self.ai:
|
||||
data[self.ai_yaml_key] = self.ai
|
||||
if self.model:
|
||||
data[self.model_yaml_key] = self.model
|
||||
if self.tags:
|
||||
data[self.tags_yaml_key] = sorted([str(tag) for tag in self.tags])
|
||||
yaml.dump(data, temp_fd, sort_keys=False)
|
||||
shutil.move(temp_file_path, file_path)
|
||||
|
||||
def filter_tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
|
||||
"""
|
||||
Filter tags based on their prefix (i. e. the tag starts with a given string)
|
||||
or some contained string.
|
||||
"""
|
||||
if not self.tags:
|
||||
return set()
|
||||
res_tags = self.tags.copy()
|
||||
if prefix and len(prefix) > 0:
|
||||
res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)}
|
||||
if contain and len(contain) > 0:
|
||||
res_tags -= {tag for tag in res_tags if contain not in tag}
|
||||
return res_tags
|
||||
|
||||
def tags_str(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> str:
|
||||
"""
|
||||
Returns all tags as a string with the TagLine prefix. Optionally filtered
|
||||
using 'Message.filter_tags()'.
|
||||
"""
|
||||
if self.tags:
|
||||
return str(TagLine.from_set(self.filter_tags(prefix, contain)))
|
||||
else:
|
||||
return str(TagLine.from_set(set()))
|
||||
|
||||
def match(self, mfilter: MessageFilter) -> bool: # noqa: 13
|
||||
"""
|
||||
Matches the current Message to the given filter atttributes.
|
||||
Return True if all attributes match, else False.
|
||||
"""
|
||||
mytags = self.tags or set()
|
||||
if (((mfilter.tags_or is not None or mfilter.tags_and is not None or mfilter.tags_not is not None)
|
||||
and not match_tags(mytags, mfilter.tags_or, mfilter.tags_and, mfilter.tags_not)) # noqa: W503
|
||||
or (mfilter.ai and (not self.ai or mfilter.ai != self.ai)) # noqa: W503
|
||||
or (mfilter.model and (not self.model or mfilter.model != self.model)) # noqa: W503
|
||||
or (mfilter.question_contains and mfilter.question_contains not in self.question) # noqa: W503
|
||||
or (mfilter.answer_contains and (not self.answer or mfilter.answer_contains not in self.answer)) # noqa: W503
|
||||
or (mfilter.answer_state == 'available' and not self.answer) # noqa: W503
|
||||
or (mfilter.ai_state == 'available' and not self.ai) # noqa: W503
|
||||
or (mfilter.model_state == 'available' and not self.model) # noqa: W503
|
||||
or (mfilter.answer_state == 'missing' and self.answer) # noqa: W503
|
||||
or (mfilter.ai_state == 'missing' and self.ai) # noqa: W503
|
||||
or (mfilter.model_state == 'missing' and self.model)): # noqa: W503
|
||||
return False
|
||||
return True
|
||||
|
||||
def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> None:
|
||||
"""
|
||||
Renames the given tags. The first tuple element is the old name,
|
||||
the second one is the new name.
|
||||
"""
|
||||
if self.tags:
|
||||
self.tags = rename_tags(self.tags, tags_rename)
|
||||
|
||||
def msg_id(self) -> str:
|
||||
"""
|
||||
Returns an ID that is unique throughout all messages in the same (DB) directory.
|
||||
Currently this is the file name. The ID is also used for sorting messages.
|
||||
"""
|
||||
if self.file_path:
|
||||
return self.file_path.name
|
||||
else:
|
||||
raise MessageError("Can't create file ID without a file path")
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
def tokens(self) -> int:
|
||||
"""
|
||||
Returns the nr. of AI language tokens used by this message.
|
||||
If unknown, 0 is returned.
|
||||
"""
|
||||
if self.answer:
|
||||
return self.question.tokens + self.answer.tokens
|
||||
else:
|
||||
return self.question.tokens
|
||||
@ -1,115 +0,0 @@
|
||||
import yaml
|
||||
import io
|
||||
import pathlib
|
||||
from .utils import terminal_width, append_message, message_to_chat
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
|
||||
def read_file(fname: str, tags_only: bool = False) -> Dict[str, Any]:
|
||||
with open(fname, "r") as fd:
|
||||
if tags_only:
|
||||
return {"tags": [x.strip() for x in fd.readline().strip().split(':')[1].strip().split(',')]}
|
||||
text = fd.read().strip().split('\n')
|
||||
tags = [x.strip() for x in text.pop(0).split(':')[1].strip().split(',')]
|
||||
question_idx = text.index("=== QUESTION ===") + 1
|
||||
answer_idx = text.index("==== ANSWER ====")
|
||||
question = "\n".join(text[question_idx:answer_idx]).strip()
|
||||
answer = "\n".join(text[answer_idx + 1:]).strip()
|
||||
return {"question": question, "answer": answer, "tags": tags,
|
||||
"file": pathlib.Path(fname).name}
|
||||
|
||||
|
||||
def dump_data(data: Dict[str, Any]) -> str:
|
||||
with io.StringIO() as fd:
|
||||
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
|
||||
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
|
||||
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
|
||||
return fd.getvalue()
|
||||
|
||||
|
||||
def write_file(fname: str, data: Dict[str, Any]) -> None:
|
||||
with open(fname, "w") as fd:
|
||||
fd.write(f'TAGS: {", ".join(data["tags"])}\n')
|
||||
fd.write(f'=== QUESTION ===\n{data["question"]}\n')
|
||||
fd.write(f'==== ANSWER ====\n{data["answer"]}\n')
|
||||
|
||||
|
||||
def save_answers(question: str,
|
||||
answers: list[str],
|
||||
tags: list[str],
|
||||
otags: Optional[list[str]],
|
||||
config: Dict[str, Any]
|
||||
) -> None:
|
||||
wtags = otags or tags
|
||||
num, inum = 0, 0
|
||||
next_fname = pathlib.Path(config['db']) / '.next'
|
||||
try:
|
||||
with open(next_fname, 'r') as f:
|
||||
num = int(f.read())
|
||||
except Exception:
|
||||
pass
|
||||
for answer in answers:
|
||||
num += 1
|
||||
inum += 1
|
||||
title = f'-- ANSWER {inum} '
|
||||
title_end = '-' * (terminal_width() - len(title))
|
||||
print(f'{title}{title_end}')
|
||||
print(answer)
|
||||
write_file(f"{num:04d}.txt", {"question": question, "answer": answer, "tags": wtags})
|
||||
with open(next_fname, 'w') as f:
|
||||
f.write(f'{num}')
|
||||
|
||||
|
||||
def create_chat(question: Optional[str],
|
||||
tags: Optional[List[str]],
|
||||
extags: Optional[List[str]],
|
||||
config: Dict[str, Any],
|
||||
match_all_tags: bool = False,
|
||||
with_tags: bool = False,
|
||||
with_file: bool = False
|
||||
) -> List[Dict[str, str]]:
|
||||
chat: List[Dict[str, str]] = []
|
||||
append_message(chat, 'system', config['system'].strip())
|
||||
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
||||
if file.suffix == '.yaml':
|
||||
with open(file, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
elif file.suffix == '.txt':
|
||||
data = read_file(file)
|
||||
else:
|
||||
continue
|
||||
data_tags = set(data.get('tags', []))
|
||||
tags_match: bool
|
||||
if match_all_tags:
|
||||
tags_match = not tags or set(tags).issubset(data_tags)
|
||||
else:
|
||||
tags_match = not tags or bool(data_tags.intersection(tags))
|
||||
extags_do_not_match = \
|
||||
not extags or not data_tags.intersection(extags)
|
||||
if tags_match and extags_do_not_match:
|
||||
message_to_chat(data, chat, with_tags, with_file)
|
||||
if question:
|
||||
append_message(chat, 'user', question)
|
||||
return chat
|
||||
|
||||
|
||||
def get_tags(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
|
||||
result = []
|
||||
for file in sorted(pathlib.Path(config['db']).iterdir()):
|
||||
if file.suffix == '.yaml':
|
||||
with open(file, 'r') as f:
|
||||
data = yaml.load(f, Loader=yaml.FullLoader)
|
||||
elif file.suffix == '.txt':
|
||||
data = read_file(file, tags_only=True)
|
||||
else:
|
||||
continue
|
||||
for tag in data.get('tags', []):
|
||||
if prefix and len(prefix) > 0:
|
||||
if tag.startswith(prefix):
|
||||
result.append(tag)
|
||||
else:
|
||||
result.append(tag)
|
||||
return result
|
||||
|
||||
def get_tags_unique(config: Dict[str, Any], prefix: Optional[str]) -> List[str]:
|
||||
return list(set(get_tags(config, prefix)))
|
||||
184
chatmastermind/tags.py
Normal file
184
chatmastermind/tags.py
Normal file
@ -0,0 +1,184 @@
|
||||
"""
|
||||
Module implementing tag related functions and classes.
|
||||
"""
|
||||
from typing import Type, TypeVar, Optional, Final
|
||||
|
||||
TagInst = TypeVar('TagInst', bound='Tag')
|
||||
TagLineInst = TypeVar('TagLineInst', bound='TagLine')
|
||||
|
||||
|
||||
class TagError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Tag(str):
|
||||
"""
|
||||
A single tag. A string that can contain anything but the default separator (' ').
|
||||
"""
|
||||
# default separator
|
||||
default_separator: Final[str] = ' '
|
||||
# alternative separators (e. g. for backwards compatibility)
|
||||
alternative_separators: Final[list[str]] = [',']
|
||||
|
||||
def __new__(cls: Type[TagInst], string: str) -> TagInst:
|
||||
"""
|
||||
Make sure the tag string does not contain the default separator.
|
||||
"""
|
||||
if cls.default_separator in string:
|
||||
raise TagError(f"Tag '{string}' contains the separator char '{cls.default_separator}'")
|
||||
instance = super().__new__(cls, string)
|
||||
return instance
|
||||
|
||||
|
||||
def delete_tags(tags: set[Tag], tags_delete: set[Tag]) -> set[Tag]:
|
||||
"""
|
||||
Deletes the given tags and returns a new set.
|
||||
"""
|
||||
return tags.difference(tags_delete)
|
||||
|
||||
|
||||
def add_tags(tags: set[Tag], tags_add: set[Tag]) -> set[Tag]:
|
||||
"""
|
||||
Adds the given tags and returns a new set.
|
||||
"""
|
||||
return set(sorted(tags | tags_add))
|
||||
|
||||
|
||||
def merge_tags(tags: set[Tag], tags_merge: list[set[Tag]]) -> set[Tag]:
|
||||
"""
|
||||
Merges the tags in 'tags_merge' into the current one and returns a new set.
|
||||
"""
|
||||
for ts in tags_merge:
|
||||
tags |= ts
|
||||
return tags
|
||||
|
||||
|
||||
def rename_tags(tags: set[Tag], tags_rename: set[tuple[Tag, Tag]]) -> set[Tag]:
|
||||
"""
|
||||
Renames the given tags and returns a new set. The first tuple element
|
||||
is the old name, the second one is the new name.
|
||||
"""
|
||||
for t in tags_rename:
|
||||
if t[0] in tags:
|
||||
tags.remove(t[0])
|
||||
tags.add(t[1])
|
||||
return set(sorted(tags))
|
||||
|
||||
|
||||
def match_tags(tags: set[Tag], tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]],
|
||||
tags_not: Optional[set[Tag]]) -> bool:
|
||||
"""
|
||||
Checks if the given set 'tags' matches the given tag requirements:
|
||||
- 'tags_or' : matches if this TagLine contains ANY of those tags
|
||||
- 'tags_and': matches if this TagLine contains ALL of those tags
|
||||
- 'tags_not': matches if this TagLine contains NONE of those tags
|
||||
|
||||
Note that it's sufficient if 'tags' matches one of 'tags_or' or 'tags_and',
|
||||
i. e. you can select a TagLine if it either contains one of the tags in 'tags_or'
|
||||
or all of the tags in 'tags_and' but it must never contain any of the tags in
|
||||
'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag
|
||||
exclusion is still done if 'tags_not' is not 'None'). If they are empty (set()),
|
||||
they match no tags.
|
||||
"""
|
||||
required_tags_present = False
|
||||
excluded_tags_missing = False
|
||||
if ((tags_or is None and tags_and is None)
|
||||
or (tags_or and any(tag in tags for tag in tags_or)) # noqa: W503
|
||||
or (tags_and and all(tag in tags for tag in tags_and))): # noqa: W503
|
||||
required_tags_present = True
|
||||
if ((tags_not is None)
|
||||
or (not any(tag in tags for tag in tags_not))): # noqa: W503
|
||||
excluded_tags_missing = True
|
||||
return required_tags_present and excluded_tags_missing
|
||||
|
||||
|
||||
class TagLine(str):
|
||||
"""
|
||||
A line of tags in a '.txt' file. It starts with a prefix ('TAGS:'), followed by
|
||||
a list of tags, separated by the defaut separator (' '). Any operations on a
|
||||
TagLine will sort the tags.
|
||||
"""
|
||||
# the prefix
|
||||
prefix: Final[str] = 'TAGS:'
|
||||
|
||||
def __new__(cls: Type[TagLineInst], string: str) -> TagLineInst:
|
||||
"""
|
||||
Make sure the tagline string starts with the prefix. Also replace newlines
|
||||
and multiple spaces with ' ', in order to support multiline TagLines.
|
||||
"""
|
||||
if not string.startswith(cls.prefix):
|
||||
raise TagError(f"TagLine '{string}' is missing prefix '{cls.prefix}'")
|
||||
string = ' '.join(string.split())
|
||||
instance = super().__new__(cls, string)
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def from_set(cls: Type[TagLineInst], tags: set[Tag]) -> TagLineInst:
|
||||
"""
|
||||
Create a new TagLine from a set of tags.
|
||||
"""
|
||||
return cls(' '.join([cls.prefix] + sorted([t for t in tags])))
|
||||
|
||||
def tags(self, prefix: Optional[str] = None, contain: Optional[str] = None) -> set[Tag]:
|
||||
"""
|
||||
Returns all tags contained in this line as a set, optionally
|
||||
filtered based on prefix or contained string.
|
||||
"""
|
||||
tagstr = self[len(self.prefix):].strip()
|
||||
if tagstr == '':
|
||||
return set() # no tags, only prefix
|
||||
separator = Tag.default_separator
|
||||
# look for alternative separators and use the first one found
|
||||
# -> we don't support different separators in the same TagLine
|
||||
for s in Tag.alternative_separators:
|
||||
if s in tagstr:
|
||||
separator = s
|
||||
break
|
||||
res_tags = set(sorted([Tag(t.strip()) for t in tagstr.split(separator)]))
|
||||
if prefix and len(prefix) > 0:
|
||||
res_tags -= {tag for tag in res_tags if not tag.startswith(prefix)}
|
||||
if contain and len(contain) > 0:
|
||||
res_tags -= {tag for tag in res_tags if contain not in tag}
|
||||
return res_tags or set()
|
||||
|
||||
def merge(self, taglines: set['TagLine']) -> 'TagLine':
|
||||
"""
|
||||
Merges the tags of all given taglines into the current one and returns a new TagLine.
|
||||
"""
|
||||
tags_merge = [tl.tags() for tl in taglines]
|
||||
return self.from_set(merge_tags(self.tags(), tags_merge))
|
||||
|
||||
def delete_tags(self, tags_delete: set[Tag]) -> 'TagLine':
|
||||
"""
|
||||
Deletes the given tags and returns a new TagLine.
|
||||
"""
|
||||
return self.from_set(delete_tags(self.tags(), tags_delete))
|
||||
|
||||
def add_tags(self, tags_add: set[Tag]) -> 'TagLine':
|
||||
"""
|
||||
Adds the given tags and returns a new TagLine.
|
||||
"""
|
||||
return self.from_set(add_tags(self.tags(), tags_add))
|
||||
|
||||
def rename_tags(self, tags_rename: set[tuple[Tag, Tag]]) -> 'TagLine':
|
||||
"""
|
||||
Renames the given tags and returns a new TagLine. The first
|
||||
tuple element is the old name, the second one is the new name.
|
||||
"""
|
||||
return self.from_set(rename_tags(self.tags(), tags_rename))
|
||||
|
||||
def match_tags(self, tags_or: Optional[set[Tag]], tags_and: Optional[set[Tag]],
|
||||
tags_not: Optional[set[Tag]]) -> bool:
|
||||
"""
|
||||
Checks if the current TagLine matches the given tag requirements:
|
||||
- 'tags_or' : matches if this TagLine contains ANY of those tags
|
||||
- 'tags_and': matches if this TagLine contains ALL of those tags
|
||||
- 'tags_not': matches if this TagLine contains NONE of those tags
|
||||
|
||||
Note that it's sufficient if the TagLine matches one of 'tags_or' or 'tags_and',
|
||||
i. e. you can select a TagLine if it either contains one of the tags in 'tags_or'
|
||||
or all of the tags in 'tags_and' but it must never contain any of the tags in
|
||||
'tags_not'. If 'tags_or' and 'tags_and' are 'None', they match all tags (tag
|
||||
exclusion is still done if 'tags_not' is not 'None').
|
||||
"""
|
||||
return match_tags(self.tags(), tags_or, tags_and, tags_not)
|
||||
@ -1,82 +0,0 @@
|
||||
import shutil
|
||||
from pprint import PrettyPrinter
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
def terminal_width() -> int:
|
||||
return shutil.get_terminal_size().columns
|
||||
|
||||
|
||||
def pp(*args, **kwargs) -> None:
|
||||
return PrettyPrinter(width=terminal_width()).pprint(*args, **kwargs)
|
||||
|
||||
|
||||
def process_tags(tags: list[str], extags: list[str], otags: list[str]) -> None:
|
||||
printed_messages = []
|
||||
|
||||
if tags:
|
||||
printed_messages.append(f"Tags: {', '.join(tags)}")
|
||||
if extags:
|
||||
printed_messages.append(f"Excluding tags: {', '.join(extags)}")
|
||||
if otags:
|
||||
printed_messages.append(f"Output tags: {', '.join(otags)}")
|
||||
|
||||
if printed_messages:
|
||||
print("\n".join(printed_messages))
|
||||
print()
|
||||
|
||||
|
||||
def append_message(chat: List[Dict[str, str]],
|
||||
role: str,
|
||||
content: str
|
||||
) -> None:
|
||||
chat.append({'role': role, 'content': content.replace("''", "'")})
|
||||
|
||||
|
||||
def message_to_chat(message: Dict[str, str],
|
||||
chat: List[Dict[str, str]],
|
||||
with_tags: bool = False,
|
||||
with_file: bool = False
|
||||
) -> None:
|
||||
append_message(chat, 'user', message['question'])
|
||||
append_message(chat, 'assistant', message['answer'])
|
||||
if with_tags:
|
||||
tags = ", ".join(message['tags'])
|
||||
append_message(chat, 'tags', tags)
|
||||
if with_file:
|
||||
append_message(chat, 'file', message['file'])
|
||||
|
||||
|
||||
def display_source_code(content: str) -> None:
|
||||
try:
|
||||
content_start = content.index('```')
|
||||
content_end = content.rindex('```')
|
||||
if content_start + 3 < content_end:
|
||||
print(content[content_start + 3:content_end].strip())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def display_chat(chat, dump=False, source_code=False) -> None:
|
||||
if dump:
|
||||
pp(chat)
|
||||
return
|
||||
for message in chat:
|
||||
text_too_long = len(message['content']) > terminal_width() - len(message['role']) - 2
|
||||
if source_code:
|
||||
display_source_code(message['content'])
|
||||
continue
|
||||
if message['role'] == 'user':
|
||||
print('-' * terminal_width())
|
||||
if text_too_long:
|
||||
print(f"{message['role'].upper()}:")
|
||||
print(message['content'])
|
||||
else:
|
||||
print(f"{message['role'].upper()}: {message['content']}")
|
||||
|
||||
def display_tags_frequency(tags: List[str], dump=False) -> None:
|
||||
if dump:
|
||||
pp(tags)
|
||||
return
|
||||
for tag in set(tags):
|
||||
print(f"-{tag} : {tags.count(tag)}")
|
||||
56
hooks/gitea_cmm_hook.php
Normal file
56
hooks/gitea_cmm_hook.php
Normal file
@ -0,0 +1,56 @@
|
||||
<?php
|
||||
|
||||
$secret_key = '123';
|
||||
|
||||
// check for POST request
|
||||
if ($_SERVER['REQUEST_METHOD'] != 'POST') {
|
||||
error_log('FAILED - not POST - '. $_SERVER['REQUEST_METHOD']);
|
||||
exit();
|
||||
}
|
||||
|
||||
// get content type
|
||||
$content_type = isset($_SERVER['CONTENT_TYPE']) ? strtolower(trim($_SERVER['CONTENT_TYPE'])) : '';
|
||||
|
||||
if ($content_type != 'application/json') {
|
||||
error_log('FAILED - not application/json - '. $content_type);
|
||||
exit();
|
||||
}
|
||||
|
||||
// get payload
|
||||
$payload = trim(file_get_contents("php://input"));
|
||||
|
||||
if (empty($payload)) {
|
||||
error_log('FAILED - no payload');
|
||||
exit();
|
||||
}
|
||||
|
||||
// get header signature
|
||||
$header_signature = isset($_SERVER['HTTP_X_GITEA_SIGNATURE']) ? $_SERVER['HTTP_X_GITEA_SIGNATURE'] : '';
|
||||
|
||||
if (empty($header_signature)) {
|
||||
error_log('FAILED - header signature missing');
|
||||
exit();
|
||||
}
|
||||
|
||||
// calculate payload signature
|
||||
$payload_signature = hash_hmac('sha256', $payload, $secret_key, false);
|
||||
|
||||
// check payload signature against header signature
|
||||
if ($header_signature !== $payload_signature) {
|
||||
error_log('FAILED - payload signature');
|
||||
exit();
|
||||
}
|
||||
|
||||
// convert json to array
|
||||
$decoded = json_decode($payload, true);
|
||||
|
||||
// check for json decode errors
|
||||
if (json_last_error() !== JSON_ERROR_NONE) {
|
||||
error_log('FAILED - json decode - '. json_last_error());
|
||||
exit();
|
||||
}
|
||||
|
||||
// success, do something
|
||||
$output = shell_exec('/home/kaizen/repos/ChatMastermind/hooks/push_hook.sh');
|
||||
echo "$output";
|
||||
?>
|
||||
8
hooks/push_hook.sh
Executable file
8
hooks/push_hook.sh
Executable file
@ -0,0 +1,8 @@
|
||||
#!/usr/bin/bash
|
||||
|
||||
. /home/kaizen/.bashrc
|
||||
set -e
|
||||
cd /home/kaizen/repos/ChatMastermind
|
||||
git pull
|
||||
pre-commit run -a
|
||||
pytest
|
||||
1
mypy.ini
1
mypy.ini
@ -5,3 +5,4 @@ strict_optional = True
|
||||
warn_unused_ignores = False
|
||||
warn_redundant_casts = True
|
||||
warn_unused_configs = True
|
||||
disallow_untyped_defs = True
|
||||
|
||||
14
setup.py
14
setup.py
@ -12,23 +12,29 @@ setup(
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/ok2/ChatMastermind",
|
||||
packages=find_packages(),
|
||||
packages=find_packages() + ["chatmastermind.ais", "chatmastermind.commands"],
|
||||
classifiers=[
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Environment :: Console",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Intended Audience :: End Users/Desktop",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Utilities",
|
||||
"Topic :: Text Processing",
|
||||
],
|
||||
install_requires=[
|
||||
"openai",
|
||||
"PyYAML",
|
||||
"argcomplete",
|
||||
"pytest"
|
||||
"pytest",
|
||||
],
|
||||
python_requires=">=3.10",
|
||||
python_requires=">=3.9",
|
||||
test_suite="tests",
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
|
||||
48
tests/test_ai_factory.py
Normal file
48
tests/test_ai_factory.py
Normal file
@ -0,0 +1,48 @@
|
||||
import argparse
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
from chatmastermind.ai_factory import create_ai
|
||||
from chatmastermind.configuration import Config
|
||||
from chatmastermind.ai import AIError
|
||||
from chatmastermind.ais.openai import OpenAI
|
||||
|
||||
|
||||
class TestCreateAI(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.args = MagicMock(spec=argparse.Namespace)
|
||||
self.args.AI = 'myopenai'
|
||||
self.args.model = None
|
||||
self.args.max_tokens = None
|
||||
self.args.temperature = None
|
||||
|
||||
def test_create_ai_from_args(self) -> None:
|
||||
# Create an AI with the default configuration
|
||||
config = Config()
|
||||
self.args.AI = 'myopenai'
|
||||
ai = create_ai(self.args, config)
|
||||
self.assertIsInstance(ai, OpenAI)
|
||||
|
||||
def test_create_ai_from_default(self) -> None:
|
||||
self.args.AI = None
|
||||
# Create an AI with the default configuration
|
||||
config = Config()
|
||||
ai = create_ai(self.args, config)
|
||||
self.assertIsInstance(ai, OpenAI)
|
||||
|
||||
def test_create_empty_ai_error(self) -> None:
|
||||
self.args.AI = None
|
||||
# Create Config with empty AIs
|
||||
config = Config()
|
||||
config.ais = {}
|
||||
# Call create_ai function and assert that it raises AIError
|
||||
with self.assertRaises(AIError):
|
||||
create_ai(self.args, config)
|
||||
|
||||
def test_create_unsupported_ai_error(self) -> None:
|
||||
# Mock argparse.Namespace with ai='invalid_ai'
|
||||
self.args.AI = 'invalid_ai'
|
||||
# Create default Config
|
||||
config = Config()
|
||||
# Call create_ai function and assert that it raises AIError
|
||||
with self.assertRaises(AIError):
|
||||
create_ai(self.args, config)
|
||||
81
tests/test_ais_openai.py
Normal file
81
tests/test_ais_openai.py
Normal file
@ -0,0 +1,81 @@
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from chatmastermind.ais.openai import OpenAI
|
||||
from chatmastermind.message import Message, Question, Answer
|
||||
from chatmastermind.chat import Chat
|
||||
from chatmastermind.ai import AIResponse, Tokens
|
||||
from chatmastermind.configuration import OpenAIConfig
|
||||
|
||||
|
||||
class OpenAITest(unittest.TestCase):
|
||||
|
||||
@mock.patch('openai.ChatCompletion.create')
|
||||
def test_request(self, mock_create: mock.MagicMock) -> None:
|
||||
# Create a test instance of OpenAI
|
||||
config = OpenAIConfig()
|
||||
openai = OpenAI(config)
|
||||
|
||||
# Set up the mock response from openai.ChatCompletion.create
|
||||
mock_response = {
|
||||
'choices': [
|
||||
{
|
||||
'message': {
|
||||
'content': 'Answer 1'
|
||||
}
|
||||
},
|
||||
{
|
||||
'message': {
|
||||
'content': 'Answer 2'
|
||||
}
|
||||
}
|
||||
],
|
||||
'usage': {
|
||||
'prompt_tokens': 10,
|
||||
'completion_tokens': 20,
|
||||
'total_tokens': 30
|
||||
}
|
||||
}
|
||||
mock_create.return_value = mock_response
|
||||
|
||||
# Create test data
|
||||
question = Message(Question('Question'))
|
||||
chat = Chat([
|
||||
Message(Question('Question 1'), answer=Answer('Answer 1')),
|
||||
Message(Question('Question 2'), answer=Answer('Answer 2')),
|
||||
# add message without an answer -> expect to be skipped
|
||||
Message(Question('Question 3'))
|
||||
])
|
||||
|
||||
# Make the request
|
||||
response = openai.request(question, chat, num_answers=2)
|
||||
|
||||
# Assert the AIResponse
|
||||
self.assertIsInstance(response, AIResponse)
|
||||
self.assertEqual(len(response.messages), 2)
|
||||
self.assertEqual(response.messages[0].answer, 'Answer 1')
|
||||
self.assertEqual(response.messages[1].answer, 'Answer 2')
|
||||
self.assertIsNotNone(response.tokens)
|
||||
self.assertIsInstance(response.tokens, Tokens)
|
||||
assert response.tokens
|
||||
self.assertEqual(response.tokens.prompt, 10)
|
||||
self.assertEqual(response.tokens.completion, 20)
|
||||
self.assertEqual(response.tokens.total, 30)
|
||||
|
||||
# Assert the mock call to openai.ChatCompletion.create
|
||||
mock_create.assert_called_once_with(
|
||||
model=f'{config.model}',
|
||||
messages=[
|
||||
{'role': 'system', 'content': f'{config.system}'},
|
||||
{'role': 'user', 'content': 'Question 1'},
|
||||
{'role': 'assistant', 'content': 'Answer 1'},
|
||||
{'role': 'user', 'content': 'Question 2'},
|
||||
{'role': 'assistant', 'content': 'Answer 2'},
|
||||
{'role': 'user', 'content': 'Question'}
|
||||
],
|
||||
temperature=config.temperature,
|
||||
max_tokens=config.max_tokens,
|
||||
top_p=config.top_p,
|
||||
n=2,
|
||||
frequency_penalty=config.frequency_penalty,
|
||||
presence_penalty=config.presence_penalty
|
||||
)
|
||||
476
tests/test_chat.py
Normal file
476
tests/test_chat.py
Normal file
@ -0,0 +1,476 @@
|
||||
import unittest
|
||||
import pathlib
|
||||
import tempfile
|
||||
import time
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
from chatmastermind.tags import TagLine
|
||||
from chatmastermind.message import Message, Question, Answer, Tag, MessageFilter
|
||||
from chatmastermind.chat import Chat, ChatDB, ChatError
|
||||
|
||||
|
||||
class TestChat(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.chat = Chat([])
|
||||
self.message1 = Message(Question('Question 1'),
|
||||
Answer('Answer 1'),
|
||||
{Tag('atag1'), Tag('btag2')},
|
||||
file_path=pathlib.Path('0001.txt'))
|
||||
self.message2 = Message(Question('Question 2'),
|
||||
Answer('Answer 2'),
|
||||
{Tag('btag2')},
|
||||
file_path=pathlib.Path('0002.txt'))
|
||||
|
||||
def test_filter(self) -> None:
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
self.chat.filter(MessageFilter(answer_contains='Answer 1'))
|
||||
|
||||
self.assertEqual(len(self.chat.messages), 1)
|
||||
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
||||
|
||||
def test_sort(self) -> None:
|
||||
self.chat.add_messages([self.message2, self.message1])
|
||||
self.chat.sort()
|
||||
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
||||
self.assertEqual(self.chat.messages[1].question, 'Question 2')
|
||||
self.chat.sort(reverse=True)
|
||||
self.assertEqual(self.chat.messages[0].question, 'Question 2')
|
||||
self.assertEqual(self.chat.messages[1].question, 'Question 1')
|
||||
|
||||
def test_clear(self) -> None:
|
||||
self.chat.add_messages([self.message1])
|
||||
self.chat.clear()
|
||||
self.assertEqual(len(self.chat.messages), 0)
|
||||
|
||||
def test_add_messages(self) -> None:
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
self.assertEqual(len(self.chat.messages), 2)
|
||||
self.assertEqual(self.chat.messages[0].question, 'Question 1')
|
||||
self.assertEqual(self.chat.messages[1].question, 'Question 2')
|
||||
|
||||
def test_tags(self) -> None:
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
tags_all = self.chat.tags()
|
||||
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
|
||||
tags_pref = self.chat.tags(prefix='a')
|
||||
self.assertSetEqual(tags_pref, {Tag('atag1')})
|
||||
tags_cont = self.chat.tags(contain='2')
|
||||
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
||||
|
||||
def test_tags_frequency(self) -> None:
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
tags_freq = self.chat.tags_frequency()
|
||||
self.assertDictEqual(tags_freq, {'atag1': 1, 'btag2': 2})
|
||||
|
||||
def test_find_remove_messages(self) -> None:
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
msgs = self.chat.find_messages(['0001.txt'])
|
||||
self.assertListEqual(msgs, [self.message1])
|
||||
msgs = self.chat.find_messages(['0001.txt', '0002.txt'])
|
||||
self.assertListEqual(msgs, [self.message1, self.message2])
|
||||
# add new Message with full path
|
||||
message3 = Message(Question('Question 2'),
|
||||
Answer('Answer 2'),
|
||||
{Tag('btag2')},
|
||||
file_path=pathlib.Path('/foo/bla/0003.txt'))
|
||||
self.chat.add_messages([message3])
|
||||
# find new Message by full path
|
||||
msgs = self.chat.find_messages(['/foo/bla/0003.txt'])
|
||||
self.assertListEqual(msgs, [message3])
|
||||
# find Message with full path only by filename
|
||||
msgs = self.chat.find_messages(['0003.txt'])
|
||||
self.assertListEqual(msgs, [message3])
|
||||
# remove last message
|
||||
self.chat.remove_messages(['0003.txt'])
|
||||
self.assertListEqual(self.chat.messages, [self.message1, self.message2])
|
||||
|
||||
@patch('sys.stdout', new_callable=StringIO)
|
||||
def test_print(self, mock_stdout: StringIO) -> None:
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
self.chat.print(paged=False)
|
||||
expected_output = f"""{Question.txt_header}
|
||||
Question 1
|
||||
{Answer.txt_header}
|
||||
Answer 1
|
||||
{Question.txt_header}
|
||||
Question 2
|
||||
{Answer.txt_header}
|
||||
Answer 2
|
||||
"""
|
||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||
|
||||
@patch('sys.stdout', new_callable=StringIO)
|
||||
def test_print_with_tags_and_file(self, mock_stdout: StringIO) -> None:
|
||||
self.chat.add_messages([self.message1, self.message2])
|
||||
self.chat.print(paged=False, with_tags=True, with_files=True)
|
||||
expected_output = f"""{TagLine.prefix} atag1 btag2
|
||||
FILE: 0001.txt
|
||||
{Question.txt_header}
|
||||
Question 1
|
||||
{Answer.txt_header}
|
||||
Answer 1
|
||||
{TagLine.prefix} btag2
|
||||
FILE: 0002.txt
|
||||
{Question.txt_header}
|
||||
Question 2
|
||||
{Answer.txt_header}
|
||||
Answer 2
|
||||
"""
|
||||
self.assertEqual(mock_stdout.getvalue(), expected_output)
|
||||
|
||||
|
||||
class TestChatDB(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.db_path = tempfile.TemporaryDirectory()
|
||||
self.cache_path = tempfile.TemporaryDirectory()
|
||||
|
||||
self.message1 = Message(Question('Question 1'),
|
||||
Answer('Answer 1'),
|
||||
{Tag('tag1')},
|
||||
file_path=pathlib.Path('0001.txt'))
|
||||
self.message2 = Message(Question('Question 2'),
|
||||
Answer('Answer 2'),
|
||||
{Tag('tag2')},
|
||||
file_path=pathlib.Path('0002.yaml'))
|
||||
self.message3 = Message(Question('Question 3'),
|
||||
Answer('Answer 3'),
|
||||
{Tag('tag3')},
|
||||
file_path=pathlib.Path('0003.txt'))
|
||||
self.message4 = Message(Question('Question 4'),
|
||||
Answer('Answer 4'),
|
||||
{Tag('tag4')},
|
||||
file_path=pathlib.Path('0004.yaml'))
|
||||
|
||||
self.message1.to_file(pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.message2.to_file(pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.message3.to_file(pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.message4.to_file(pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
# make the next FID match the current state
|
||||
next_fname = pathlib.Path(self.db_path.name) / '.next'
|
||||
with open(next_fname, 'w') as f:
|
||||
f.write('4')
|
||||
|
||||
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[pathlib.Path]:
|
||||
"""
|
||||
List all Message files in the given TemporaryDirectory.
|
||||
"""
|
||||
# exclude '.next'
|
||||
return list(pathlib.Path(tmp_dir.name).glob('*.[ty]*'))
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.db_path.cleanup()
|
||||
self.cache_path.cleanup()
|
||||
pass
|
||||
|
||||
def test_chat_db_from_dir(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(len(chat_db.messages), 4)
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
# check that the files are sorted
|
||||
self.assertEqual(chat_db.messages[0].file_path,
|
||||
pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path,
|
||||
pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path,
|
||||
pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path,
|
||||
pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
def test_chat_db_from_dir_glob(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
glob='*.txt')
|
||||
self.assertEqual(len(chat_db.messages), 2)
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.messages[0].file_path,
|
||||
pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path,
|
||||
pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
|
||||
def test_chat_db_from_dir_filter_tags(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
mfilter=MessageFilter(tags_or={Tag('tag1')}))
|
||||
self.assertEqual(len(chat_db.messages), 1)
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.messages[0].file_path,
|
||||
pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
|
||||
def test_chat_db_from_dir_filter_tags_empty(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
mfilter=MessageFilter(tags_or=set(),
|
||||
tags_and=set(),
|
||||
tags_not=set()))
|
||||
self.assertEqual(len(chat_db.messages), 0)
|
||||
|
||||
def test_chat_db_from_dir_filter_answer(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
mfilter=MessageFilter(answer_contains='Answer 2'))
|
||||
self.assertEqual(len(chat_db.messages), 1)
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.messages[0].file_path,
|
||||
pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[0].answer, 'Answer 2')
|
||||
|
||||
def test_chat_db_from_messages(self) -> None:
|
||||
chat_db = ChatDB.from_messages(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name),
|
||||
messages=[self.message1, self.message2,
|
||||
self.message3, self.message4])
|
||||
self.assertEqual(len(chat_db.messages), 4)
|
||||
self.assertEqual(chat_db.cache_path, pathlib.Path(self.cache_path.name))
|
||||
self.assertEqual(chat_db.db_path, pathlib.Path(self.db_path.name))
|
||||
|
||||
def test_chat_db_fids(self) -> None:
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(chat_db.get_next_fid(), 5)
|
||||
self.assertEqual(chat_db.get_next_fid(), 6)
|
||||
self.assertEqual(chat_db.get_next_fid(), 7)
|
||||
with open(chat_db.next_fname, 'r') as f:
|
||||
self.assertEqual(f.read(), '7')
|
||||
|
||||
def test_chat_db_write(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
# check that Message.file_path is correct
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
# write the messages to the cache directory
|
||||
chat_db.write_cache()
|
||||
# check if the written files are in the cache directory
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0001.txt'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0002.yaml'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0003.txt'), cache_dir_files)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '0004.yaml'), cache_dir_files)
|
||||
# check that Message.file_path has been correctly updated
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.cache_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.cache_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.cache_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.cache_path.name, '0004.yaml'))
|
||||
|
||||
# check the timestamp of the files in the DB directory
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
old_timestamps = {file: file.stat().st_mtime for file in db_dir_files}
|
||||
# overwrite the messages in the db directory
|
||||
time.sleep(0.05)
|
||||
chat_db.write_db()
|
||||
# check if the written files are in the DB directory
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
|
||||
# check if all files in the DB dir have actually been overwritten
|
||||
for file in db_dir_files:
|
||||
self.assertGreater(file.stat().st_mtime, old_timestamps[file])
|
||||
# check that Message.file_path has been correctly updated (again)
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
def test_chat_db_read(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
self.assertEqual(len(chat_db.messages), 4)
|
||||
|
||||
# create 2 new files in the DB directory
|
||||
new_message1 = Message(Question('Question 5'),
|
||||
Answer('Answer 5'),
|
||||
{Tag('tag5')})
|
||||
new_message2 = Message(Question('Question 6'),
|
||||
Answer('Answer 6'),
|
||||
{Tag('tag6')})
|
||||
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
|
||||
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
|
||||
# read and check them
|
||||
chat_db.read_db()
|
||||
self.assertEqual(len(chat_db.messages), 6)
|
||||
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
|
||||
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
|
||||
|
||||
# create 2 new files in the cache directory
|
||||
new_message3 = Message(Question('Question 7'),
|
||||
Answer('Answer 5'),
|
||||
{Tag('tag7')})
|
||||
new_message4 = Message(Question('Question 8'),
|
||||
Answer('Answer 6'),
|
||||
{Tag('tag8')})
|
||||
new_message3.to_file(pathlib.Path(self.cache_path.name, '0007.txt'))
|
||||
new_message4.to_file(pathlib.Path(self.cache_path.name, '0008.yaml'))
|
||||
# read and check them
|
||||
chat_db.read_cache()
|
||||
self.assertEqual(len(chat_db.messages), 8)
|
||||
# check that the new message have the cache dir path
|
||||
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.cache_path.name, '0007.txt'))
|
||||
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.cache_path.name, '0008.yaml'))
|
||||
# an the old ones keep their path (since they have not been replaced)
|
||||
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
|
||||
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
|
||||
|
||||
# now overwrite two messages in the DB directory
|
||||
new_message1.question = Question('New Question 1')
|
||||
new_message2.question = Question('New Question 2')
|
||||
new_message1.to_file(pathlib.Path(self.db_path.name, '0005.txt'))
|
||||
new_message2.to_file(pathlib.Path(self.db_path.name, '0006.yaml'))
|
||||
# read from the DB dir and check if the modified messages have been updated
|
||||
chat_db.read_db()
|
||||
self.assertEqual(len(chat_db.messages), 8)
|
||||
self.assertEqual(chat_db.messages[4].question, 'New Question 1')
|
||||
self.assertEqual(chat_db.messages[5].question, 'New Question 2')
|
||||
self.assertEqual(chat_db.messages[4].file_path, pathlib.Path(self.db_path.name, '0005.txt'))
|
||||
self.assertEqual(chat_db.messages[5].file_path, pathlib.Path(self.db_path.name, '0006.yaml'))
|
||||
|
||||
# now write the messages from the cache to the DB directory
|
||||
new_message3.to_file(pathlib.Path(self.db_path.name, '0007.txt'))
|
||||
new_message4.to_file(pathlib.Path(self.db_path.name, '0008.yaml'))
|
||||
# read and check them
|
||||
chat_db.read_db()
|
||||
self.assertEqual(len(chat_db.messages), 8)
|
||||
# check that they now have the DB path
|
||||
self.assertEqual(chat_db.messages[6].file_path, pathlib.Path(self.db_path.name, '0007.txt'))
|
||||
self.assertEqual(chat_db.messages[7].file_path, pathlib.Path(self.db_path.name, '0008.yaml'))
|
||||
|
||||
def test_chat_db_clear(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
# check that Message.file_path is correct
|
||||
self.assertEqual(chat_db.messages[0].file_path, pathlib.Path(self.db_path.name, '0001.txt'))
|
||||
self.assertEqual(chat_db.messages[1].file_path, pathlib.Path(self.db_path.name, '0002.yaml'))
|
||||
self.assertEqual(chat_db.messages[2].file_path, pathlib.Path(self.db_path.name, '0003.txt'))
|
||||
self.assertEqual(chat_db.messages[3].file_path, pathlib.Path(self.db_path.name, '0004.yaml'))
|
||||
|
||||
# write the messages to the cache directory
|
||||
chat_db.write_cache()
|
||||
# check if the written files are in the cache directory
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 4)
|
||||
|
||||
# now rewrite them to the DB dir and check for modified paths
|
||||
chat_db.write_db()
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0001.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0002.yaml'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0003.txt'), db_dir_files)
|
||||
self.assertIn(pathlib.Path(self.db_path.name, '0004.yaml'), db_dir_files)
|
||||
|
||||
# add a new message with empty file_path
|
||||
message_empty = Message(question=Question("What the hell am I doing here?"),
|
||||
answer=Answer("You don't belong here!"))
|
||||
# and one for the cache dir
|
||||
message_cache = Message(question=Question("What the hell am I doing here?"),
|
||||
answer=Answer("You're a creep!"),
|
||||
file_path=pathlib.Path(self.cache_path.name, '0005.txt'))
|
||||
chat_db.add_messages([message_empty, message_cache])
|
||||
|
||||
# clear the cache and check the cache dir
|
||||
chat_db.clear_cache()
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 0)
|
||||
# make sure that the DB messages (and the new message) are still there
|
||||
self.assertEqual(len(chat_db.messages), 5)
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
# but not the message with the cache dir path
|
||||
self.assertFalse(any(m.file_path == message_cache.file_path for m in chat_db.messages))
|
||||
|
||||
def test_chat_db_add(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
|
||||
# add new messages to the cache dir
|
||||
message1 = Message(question=Question("Question 1"),
|
||||
answer=Answer("Answer 1"))
|
||||
chat_db.add_to_cache([message1])
|
||||
# check if the file_path has been correctly set
|
||||
self.assertIsNotNone(message1.file_path)
|
||||
self.assertEqual(message1.file_path.parent, pathlib.Path(self.cache_path.name)) # type: ignore [union-attr]
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 1)
|
||||
|
||||
# add new messages to the DB dir
|
||||
message2 = Message(question=Question("Question 2"),
|
||||
answer=Answer("Answer 2"))
|
||||
chat_db.add_to_db([message2])
|
||||
# check if the file_path has been correctly set
|
||||
self.assertIsNotNone(message2.file_path)
|
||||
self.assertEqual(message2.file_path.parent, pathlib.Path(self.db_path.name)) # type: ignore [union-attr]
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 5)
|
||||
|
||||
with self.assertRaises(ChatError):
|
||||
chat_db.add_to_cache([Message(Question("?"), file_path=pathlib.Path("foo"))])
|
||||
|
||||
def test_chat_db_write_messages(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 0)
|
||||
|
||||
# try to write a message without a valid file_path
|
||||
message = Message(question=Question("Question 1"),
|
||||
answer=Answer("Answer 1"))
|
||||
with self.assertRaises(ChatError):
|
||||
chat_db.write_messages([message])
|
||||
|
||||
# write a message with a valid file_path
|
||||
message.file_path = pathlib.Path(self.cache_path.name) / '123456.txt'
|
||||
chat_db.write_messages([message])
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 1)
|
||||
self.assertIn(pathlib.Path(self.cache_path.name, '123456.txt'), cache_dir_files)
|
||||
|
||||
def test_chat_db_update_messages(self) -> None:
|
||||
# create a new ChatDB instance
|
||||
chat_db = ChatDB.from_dir(pathlib.Path(self.cache_path.name),
|
||||
pathlib.Path(self.db_path.name))
|
||||
|
||||
db_dir_files = self.message_list(self.db_path)
|
||||
self.assertEqual(len(db_dir_files), 4)
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 0)
|
||||
|
||||
message = chat_db.messages[0]
|
||||
message.answer = Answer("New answer")
|
||||
# update message without writing
|
||||
chat_db.update_messages([message], write=False)
|
||||
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
|
||||
# re-read the message and check for old content
|
||||
chat_db.read_db()
|
||||
self.assertEqual(chat_db.messages[0].answer, Answer("Answer 1"))
|
||||
# now check with writing (message should be overwritten)
|
||||
chat_db.update_messages([message], write=True)
|
||||
chat_db.read_db()
|
||||
self.assertEqual(chat_db.messages[0].answer, Answer("New answer"))
|
||||
# test without file_path -> expect error
|
||||
message1 = Message(question=Question("Question 1"),
|
||||
answer=Answer("Answer 1"))
|
||||
with self.assertRaises(ChatError):
|
||||
chat_db.update_messages([message1])
|
||||
160
tests/test_configuration.py
Normal file
160
tests/test_configuration.py
Normal file
@ -0,0 +1,160 @@
|
||||
import os
|
||||
import unittest
|
||||
import yaml
|
||||
from tempfile import NamedTemporaryFile
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from chatmastermind.configuration import AIConfig, OpenAIConfig, ConfigError, ai_config_instance, Config
|
||||
|
||||
|
||||
class TestAIConfigInstance(unittest.TestCase):
|
||||
def test_ai_config_instance_with_valid_name_should_return_instance_with_default_values(self) -> None:
|
||||
ai_config = cast(OpenAIConfig, ai_config_instance('openai'))
|
||||
ai_reference = OpenAIConfig()
|
||||
self.assertEqual(ai_config.ID, ai_reference.ID)
|
||||
self.assertEqual(ai_config.name, ai_reference.name)
|
||||
self.assertEqual(ai_config.api_key, ai_reference.api_key)
|
||||
self.assertEqual(ai_config.system, ai_reference.system)
|
||||
self.assertEqual(ai_config.model, ai_reference.model)
|
||||
self.assertEqual(ai_config.temperature, ai_reference.temperature)
|
||||
self.assertEqual(ai_config.max_tokens, ai_reference.max_tokens)
|
||||
self.assertEqual(ai_config.top_p, ai_reference.top_p)
|
||||
self.assertEqual(ai_config.frequency_penalty, ai_reference.frequency_penalty)
|
||||
self.assertEqual(ai_config.presence_penalty, ai_reference.presence_penalty)
|
||||
|
||||
def test_ai_config_instance_with_valid_name_and_configuration_should_return_instance_with_custom_values(self) -> None:
|
||||
conf_dict = {
|
||||
'system': 'Custom system',
|
||||
'api_key': '9876543210',
|
||||
'model': 'custom_model',
|
||||
'max_tokens': 5000,
|
||||
'temperature': 0.5,
|
||||
'top_p': 0.8,
|
||||
'frequency_penalty': 0.7,
|
||||
'presence_penalty': 0.2
|
||||
}
|
||||
ai_config = cast(OpenAIConfig, ai_config_instance('openai', conf_dict))
|
||||
self.assertEqual(ai_config.system, 'Custom system')
|
||||
self.assertEqual(ai_config.api_key, '9876543210')
|
||||
self.assertEqual(ai_config.model, 'custom_model')
|
||||
self.assertEqual(ai_config.max_tokens, 5000)
|
||||
self.assertAlmostEqual(ai_config.temperature, 0.5)
|
||||
self.assertAlmostEqual(ai_config.top_p, 0.8)
|
||||
self.assertAlmostEqual(ai_config.frequency_penalty, 0.7)
|
||||
self.assertAlmostEqual(ai_config.presence_penalty, 0.2)
|
||||
|
||||
def test_ai_config_instance_with_invalid_name_should_raise_config_error(self) -> None:
|
||||
with self.assertRaises(ConfigError):
|
||||
ai_config_instance('invalid_name')
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.test_file = NamedTemporaryFile(delete=False)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
os.remove(self.test_file.name)
|
||||
|
||||
def test_from_dict_should_create_config_from_dict(self) -> None:
|
||||
source_dict = {
|
||||
'db': './test_db/',
|
||||
'ais': {
|
||||
'myopenai': {
|
||||
'name': 'openai',
|
||||
'system': 'Custom system',
|
||||
'api_key': '9876543210',
|
||||
'model': 'custom_model',
|
||||
'max_tokens': 5000,
|
||||
'temperature': 0.5,
|
||||
'top_p': 0.8,
|
||||
'frequency_penalty': 0.7,
|
||||
'presence_penalty': 0.2
|
||||
}
|
||||
}
|
||||
}
|
||||
config = Config.from_dict(source_dict)
|
||||
self.assertEqual(config.db, './test_db/')
|
||||
self.assertEqual(len(config.ais), 1)
|
||||
self.assertEqual(config.ais['myopenai'].name, 'openai')
|
||||
self.assertEqual(cast(OpenAIConfig, config.ais['myopenai']).system, 'Custom system')
|
||||
# check that 'ID' has been added
|
||||
self.assertEqual(config.ais['myopenai'].ID, 'myopenai')
|
||||
|
||||
def test_create_default_should_create_default_config(self) -> None:
|
||||
Config.create_default(Path(self.test_file.name))
|
||||
with open(self.test_file.name, 'r') as f:
|
||||
default_config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
config_reference = Config()
|
||||
self.assertEqual(default_config['db'], config_reference.db)
|
||||
|
||||
def test_from_file_should_load_config_from_file(self) -> None:
|
||||
source_dict = {
|
||||
'db': './test_db/',
|
||||
'ais': {
|
||||
'default': {
|
||||
'name': 'openai',
|
||||
'system': 'Custom system',
|
||||
'api_key': '9876543210',
|
||||
'model': 'custom_model',
|
||||
'max_tokens': 5000,
|
||||
'temperature': 0.5,
|
||||
'top_p': 0.8,
|
||||
'frequency_penalty': 0.7,
|
||||
'presence_penalty': 0.2
|
||||
}
|
||||
}
|
||||
}
|
||||
with open(self.test_file.name, 'w') as f:
|
||||
yaml.dump(source_dict, f)
|
||||
config = Config.from_file(self.test_file.name)
|
||||
self.assertIsInstance(config, Config)
|
||||
self.assertEqual(config.db, './test_db/')
|
||||
self.assertEqual(len(config.ais), 1)
|
||||
self.assertIsInstance(config.ais['default'], AIConfig)
|
||||
self.assertEqual(cast(OpenAIConfig, config.ais['default']).system, 'Custom system')
|
||||
|
||||
def test_to_file_should_save_config_to_file(self) -> None:
|
||||
config = Config(
|
||||
db='./test_db/',
|
||||
ais={
|
||||
'myopenai': OpenAIConfig(
|
||||
ID='myopenai',
|
||||
system='Custom system',
|
||||
api_key='9876543210',
|
||||
model='custom_model',
|
||||
max_tokens=5000,
|
||||
temperature=0.5,
|
||||
top_p=0.8,
|
||||
frequency_penalty=0.7,
|
||||
presence_penalty=0.2
|
||||
)
|
||||
}
|
||||
)
|
||||
config.to_file(Path(self.test_file.name))
|
||||
with open(self.test_file.name, 'r') as f:
|
||||
saved_config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
self.assertEqual(saved_config['db'], './test_db/')
|
||||
self.assertEqual(len(saved_config['ais']), 1)
|
||||
self.assertEqual(saved_config['ais']['myopenai']['system'], 'Custom system')
|
||||
|
||||
def test_from_file_error_unknown_ai(self) -> None:
|
||||
source_dict = {
|
||||
'db': './test_db/',
|
||||
'ais': {
|
||||
'default': {
|
||||
'name': 'foobla',
|
||||
'system': 'Custom system',
|
||||
'api_key': '9876543210',
|
||||
'model': 'custom_model',
|
||||
'max_tokens': 5000,
|
||||
'temperature': 0.5,
|
||||
'top_p': 0.8,
|
||||
'frequency_penalty': 0.7,
|
||||
'presence_penalty': 0.2
|
||||
}
|
||||
}
|
||||
}
|
||||
with open(self.test_file.name, 'w') as f:
|
||||
yaml.dump(source_dict, f)
|
||||
with self.assertRaises(ConfigError):
|
||||
Config.from_file(self.test_file.name)
|
||||
@ -1,209 +0,0 @@
|
||||
import unittest
|
||||
import io
|
||||
import pathlib
|
||||
import argparse
|
||||
from chatmastermind.utils import terminal_width
|
||||
from chatmastermind.main import create_parser, handle_question
|
||||
from chatmastermind.api_client import ai
|
||||
from chatmastermind.storage import create_chat, save_answers, dump_data
|
||||
from unittest import mock
|
||||
from unittest.mock import patch, MagicMock, Mock
|
||||
|
||||
|
||||
class TestCreateChat(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.config = {
|
||||
'system': 'System text',
|
||||
'db': 'test_files'
|
||||
}
|
||||
self.question = "test question"
|
||||
self.tags = ['test_tag']
|
||||
|
||||
@patch('os.listdir')
|
||||
@patch('builtins.open')
|
||||
def test_create_chat_with_tags(self, open_mock, listdir_mock):
|
||||
listdir_mock.return_value = ['testfile.txt']
|
||||
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
|
||||
{'question': 'test_content', 'answer': 'some answer',
|
||||
'tags': ['test_tag']}))
|
||||
|
||||
test_chat = create_chat(self.question, self.tags, None, self.config)
|
||||
|
||||
self.assertEqual(len(test_chat), 4)
|
||||
self.assertEqual(test_chat[0],
|
||||
{'role': 'system', 'content': self.config['system']})
|
||||
self.assertEqual(test_chat[1],
|
||||
{'role': 'user', 'content': 'test_content'})
|
||||
self.assertEqual(test_chat[2],
|
||||
{'role': 'assistant', 'content': 'some answer'})
|
||||
self.assertEqual(test_chat[3],
|
||||
{'role': 'user', 'content': self.question})
|
||||
|
||||
@patch('os.listdir')
|
||||
@patch('builtins.open')
|
||||
def test_create_chat_with_other_tags(self, open_mock, listdir_mock):
|
||||
listdir_mock.return_value = ['testfile.txt']
|
||||
open_mock.return_value.__enter__.return_value = io.StringIO(dump_data(
|
||||
{'question': 'test_content', 'answer': 'some answer',
|
||||
'tags': ['other_tag']}))
|
||||
|
||||
test_chat = create_chat(self.question, self.tags, None, self.config)
|
||||
|
||||
self.assertEqual(len(test_chat), 2)
|
||||
self.assertEqual(test_chat[0],
|
||||
{'role': 'system', 'content': self.config['system']})
|
||||
self.assertEqual(test_chat[1],
|
||||
{'role': 'user', 'content': self.question})
|
||||
|
||||
@patch('os.listdir')
|
||||
@patch('builtins.open')
|
||||
def test_create_chat_without_tags(self, open_mock, listdir_mock):
|
||||
listdir_mock.return_value = ['testfile.txt', 'testfile2.txt']
|
||||
open_mock.side_effect = (
|
||||
io.StringIO(dump_data({'question': 'test_content',
|
||||
'answer': 'some answer',
|
||||
'tags': ['test_tag']})),
|
||||
io.StringIO(dump_data({'question': 'test_content2',
|
||||
'answer': 'some answer2',
|
||||
'tags': ['test_tag2']})),
|
||||
)
|
||||
|
||||
test_chat = create_chat(self.question, [], None, self.config)
|
||||
|
||||
self.assertEqual(len(test_chat), 6)
|
||||
self.assertEqual(test_chat[0],
|
||||
{'role': 'system', 'content': self.config['system']})
|
||||
self.assertEqual(test_chat[1],
|
||||
{'role': 'user', 'content': 'test_content'})
|
||||
self.assertEqual(test_chat[2],
|
||||
{'role': 'assistant', 'content': 'some answer'})
|
||||
self.assertEqual(test_chat[3],
|
||||
{'role': 'user', 'content': 'test_content2'})
|
||||
self.assertEqual(test_chat[4],
|
||||
{'role': 'assistant', 'content': 'some answer2'})
|
||||
|
||||
|
||||
class TestHandleQuestion(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.question = "test question"
|
||||
self.args = argparse.Namespace(
|
||||
tags=['tag1'],
|
||||
extags=['extag1'],
|
||||
output_tags=None,
|
||||
question=[self.question],
|
||||
source=None,
|
||||
only_source_code=False,
|
||||
number=3
|
||||
)
|
||||
self.config = {
|
||||
'db': 'test_files',
|
||||
'setting1': 'value1',
|
||||
'setting2': 'value2'
|
||||
}
|
||||
|
||||
@patch("chatmastermind.main.create_chat", return_value="test_chat")
|
||||
@patch("chatmastermind.main.process_tags")
|
||||
@patch("chatmastermind.main.ai", return_value=(["answer1", "answer2", "answer3"], "test_usage"))
|
||||
@patch("chatmastermind.utils.pp")
|
||||
@patch("builtins.print")
|
||||
def test_handle_question(self, mock_print, mock_pp, mock_ai,
|
||||
mock_process_tags, mock_create_chat):
|
||||
open_mock = MagicMock()
|
||||
with patch("chatmastermind.storage.open", open_mock):
|
||||
handle_question(self.args, self.config, True)
|
||||
mock_process_tags.assert_called_once_with(self.args.tags,
|
||||
self.args.extags,
|
||||
[])
|
||||
mock_create_chat.assert_called_once_with(self.question,
|
||||
self.args.tags,
|
||||
self.args.extags,
|
||||
self.config)
|
||||
mock_pp.assert_called_once_with("test_chat")
|
||||
mock_ai.assert_called_with("test_chat",
|
||||
self.config,
|
||||
self.args.number)
|
||||
expected_calls = []
|
||||
for num, answer in enumerate(mock_ai.return_value[0], start=1):
|
||||
title = f'-- ANSWER {num} '
|
||||
title_end = '-' * (terminal_width() - len(title))
|
||||
expected_calls.append(((f'{title}{title_end}',),))
|
||||
expected_calls.append(((answer,),))
|
||||
expected_calls.append((("-" * terminal_width(),),))
|
||||
expected_calls.append(((f"Usage: {mock_ai.return_value[1]}",),))
|
||||
self.assertEqual(mock_print.call_args_list, expected_calls)
|
||||
open_expected_calls = list([mock.call(f"{num:04d}.txt", "w") for num in range(2, 5)])
|
||||
open_mock.assert_has_calls(open_expected_calls, any_order=True)
|
||||
|
||||
|
||||
class TestSaveAnswers(unittest.TestCase):
|
||||
@mock.patch('builtins.open')
|
||||
@mock.patch('chatmastermind.storage.print')
|
||||
def test_save_answers(self, print_mock, open_mock):
|
||||
question = "Test question?"
|
||||
answers = ["Answer 1", "Answer 2"]
|
||||
tags = ["tag1", "tag2"]
|
||||
otags = ["otag1", "otag2"]
|
||||
config = {'db': 'test_db'}
|
||||
|
||||
with mock.patch('chatmastermind.storage.pathlib.Path.exists', return_value=True), \
|
||||
mock.patch('chatmastermind.storage.yaml.dump'), \
|
||||
mock.patch('io.StringIO') as stringio_mock:
|
||||
stringio_instance = stringio_mock.return_value
|
||||
stringio_instance.getvalue.side_effect = ["question", "answer1", "answer2"]
|
||||
save_answers(question, answers, tags, otags, config)
|
||||
|
||||
open_calls = [
|
||||
mock.call(pathlib.Path('test_db/.next'), 'r'),
|
||||
mock.call(pathlib.Path('test_db/.next'), 'w'),
|
||||
]
|
||||
open_mock.assert_has_calls(open_calls, any_order=True)
|
||||
|
||||
|
||||
class TestAI(unittest.TestCase):
|
||||
|
||||
@patch("openai.ChatCompletion.create")
|
||||
def test_ai(self, mock_create: MagicMock):
|
||||
mock_create.return_value = {
|
||||
'choices': [
|
||||
{'message': {'content': 'response_text_1'}},
|
||||
{'message': {'content': 'response_text_2'}}
|
||||
],
|
||||
'usage': {'tokens': 10}
|
||||
}
|
||||
|
||||
number = 2
|
||||
chat = [{"role": "system", "content": "hello ai"}]
|
||||
config = {
|
||||
"openai": {
|
||||
"model": "text-davinci-002",
|
||||
"temperature": 0.5,
|
||||
"max_tokens": 150,
|
||||
"top_p": 1,
|
||||
"n": number,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0
|
||||
}
|
||||
}
|
||||
|
||||
result = ai(chat, config, number)
|
||||
expected_result = (['response_text_1', 'response_text_2'],
|
||||
{'tokens': 10})
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
|
||||
class TestCreateParser(unittest.TestCase):
|
||||
def test_create_parser(self):
|
||||
with patch('argparse.ArgumentParser.add_mutually_exclusive_group') as mock_add_mutually_exclusive_group:
|
||||
mock_group = Mock()
|
||||
mock_add_mutually_exclusive_group.return_value = mock_group
|
||||
parser = create_parser()
|
||||
self.assertIsInstance(parser, argparse.ArgumentParser)
|
||||
mock_add_mutually_exclusive_group.assert_called_once_with(required=True)
|
||||
mock_group.add_argument.assert_any_call('-p', '--print', help='File to print')
|
||||
mock_group.add_argument.assert_any_call('-q', '--question', nargs='*', help='Question to ask')
|
||||
mock_group.add_argument.assert_any_call('-D', '--chat-dump', help="Print chat as Python structure", action='store_true')
|
||||
mock_group.add_argument.assert_any_call('-d', '--chat', help="Print chat as readable text", action='store_true')
|
||||
self.assertTrue('.config.yaml' in parser.get_default('config'))
|
||||
self.assertEqual(parser.get_default('number'), 1)
|
||||
836
tests/test_message.py
Normal file
836
tests/test_message.py
Normal file
@ -0,0 +1,836 @@
|
||||
import unittest
|
||||
import pathlib
|
||||
import tempfile
|
||||
from typing import cast
|
||||
from chatmastermind.message import source_code, Message, MessageError, Question, Answer, AILine, ModelLine, MessageFilter, message_in
|
||||
from chatmastermind.tags import Tag, TagLine
|
||||
|
||||
|
||||
class SourceCodeTestCase(unittest.TestCase):
|
||||
def test_source_code_with_include_delims(self) -> None:
|
||||
text = """
|
||||
Some text before the code block
|
||||
```python
|
||||
print("Hello, World!")
|
||||
```
|
||||
Some text after the code block
|
||||
```python
|
||||
x = 10
|
||||
y = 20
|
||||
print(x + y)
|
||||
```
|
||||
"""
|
||||
expected_result = [
|
||||
" ```python\n print(\"Hello, World!\")\n ```\n",
|
||||
" ```python\n x = 10\n y = 20\n print(x + y)\n ```\n"
|
||||
]
|
||||
result = source_code(text, include_delims=True)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_source_code_without_include_delims(self) -> None:
|
||||
text = """
|
||||
Some text before the code block
|
||||
```python
|
||||
print("Hello, World!")
|
||||
```
|
||||
Some text after the code block
|
||||
```python
|
||||
x = 10
|
||||
y = 20
|
||||
print(x + y)
|
||||
```
|
||||
"""
|
||||
expected_result = [
|
||||
" print(\"Hello, World!\")\n",
|
||||
" x = 10\n y = 20\n print(x + y)\n"
|
||||
]
|
||||
result = source_code(text, include_delims=False)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_source_code_with_single_code_block(self) -> None:
|
||||
text = "```python\nprint(\"Hello, World!\")\n```"
|
||||
expected_result = ["```python\nprint(\"Hello, World!\")\n```\n"]
|
||||
result = source_code(text, include_delims=True)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_source_code_with_no_code_blocks(self) -> None:
|
||||
text = "Some text without any code blocks"
|
||||
expected_result: list[str] = []
|
||||
result = source_code(text, include_delims=True)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
|
||||
class QuestionTestCase(unittest.TestCase):
|
||||
def test_question_with_header(self) -> None:
|
||||
with self.assertRaises(MessageError):
|
||||
Question(f"{Question.txt_header}\nWhat is your name?")
|
||||
|
||||
def test_question_with_answer_header(self) -> None:
|
||||
with self.assertRaises(MessageError):
|
||||
Question(f"{Answer.txt_header}\nBob")
|
||||
|
||||
def test_question_with_legal_header(self) -> None:
|
||||
"""
|
||||
If the header is just a part of a line, it's fine.
|
||||
"""
|
||||
question = Question(f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
|
||||
self.assertIsInstance(question, Question)
|
||||
self.assertEqual(question, f"This is a line contaning '{Question.txt_header}'\nWhat does that mean?")
|
||||
|
||||
def test_question_without_header(self) -> None:
|
||||
question = Question("What is your favorite color?")
|
||||
self.assertIsInstance(question, Question)
|
||||
self.assertEqual(question, "What is your favorite color?")
|
||||
|
||||
|
||||
class AnswerTestCase(unittest.TestCase):
|
||||
def test_answer_with_header(self) -> None:
|
||||
with self.assertRaises(MessageError):
|
||||
Answer(f"{Answer.txt_header}\nno")
|
||||
|
||||
def test_answer_with_legal_header(self) -> None:
|
||||
answer = Answer(f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
|
||||
self.assertIsInstance(answer, Answer)
|
||||
self.assertEqual(answer, f"This is a line contaning '{Answer.txt_header}'\nIt is what it is.")
|
||||
|
||||
def test_answer_without_header(self) -> None:
|
||||
answer = Answer("No")
|
||||
self.assertIsInstance(answer, Answer)
|
||||
self.assertEqual(answer, "No")
|
||||
|
||||
|
||||
class MessageToFileTxtTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
self.message_complete = Message(Question('This is a question.'),
|
||||
Answer('This is an answer.'),
|
||||
{Tag('tag1'), Tag('tag2')},
|
||||
ai='ChatGPT',
|
||||
model='gpt-3.5-turbo',
|
||||
file_path=self.file_path)
|
||||
self.message_min = Message(Question('This is a question.'),
|
||||
file_path=self.file_path)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.file.close()
|
||||
self.file_path.unlink()
|
||||
|
||||
def test_to_file_txt_complete(self) -> None:
|
||||
self.message_complete.to_file(self.file_path)
|
||||
|
||||
with open(self.file_path, "r") as fd:
|
||||
content = fd.read()
|
||||
expected_content = f"""{TagLine.prefix} tag1 tag2
|
||||
{AILine.prefix} ChatGPT
|
||||
{ModelLine.prefix} gpt-3.5-turbo
|
||||
{Question.txt_header}
|
||||
This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer.
|
||||
"""
|
||||
self.assertEqual(content, expected_content)
|
||||
|
||||
def test_to_file_txt_min(self) -> None:
|
||||
self.message_min.to_file(self.file_path)
|
||||
|
||||
with open(self.file_path, "r") as fd:
|
||||
content = fd.read()
|
||||
expected_content = f"""{Question.txt_header}
|
||||
This is a question.
|
||||
"""
|
||||
self.assertEqual(content, expected_content)
|
||||
|
||||
def test_to_file_unsupported_file_type(self) -> None:
|
||||
unsupported_file_path = pathlib.Path("example.doc")
|
||||
with self.assertRaises(MessageError) as cm:
|
||||
self.message_complete.to_file(unsupported_file_path)
|
||||
self.assertEqual(str(cm.exception), "File type '.doc' is not supported")
|
||||
|
||||
def test_to_file_no_file_path(self) -> None:
|
||||
"""
|
||||
Provoke an exception using an empty path.
|
||||
"""
|
||||
with self.assertRaises(MessageError) as cm:
|
||||
# clear the internal file_path
|
||||
self.message_complete.file_path = None
|
||||
self.message_complete.to_file(None)
|
||||
self.assertEqual(str(cm.exception), "Got no valid path to write message")
|
||||
# reset the internal file_path
|
||||
self.message_complete.file_path = self.file_path
|
||||
|
||||
|
||||
class MessageToFileYamlTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
self.message_complete = Message(Question('This is a question.'),
|
||||
Answer('This is an answer.'),
|
||||
{Tag('tag1'), Tag('tag2')},
|
||||
ai='ChatGPT',
|
||||
model='gpt-3.5-turbo',
|
||||
file_path=self.file_path)
|
||||
self.message_multiline = Message(Question('This is a\nmultiline question.'),
|
||||
Answer('This is a\nmultiline answer.'),
|
||||
{Tag('tag1'), Tag('tag2')},
|
||||
ai='ChatGPT',
|
||||
model='gpt-3.5-turbo',
|
||||
file_path=self.file_path)
|
||||
self.message_min = Message(Question('This is a question.'),
|
||||
file_path=self.file_path)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.file.close()
|
||||
self.file_path.unlink()
|
||||
|
||||
def test_to_file_yaml_complete(self) -> None:
|
||||
self.message_complete.to_file(self.file_path)
|
||||
|
||||
with open(self.file_path, "r") as fd:
|
||||
content = fd.read()
|
||||
expected_content = f"""{Question.yaml_key}: This is a question.
|
||||
{Answer.yaml_key}: This is an answer.
|
||||
{Message.ai_yaml_key}: ChatGPT
|
||||
{Message.model_yaml_key}: gpt-3.5-turbo
|
||||
{Message.tags_yaml_key}:
|
||||
- tag1
|
||||
- tag2
|
||||
"""
|
||||
self.assertEqual(content, expected_content)
|
||||
|
||||
def test_to_file_yaml_multiline(self) -> None:
|
||||
self.message_multiline.to_file(self.file_path)
|
||||
|
||||
with open(self.file_path, "r") as fd:
|
||||
content = fd.read()
|
||||
expected_content = f"""{Question.yaml_key}: |-
|
||||
This is a
|
||||
multiline question.
|
||||
{Answer.yaml_key}: |-
|
||||
This is a
|
||||
multiline answer.
|
||||
{Message.ai_yaml_key}: ChatGPT
|
||||
{Message.model_yaml_key}: gpt-3.5-turbo
|
||||
{Message.tags_yaml_key}:
|
||||
- tag1
|
||||
- tag2
|
||||
"""
|
||||
self.assertEqual(content, expected_content)
|
||||
|
||||
def test_to_file_yaml_min(self) -> None:
|
||||
self.message_min.to_file(self.file_path)
|
||||
|
||||
with open(self.file_path, "r") as fd:
|
||||
content = fd.read()
|
||||
expected_content = f"{Question.yaml_key}: This is a question.\n"
|
||||
self.assertEqual(content, expected_content)
|
||||
|
||||
|
||||
class MessageFromFileTxtTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
with open(self.file_path, "w") as fd:
|
||||
fd.write(f"""{TagLine.prefix} tag1 tag2
|
||||
{AILine.prefix} ChatGPT
|
||||
{ModelLine.prefix} gpt-3.5-turbo
|
||||
{Question.txt_header}
|
||||
This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer.
|
||||
""")
|
||||
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path_min = pathlib.Path(self.file_min.name)
|
||||
with open(self.file_path_min, "w") as fd:
|
||||
fd.write(f"""{Question.txt_header}
|
||||
This is a question.
|
||||
""")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.file.close()
|
||||
self.file_min.close()
|
||||
self.file_path.unlink()
|
||||
self.file_path_min.unlink()
|
||||
|
||||
def test_from_file_txt_complete(self) -> None:
|
||||
"""
|
||||
Read a complete message (with all optional values).
|
||||
"""
|
||||
message = Message.from_file(self.file_path)
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.answer, 'This is an answer.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||
self.assertEqual(message.ai, 'ChatGPT')
|
||||
self.assertEqual(message.model, 'gpt-3.5-turbo')
|
||||
self.assertEqual(message.file_path, self.file_path)
|
||||
|
||||
def test_from_file_txt_min(self) -> None:
|
||||
"""
|
||||
Read a message with only required values.
|
||||
"""
|
||||
message = Message.from_file(self.file_path_min)
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.file_path, self.file_path_min)
|
||||
self.assertIsNone(message.answer)
|
||||
|
||||
def test_from_file_txt_tags_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(tags_or={Tag('tag1')}))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.answer, 'This is an answer.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||
self.assertEqual(message.file_path, self.file_path)
|
||||
|
||||
def test_from_file_txt_tags_dont_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(tags_or={Tag('tag3')}))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_txt_no_tags_dont_match(self) -> None:
|
||||
message = Message.from_file(self.file_path_min,
|
||||
MessageFilter(tags_or={Tag('tag1')}))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_txt_empty_tags_dont_match(self) -> None:
|
||||
message = Message.from_file(self.file_path_min,
|
||||
MessageFilter(tags_or=set(),
|
||||
tags_and=set()))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_txt_no_tags_match_tags_not(self) -> None:
|
||||
message = Message.from_file(self.file_path_min,
|
||||
MessageFilter(tags_not={Tag('tag1')}))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||
self.assertEqual(message.file_path, self.file_path_min)
|
||||
|
||||
def test_from_file_not_exists(self) -> None:
|
||||
file_not_exists = pathlib.Path("example.txt")
|
||||
with self.assertRaises(MessageError) as cm:
|
||||
Message.from_file(file_not_exists)
|
||||
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
|
||||
|
||||
def test_from_file_txt_question_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(question_contains='question'))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
|
||||
def test_from_file_txt_answer_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(answer_contains='answer'))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
|
||||
def test_from_file_txt_answer_available(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(answer_state='available'))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
|
||||
def test_from_file_txt_answer_missing(self) -> None:
|
||||
message = Message.from_file(self.file_path_min,
|
||||
MessageFilter(answer_state='missing'))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
|
||||
def test_from_file_txt_question_doesnt_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(question_contains='answer'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_txt_answer_doesnt_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(answer_contains='question'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_txt_answer_not_exists(self) -> None:
|
||||
message = Message.from_file(self.file_path_min,
|
||||
MessageFilter(answer_contains='answer'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_txt_answer_not_available(self) -> None:
|
||||
message = Message.from_file(self.file_path_min,
|
||||
MessageFilter(answer_state='available'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_txt_answer_not_missing(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(answer_state='missing'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_txt_ai_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(ai='ChatGPT'))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
|
||||
def test_from_file_txt_ai_doesnt_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(ai='Foo'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_txt_model_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(model='gpt-3.5-turbo'))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
|
||||
def test_from_file_txt_model_doesnt_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(model='Bar'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
|
||||
class MessageFromFileYamlTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
with open(self.file_path, "w") as fd:
|
||||
fd.write(f"""
|
||||
{Question.yaml_key}: |-
|
||||
This is a question.
|
||||
{Answer.yaml_key}: |-
|
||||
This is an answer.
|
||||
{Message.ai_yaml_key}: ChatGPT
|
||||
{Message.model_yaml_key}: gpt-3.5-turbo
|
||||
{Message.tags_yaml_key}:
|
||||
- tag1
|
||||
- tag2
|
||||
""")
|
||||
self.file_min = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||
self.file_path_min = pathlib.Path(self.file_min.name)
|
||||
with open(self.file_path_min, "w") as fd:
|
||||
fd.write(f"""
|
||||
{Question.yaml_key}: |-
|
||||
This is a question.
|
||||
""")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.file.close()
|
||||
self.file_path.unlink()
|
||||
self.file_min.close()
|
||||
self.file_path_min.unlink()
|
||||
|
||||
def test_from_file_yaml_complete(self) -> None:
|
||||
"""
|
||||
Read a complete message (with all optional values).
|
||||
"""
|
||||
message = Message.from_file(self.file_path)
|
||||
self.assertIsInstance(message, Message)
|
||||
self.assertIsNotNone(message)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.answer, 'This is an answer.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||
self.assertEqual(message.ai, 'ChatGPT')
|
||||
self.assertEqual(message.model, 'gpt-3.5-turbo')
|
||||
self.assertEqual(message.file_path, self.file_path)
|
||||
|
||||
def test_from_file_yaml_min(self) -> None:
|
||||
"""
|
||||
Read a message with only the required values.
|
||||
"""
|
||||
message = Message.from_file(self.file_path_min)
|
||||
self.assertIsInstance(message, Message)
|
||||
self.assertIsNotNone(message)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||
self.assertEqual(message.file_path, self.file_path_min)
|
||||
self.assertIsNone(message.answer)
|
||||
|
||||
def test_from_file_not_exists(self) -> None:
|
||||
file_not_exists = pathlib.Path("example.yaml")
|
||||
with self.assertRaises(MessageError) as cm:
|
||||
Message.from_file(file_not_exists)
|
||||
self.assertEqual(str(cm.exception), f"Message file '{file_not_exists}' does not exist")
|
||||
|
||||
def test_from_file_yaml_tags_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(tags_or={Tag('tag1')}))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertEqual(message.answer, 'This is an answer.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), {Tag('tag1'), Tag('tag2')})
|
||||
self.assertEqual(message.file_path, self.file_path)
|
||||
|
||||
def test_from_file_yaml_tags_dont_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(tags_or={Tag('tag3')}))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_yaml_no_tags_dont_match(self) -> None:
|
||||
message = Message.from_file(self.file_path_min,
|
||||
MessageFilter(tags_or={Tag('tag1')}))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_yaml_no_tags_match_tags_not(self) -> None:
|
||||
message = Message.from_file(self.file_path_min,
|
||||
MessageFilter(tags_not={Tag('tag1')}))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
if message: # mypy bug
|
||||
self.assertEqual(message.question, 'This is a question.')
|
||||
self.assertSetEqual(cast(set[Tag], message.tags), set())
|
||||
self.assertEqual(message.file_path, self.file_path_min)
|
||||
|
||||
def test_from_file_yaml_question_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(question_contains='question'))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
|
||||
def test_from_file_yaml_answer_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(answer_contains='answer'))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
|
||||
def test_from_file_yaml_answer_available(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(answer_state='available'))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
|
||||
def test_from_file_yaml_answer_missing(self) -> None:
|
||||
message = Message.from_file(self.file_path_min,
|
||||
MessageFilter(answer_state='missing'))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
|
||||
def test_from_file_yaml_question_doesnt_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(question_contains='answer'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_yaml_answer_doesnt_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(answer_contains='question'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_yaml_answer_not_exists(self) -> None:
|
||||
message = Message.from_file(self.file_path_min,
|
||||
MessageFilter(answer_contains='answer'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_yaml_answer_not_available(self) -> None:
|
||||
message = Message.from_file(self.file_path_min,
|
||||
MessageFilter(answer_state='available'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_yaml_answer_not_missing(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(answer_state='missing'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_yaml_ai_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(ai='ChatGPT'))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
|
||||
def test_from_file_yaml_ai_doesnt_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(ai='Foo'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
def test_from_file_yaml_model_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(model='gpt-3.5-turbo'))
|
||||
self.assertIsNotNone(message)
|
||||
self.assertIsInstance(message, Message)
|
||||
|
||||
def test_from_file_yaml_model_doesnt_match(self) -> None:
|
||||
message = Message.from_file(self.file_path,
|
||||
MessageFilter(model='Bar'))
|
||||
self.assertIsNone(message)
|
||||
|
||||
|
||||
class TagsFromFileTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file_txt = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path_txt = pathlib.Path(self.file_txt.name)
|
||||
with open(self.file_path_txt, "w") as fd:
|
||||
fd.write(f"""{TagLine.prefix} tag1 tag2 ptag3
|
||||
{Question.txt_header}
|
||||
This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer.
|
||||
""")
|
||||
self.file_txt_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path_txt_no_tags = pathlib.Path(self.file_txt_no_tags.name)
|
||||
with open(self.file_path_txt_no_tags, "w") as fd:
|
||||
fd.write(f"""{Question.txt_header}
|
||||
This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer.
|
||||
""")
|
||||
self.file_txt_tags_empty = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path_txt_tags_empty = pathlib.Path(self.file_txt_tags_empty.name)
|
||||
with open(self.file_path_txt_tags_empty, "w") as fd:
|
||||
fd.write(f"""TAGS:
|
||||
{Question.txt_header}
|
||||
This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer.
|
||||
""")
|
||||
self.file_yaml = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||
self.file_path_yaml = pathlib.Path(self.file_yaml.name)
|
||||
with open(self.file_path_yaml, "w") as fd:
|
||||
fd.write(f"""
|
||||
{Question.yaml_key}: |-
|
||||
This is a question.
|
||||
{Answer.yaml_key}: |-
|
||||
This is an answer.
|
||||
{Message.tags_yaml_key}:
|
||||
- tag1
|
||||
- tag2
|
||||
- ptag3
|
||||
""")
|
||||
self.file_yaml_no_tags = tempfile.NamedTemporaryFile(delete=False, suffix='.yaml')
|
||||
self.file_path_yaml_no_tags = pathlib.Path(self.file_yaml_no_tags.name)
|
||||
with open(self.file_path_yaml_no_tags, "w") as fd:
|
||||
fd.write(f"""
|
||||
{Question.yaml_key}: |-
|
||||
This is a question.
|
||||
{Answer.yaml_key}: |-
|
||||
This is an answer.
|
||||
""")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.file_txt.close()
|
||||
self.file_path_txt.unlink()
|
||||
self.file_yaml.close()
|
||||
self.file_path_yaml.unlink()
|
||||
self.file_txt_no_tags.close
|
||||
self.file_path_txt_no_tags.unlink()
|
||||
self.file_txt_tags_empty.close
|
||||
self.file_path_txt_tags_empty.unlink()
|
||||
self.file_yaml_no_tags.close()
|
||||
self.file_path_yaml_no_tags.unlink()
|
||||
|
||||
def test_tags_from_file_txt(self) -> None:
|
||||
tags = Message.tags_from_file(self.file_path_txt)
|
||||
self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')})
|
||||
|
||||
def test_tags_from_file_txt_no_tags(self) -> None:
|
||||
tags = Message.tags_from_file(self.file_path_txt_no_tags)
|
||||
self.assertSetEqual(tags, set())
|
||||
|
||||
def test_tags_from_file_txt_tags_empty(self) -> None:
|
||||
tags = Message.tags_from_file(self.file_path_txt_tags_empty)
|
||||
self.assertSetEqual(tags, set())
|
||||
|
||||
def test_tags_from_file_yaml(self) -> None:
|
||||
tags = Message.tags_from_file(self.file_path_yaml)
|
||||
self.assertSetEqual(tags, {Tag('tag1'), Tag('tag2'), Tag('ptag3')})
|
||||
|
||||
def test_tags_from_file_yaml_no_tags(self) -> None:
|
||||
tags = Message.tags_from_file(self.file_path_yaml_no_tags)
|
||||
self.assertSetEqual(tags, set())
|
||||
|
||||
def test_tags_from_file_txt_prefix(self) -> None:
|
||||
tags = Message.tags_from_file(self.file_path_txt, prefix='p')
|
||||
self.assertSetEqual(tags, {Tag('ptag3')})
|
||||
tags = Message.tags_from_file(self.file_path_txt, prefix='R')
|
||||
self.assertSetEqual(tags, set())
|
||||
|
||||
def test_tags_from_file_yaml_prefix(self) -> None:
|
||||
tags = Message.tags_from_file(self.file_path_yaml, prefix='p')
|
||||
self.assertSetEqual(tags, {Tag('ptag3')})
|
||||
tags = Message.tags_from_file(self.file_path_yaml, prefix='R')
|
||||
self.assertSetEqual(tags, set())
|
||||
|
||||
def test_tags_from_file_txt_contain(self) -> None:
|
||||
tags = Message.tags_from_file(self.file_path_txt, contain='3')
|
||||
self.assertSetEqual(tags, {Tag('ptag3')})
|
||||
tags = Message.tags_from_file(self.file_path_txt, contain='R')
|
||||
self.assertSetEqual(tags, set())
|
||||
|
||||
def test_tags_from_file_yaml_contain(self) -> None:
|
||||
tags = Message.tags_from_file(self.file_path_yaml, contain='3')
|
||||
self.assertSetEqual(tags, {Tag('ptag3')})
|
||||
tags = Message.tags_from_file(self.file_path_yaml, contain='R')
|
||||
self.assertSetEqual(tags, set())
|
||||
|
||||
|
||||
class TagsFromDirTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.temp_dir_no_tags = tempfile.TemporaryDirectory()
|
||||
self.tag_sets = [
|
||||
{Tag('atag1'), Tag('atag2')},
|
||||
{Tag('btag3'), Tag('btag4')},
|
||||
{Tag('ctag5'), Tag('ctag6')}
|
||||
]
|
||||
self.files = [
|
||||
pathlib.Path(self.temp_dir.name, 'file1.txt'),
|
||||
pathlib.Path(self.temp_dir.name, 'file2.yaml'),
|
||||
pathlib.Path(self.temp_dir.name, 'file3.txt')
|
||||
]
|
||||
self.files_no_tags = [
|
||||
pathlib.Path(self.temp_dir_no_tags.name, 'file4.txt'),
|
||||
pathlib.Path(self.temp_dir_no_tags.name, 'file5.yaml'),
|
||||
pathlib.Path(self.temp_dir_no_tags.name, 'file6.txt')
|
||||
]
|
||||
for file, tags in zip(self.files, self.tag_sets):
|
||||
message = Message(Question('This is a question.'),
|
||||
Answer('This is an answer.'),
|
||||
tags)
|
||||
message.to_file(file)
|
||||
for file in self.files_no_tags:
|
||||
message = Message(Question('This is a question.'),
|
||||
Answer('This is an answer.'))
|
||||
message.to_file(file)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.temp_dir.cleanup()
|
||||
self.temp_dir_no_tags.cleanup()
|
||||
|
||||
def test_tags_from_dir(self) -> None:
|
||||
all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name))
|
||||
expected_tags = self.tag_sets[0] | self.tag_sets[1] | self.tag_sets[2]
|
||||
self.assertEqual(all_tags, expected_tags)
|
||||
|
||||
def test_tags_from_dir_prefix(self) -> None:
|
||||
atags = Message.tags_from_dir(pathlib.Path(self.temp_dir.name), prefix='a')
|
||||
expected_tags = self.tag_sets[0]
|
||||
self.assertEqual(atags, expected_tags)
|
||||
|
||||
def test_tags_from_dir_no_tags(self) -> None:
|
||||
all_tags = Message.tags_from_dir(pathlib.Path(self.temp_dir_no_tags.name))
|
||||
self.assertSetEqual(all_tags, set())
|
||||
|
||||
|
||||
class MessageIDTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.file = tempfile.NamedTemporaryFile(delete=False, suffix='.txt')
|
||||
self.file_path = pathlib.Path(self.file.name)
|
||||
self.message = Message(Question('This is a question.'),
|
||||
file_path=self.file_path)
|
||||
self.message_no_file_path = Message(Question('This is a question.'))
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.file.close()
|
||||
self.file_path.unlink()
|
||||
|
||||
def test_msg_id_txt(self) -> None:
|
||||
self.assertEqual(self.message.msg_id(), self.file_path.name)
|
||||
|
||||
def test_msg_id_txt_exception(self) -> None:
|
||||
with self.assertRaises(MessageError):
|
||||
self.message_no_file_path.msg_id()
|
||||
|
||||
|
||||
class MessageHashTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message1 = Message(Question('This is a question.'),
|
||||
tags={Tag('tag1')},
|
||||
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||
self.message2 = Message(Question('This is a new question.'),
|
||||
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||
self.message3 = Message(Question('This is a question.'),
|
||||
Answer('This is an answer.'),
|
||||
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||
# message4 is a copy of message1, because only question and
|
||||
# answer are used for hashing and comparison
|
||||
self.message4 = Message(Question('This is a question.'),
|
||||
tags={Tag('tag1'), Tag('tag2')},
|
||||
ai='Blabla',
|
||||
file_path=pathlib.Path('foobla'))
|
||||
|
||||
def test_set_hashing(self) -> None:
|
||||
msgs: set[Message] = {self.message1, self.message2, self.message3, self.message4}
|
||||
self.assertEqual(len(msgs), 3)
|
||||
for msg in [self.message1, self.message2, self.message3]:
|
||||
self.assertIn(msg, msgs)
|
||||
|
||||
|
||||
class MessageTagsStrTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message = Message(Question('This is a question.'),
|
||||
tags={Tag('tag1')},
|
||||
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||
|
||||
def test_tags_str(self) -> None:
|
||||
self.assertEqual(self.message.tags_str(), f'{TagLine.prefix} tag1')
|
||||
|
||||
|
||||
class MessageFilterTagsTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message = Message(Question('This is a question.'),
|
||||
tags={Tag('atag1'), Tag('btag2')},
|
||||
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||
|
||||
def test_filter_tags(self) -> None:
|
||||
tags_all = self.message.filter_tags()
|
||||
self.assertSetEqual(tags_all, {Tag('atag1'), Tag('btag2')})
|
||||
tags_pref = self.message.filter_tags(prefix='a')
|
||||
self.assertSetEqual(tags_pref, {Tag('atag1')})
|
||||
tags_cont = self.message.filter_tags(contain='2')
|
||||
self.assertSetEqual(tags_cont, {Tag('btag2')})
|
||||
|
||||
|
||||
class MessageInTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message1 = Message(Question('This is a question.'),
|
||||
tags={Tag('atag1'), Tag('btag2')},
|
||||
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||
self.message2 = Message(Question('This is a question.'),
|
||||
tags={Tag('atag1'), Tag('btag2')},
|
||||
file_path=pathlib.Path('/tmp/bla/foo'))
|
||||
|
||||
def test_message_in(self) -> None:
|
||||
self.assertTrue(message_in(self.message1, [self.message1]))
|
||||
self.assertFalse(message_in(self.message1, [self.message2]))
|
||||
|
||||
|
||||
class MessageRenameTagsTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message = Message(Question('This is a question.'),
|
||||
tags={Tag('atag1'), Tag('btag2')},
|
||||
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||
|
||||
def test_rename_tags(self) -> None:
|
||||
self.message.rename_tags({(Tag('atag1'), Tag('atag2')), (Tag('btag2'), Tag('btag3'))})
|
||||
self.assertIsNotNone(self.message.tags)
|
||||
self.assertSetEqual(self.message.tags, {Tag('atag2'), Tag('btag3')}) # type: ignore [arg-type]
|
||||
|
||||
|
||||
class MessageToStrTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.message = Message(Question('This is a question.'),
|
||||
Answer('This is an answer.'),
|
||||
tags={Tag('atag1'), Tag('btag2')},
|
||||
file_path=pathlib.Path('/tmp/foo/bla'))
|
||||
|
||||
def test_to_str(self) -> None:
|
||||
expected_output = f"""{Question.txt_header}
|
||||
This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer."""
|
||||
self.assertEqual(self.message.to_str(), expected_output)
|
||||
|
||||
def test_to_str_with_tags_and_file(self) -> None:
|
||||
expected_output = f"""{TagLine.prefix} atag1 btag2
|
||||
FILE: /tmp/foo/bla
|
||||
{Question.txt_header}
|
||||
This is a question.
|
||||
{Answer.txt_header}
|
||||
This is an answer."""
|
||||
self.assertEqual(self.message.to_str(with_tags=True, with_file=True), expected_output)
|
||||
195
tests/test_question_cmd.py
Normal file
195
tests/test_question_cmd.py
Normal file
@ -0,0 +1,195 @@
|
||||
import os
|
||||
import unittest
|
||||
import argparse
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
from chatmastermind.commands.question import create_message
|
||||
from chatmastermind.message import Message, Question, Answer
|
||||
from chatmastermind.chat import ChatDB
|
||||
|
||||
|
||||
class TestMessageCreate(unittest.TestCase):
|
||||
"""
|
||||
Test if messages created by the 'question' command have
|
||||
the correct format.
|
||||
"""
|
||||
def setUp(self) -> None:
|
||||
# create ChatDB structure
|
||||
self.db_path = tempfile.TemporaryDirectory()
|
||||
self.cache_path = tempfile.TemporaryDirectory()
|
||||
self.chat = ChatDB.from_dir(cache_path=Path(self.cache_path.name),
|
||||
db_path=Path(self.db_path.name))
|
||||
# create some messages
|
||||
self.message_text = Message(Question("What is this?"),
|
||||
Answer("It is pure text"))
|
||||
self.message_code = Message(Question("What is this?"),
|
||||
Answer("Text\n```\nIt is embedded code\n```\ntext"))
|
||||
self.chat.add_to_db([self.message_text, self.message_code])
|
||||
# create arguments mock
|
||||
self.args = MagicMock(spec=argparse.Namespace)
|
||||
self.args.source_text = None
|
||||
self.args.source_code = None
|
||||
self.args.AI = None
|
||||
self.args.model = None
|
||||
self.args.output_tags = None
|
||||
# File 1 : no source code block, only text
|
||||
self.source_file1 = tempfile.NamedTemporaryFile(delete=False)
|
||||
self.source_file1_content = """This is just text.
|
||||
No source code.
|
||||
Nope. Go look elsewhere!"""
|
||||
with open(self.source_file1.name, 'w') as f:
|
||||
f.write(self.source_file1_content)
|
||||
# File 2 : one embedded source code block
|
||||
self.source_file2 = tempfile.NamedTemporaryFile(delete=False)
|
||||
self.source_file2_content = """This is just text.
|
||||
```
|
||||
This is embedded source code.
|
||||
```
|
||||
And some text again."""
|
||||
with open(self.source_file2.name, 'w') as f:
|
||||
f.write(self.source_file2_content)
|
||||
# File 3 : all source code
|
||||
self.source_file3 = tempfile.NamedTemporaryFile(delete=False)
|
||||
self.source_file3_content = """This is all source code.
|
||||
Yes, really.
|
||||
Language is called 'brainfart'."""
|
||||
with open(self.source_file3.name, 'w') as f:
|
||||
f.write(self.source_file3_content)
|
||||
# File 4 : two source code blocks
|
||||
self.source_file4 = tempfile.NamedTemporaryFile(delete=False)
|
||||
self.source_file4_content = """This is just text.
|
||||
```
|
||||
This is embedded source code.
|
||||
```
|
||||
And some text again.
|
||||
```
|
||||
This is embedded source code.
|
||||
```
|
||||
Aaaand again some text."""
|
||||
with open(self.source_file4.name, 'w') as f:
|
||||
f.write(self.source_file4_content)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
os.remove(self.source_file1.name)
|
||||
os.remove(self.source_file2.name)
|
||||
os.remove(self.source_file3.name)
|
||||
|
||||
def message_list(self, tmp_dir: tempfile.TemporaryDirectory) -> list[Path]:
|
||||
# exclude '.next'
|
||||
return list(Path(tmp_dir.name).glob('*.[ty]*'))
|
||||
|
||||
def test_message_file_created(self) -> None:
|
||||
self.args.ask = ["What is this?"]
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 0)
|
||||
create_message(self.chat, self.args)
|
||||
cache_dir_files = self.message_list(self.cache_path)
|
||||
self.assertEqual(len(cache_dir_files), 1)
|
||||
message = Message.from_file(cache_dir_files[0])
|
||||
self.assertIsInstance(message, Message)
|
||||
self.assertEqual(message.question, Question("What is this?")) # type: ignore [union-attr]
|
||||
|
||||
def test_single_question(self) -> None:
|
||||
self.args.ask = ["What is this?"]
|
||||
message = create_message(self.chat, self.args)
|
||||
self.assertIsInstance(message, Message)
|
||||
self.assertEqual(message.question, Question("What is this?"))
|
||||
self.assertEqual(len(message.question.source_code()), 0)
|
||||
|
||||
def test_multipart_question(self) -> None:
|
||||
self.args.ask = ["What is this", "'bard' thing?", "Is it good?"]
|
||||
message = create_message(self.chat, self.args)
|
||||
self.assertIsInstance(message, Message)
|
||||
self.assertEqual(message.question, Question("""What is this
|
||||
|
||||
'bard' thing?
|
||||
|
||||
Is it good?"""))
|
||||
|
||||
def test_single_question_with_text_only_file(self) -> None:
|
||||
self.args.ask = ["What is this?"]
|
||||
self.args.source_text = [f"{self.source_file1.name}"]
|
||||
message = create_message(self.chat, self.args)
|
||||
self.assertIsInstance(message, Message)
|
||||
# file contains no source code (only text)
|
||||
# -> don't expect any in the question
|
||||
self.assertEqual(len(message.question.source_code()), 0)
|
||||
self.assertEqual(message.question, Question(f"""What is this?
|
||||
|
||||
{self.source_file1_content}"""))
|
||||
|
||||
def test_single_question_with_text_file_and_embedded_code(self) -> None:
|
||||
self.args.ask = ["What is this?"]
|
||||
self.args.source_code = [f"{self.source_file2.name}"]
|
||||
message = create_message(self.chat, self.args)
|
||||
self.assertIsInstance(message, Message)
|
||||
# file contains 1 source code block
|
||||
# -> expect it in the question
|
||||
self.assertEqual(len(message.question.source_code()), 1)
|
||||
self.assertEqual(message.question, Question("""What is this?
|
||||
|
||||
```
|
||||
This is embedded source code.
|
||||
```
|
||||
"""))
|
||||
|
||||
def test_single_question_with_code_only_file(self) -> None:
|
||||
self.args.ask = ["What is this?"]
|
||||
self.args.source_code = [f"{self.source_file3.name}"]
|
||||
message = create_message(self.chat, self.args)
|
||||
self.assertIsInstance(message, Message)
|
||||
# file is complete source code
|
||||
self.assertEqual(len(message.question.source_code()), 1)
|
||||
self.assertEqual(message.question, Question(f"""What is this?
|
||||
|
||||
```
|
||||
{self.source_file3_content}
|
||||
```"""))
|
||||
|
||||
def test_single_question_with_text_file_and_multi_embedded_code(self) -> None:
|
||||
self.args.ask = ["What is this?"]
|
||||
self.args.source_code = [f"{self.source_file4.name}"]
|
||||
message = create_message(self.chat, self.args)
|
||||
self.assertIsInstance(message, Message)
|
||||
# file contains 2 source code blocks
|
||||
# -> expect them in the question
|
||||
self.assertEqual(len(message.question.source_code()), 2)
|
||||
self.assertEqual(message.question, Question("""What is this?
|
||||
|
||||
```
|
||||
This is embedded source code.
|
||||
```
|
||||
|
||||
|
||||
```
|
||||
This is embedded source code.
|
||||
```
|
||||
"""))
|
||||
|
||||
def test_single_question_with_text_only_message(self) -> None:
|
||||
self.args.ask = ["What is this?"]
|
||||
self.args.source_text = [f"{self.chat.messages[0].file_path}"]
|
||||
message = create_message(self.chat, self.args)
|
||||
self.assertIsInstance(message, Message)
|
||||
# file contains no source code (only text)
|
||||
# -> don't expect any in the question
|
||||
self.assertEqual(len(message.question.source_code()), 0)
|
||||
self.assertEqual(message.question, Question(f"""What is this?
|
||||
|
||||
{self.message_text.answer}"""))
|
||||
|
||||
def test_single_question_with_message_and_embedded_code(self) -> None:
|
||||
self.args.ask = ["What is this?"]
|
||||
self.args.source_code = [f"{self.chat.messages[1].file_path}"]
|
||||
message = create_message(self.chat, self.args)
|
||||
self.assertIsInstance(message, Message)
|
||||
# answer contains 1 source code block
|
||||
# -> expect it in the question
|
||||
self.assertEqual(len(message.question.source_code()), 1)
|
||||
self.assertEqual(message.question, Question("""What is this?
|
||||
|
||||
```
|
||||
It is embedded code
|
||||
```
|
||||
"""))
|
||||
163
tests/test_tags.py
Normal file
163
tests/test_tags.py
Normal file
@ -0,0 +1,163 @@
|
||||
import unittest
|
||||
from chatmastermind.tags import Tag, TagLine, TagError
|
||||
|
||||
|
||||
class TestTag(unittest.TestCase):
|
||||
def test_valid_tag(self) -> None:
|
||||
tag = Tag('mytag')
|
||||
self.assertEqual(tag, 'mytag')
|
||||
|
||||
def test_invalid_tag(self) -> None:
|
||||
with self.assertRaises(TagError):
|
||||
Tag('tag with space')
|
||||
|
||||
def test_default_separator(self) -> None:
|
||||
self.assertEqual(Tag.default_separator, ' ')
|
||||
|
||||
def test_alternative_separators(self) -> None:
|
||||
self.assertEqual(Tag.alternative_separators, [','])
|
||||
|
||||
|
||||
class TestTagLine(unittest.TestCase):
|
||||
def test_valid_tagline(self) -> None:
|
||||
tagline = TagLine('TAGS: tag1 tag2')
|
||||
self.assertEqual(tagline, 'TAGS: tag1 tag2')
|
||||
|
||||
def test_valid_tagline_with_newline(self) -> None:
|
||||
tagline = TagLine('TAGS: tag1\n tag2')
|
||||
self.assertEqual(tagline, 'TAGS: tag1 tag2')
|
||||
|
||||
def test_invalid_tagline(self) -> None:
|
||||
with self.assertRaises(TagError):
|
||||
TagLine('tag1 tag2')
|
||||
|
||||
def test_prefix(self) -> None:
|
||||
self.assertEqual(TagLine.prefix, 'TAGS:')
|
||||
|
||||
def test_from_set(self) -> None:
|
||||
tags = {Tag('tag1'), Tag('tag2')}
|
||||
tagline = TagLine.from_set(tags)
|
||||
self.assertEqual(tagline, 'TAGS: tag1 tag2')
|
||||
|
||||
def test_tags(self) -> None:
|
||||
tagline = TagLine('TAGS: atag1 btag2')
|
||||
tags = tagline.tags()
|
||||
self.assertEqual(tags, {Tag('atag1'), Tag('btag2')})
|
||||
|
||||
def test_tags_empty(self) -> None:
|
||||
tagline = TagLine('TAGS:')
|
||||
self.assertSetEqual(tagline.tags(), set())
|
||||
|
||||
def test_tags_with_newline(self) -> None:
|
||||
tagline = TagLine('TAGS: tag1\n tag2')
|
||||
tags = tagline.tags()
|
||||
self.assertEqual(tags, {Tag('tag1'), Tag('tag2')})
|
||||
|
||||
def test_tags_prefix(self) -> None:
|
||||
tagline = TagLine('TAGS: atag1 stag2 stag3')
|
||||
tags = tagline.tags(prefix='a')
|
||||
self.assertSetEqual(tags, {Tag('atag1')})
|
||||
tags = tagline.tags(prefix='s')
|
||||
self.assertSetEqual(tags, {Tag('stag2'), Tag('stag3')})
|
||||
tags = tagline.tags(prefix='R')
|
||||
self.assertSetEqual(tags, set())
|
||||
|
||||
def test_tags_contain(self) -> None:
|
||||
tagline = TagLine('TAGS: atag1 stag2 stag3')
|
||||
tags = tagline.tags(contain='t')
|
||||
self.assertSetEqual(tags, {Tag('atag1'), Tag('stag2'), Tag('stag3')})
|
||||
tags = tagline.tags(contain='1')
|
||||
self.assertSetEqual(tags, {Tag('atag1')})
|
||||
tags = tagline.tags(contain='R')
|
||||
self.assertSetEqual(tags, set())
|
||||
|
||||
def test_merge(self) -> None:
|
||||
tagline1 = TagLine('TAGS: tag1 tag2')
|
||||
tagline2 = TagLine('TAGS: tag2 tag3')
|
||||
merged_tagline = tagline1.merge({tagline2})
|
||||
self.assertEqual(merged_tagline, 'TAGS: tag1 tag2 tag3')
|
||||
|
||||
def test_delete_tags(self) -> None:
|
||||
tagline = TagLine('TAGS: tag1 tag2 tag3')
|
||||
new_tagline = tagline.delete_tags({Tag('tag1'), Tag('tag3')})
|
||||
self.assertEqual(new_tagline, 'TAGS: tag2')
|
||||
|
||||
def test_add_tags(self) -> None:
|
||||
tagline = TagLine('TAGS: tag1')
|
||||
new_tagline = tagline.add_tags({Tag('tag2'), Tag('tag3')})
|
||||
self.assertEqual(new_tagline, 'TAGS: tag1 tag2 tag3')
|
||||
|
||||
def test_rename_tags(self) -> None:
|
||||
tagline = TagLine('TAGS: old1 old2')
|
||||
new_tagline = tagline.rename_tags({(Tag('old1'), Tag('new1')), (Tag('old2'), Tag('new2'))})
|
||||
self.assertEqual(new_tagline, 'TAGS: new1 new2')
|
||||
|
||||
def test_match_tags(self) -> None:
|
||||
tagline = TagLine('TAGS: tag1 tag2 tag3')
|
||||
|
||||
# Test case 1: Match any tag in 'tags_or'
|
||||
tags_or = {Tag('tag1'), Tag('tag4')}
|
||||
tags_and: set[Tag] = set()
|
||||
tags_not: set[Tag] = set()
|
||||
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||
|
||||
# Test case 2: Match all tags in 'tags_and'
|
||||
tags_or = set()
|
||||
tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3')}
|
||||
tags_not = set()
|
||||
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||
|
||||
# Test case 3: Match any tag in 'tags_or' and match all tags in 'tags_and'
|
||||
tags_or = {Tag('tag1'), Tag('tag4')}
|
||||
tags_and = {Tag('tag1'), Tag('tag2')}
|
||||
tags_not = set()
|
||||
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||
|
||||
# Test case 4: Match any tag in 'tags_or', match all tags in 'tags_and', and exclude tags in 'tags_not'
|
||||
tags_or = {Tag('tag1'), Tag('tag4')}
|
||||
tags_and = {Tag('tag1'), Tag('tag2')}
|
||||
tags_not = {Tag('tag5')}
|
||||
self.assertTrue(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||
|
||||
# Test case 5: No matching tags in 'tags_or'
|
||||
tags_or = {Tag('tag4'), Tag('tag5')}
|
||||
tags_and = set()
|
||||
tags_not = set()
|
||||
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||
|
||||
# Test case 6: Not all tags in 'tags_and' are present
|
||||
tags_or = set()
|
||||
tags_and = {Tag('tag1'), Tag('tag2'), Tag('tag3'), Tag('tag4')}
|
||||
tags_not = set()
|
||||
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||
|
||||
# Test case 7: Some tags in 'tags_not' are present
|
||||
tags_or = {Tag('tag1')}
|
||||
tags_and = set()
|
||||
tags_not = {Tag('tag2')}
|
||||
self.assertFalse(tagline.match_tags(tags_or, tags_and, tags_not))
|
||||
|
||||
# Test case 8: 'tags_or' and 'tags_and' are None, match all tags
|
||||
tags_not = set()
|
||||
self.assertTrue(tagline.match_tags(None, None, tags_not))
|
||||
|
||||
# Test case 9: 'tags_or' and 'tags_and' are None, match all tags except excluded tags
|
||||
tags_not = {Tag('tag2')}
|
||||
self.assertFalse(tagline.match_tags(None, None, tags_not))
|
||||
|
||||
# Test case 10: 'tags_or' and 'tags_and' are empty, match no tags
|
||||
self.assertFalse(tagline.match_tags(set(), set(), None))
|
||||
|
||||
# Test case 11: 'tags_or' is empty, match no tags
|
||||
self.assertFalse(tagline.match_tags(set(), None, None))
|
||||
|
||||
# Test case 12: 'tags_and' is empty, match no tags
|
||||
self.assertFalse(tagline.match_tags(None, set(), None))
|
||||
|
||||
# Test case 13: 'tags_or' is empty, match 'tags_and'
|
||||
tags_and = {Tag('tag1'), Tag('tag2')}
|
||||
self.assertTrue(tagline.match_tags(None, tags_and, None))
|
||||
|
||||
# Test case 14: 'tags_and' is empty, match 'tags_or'
|
||||
tags_or = {Tag('tag1'), Tag('tag2')}
|
||||
self.assertTrue(tagline.match_tags(tags_or, None, None))
|
||||
Loading…
Reference in New Issue
Block a user