diff --git a/chatmastermind/ai_factory.py b/chatmastermind/ai_factory.py index 420b287..a3cf9c3 100644 --- a/chatmastermind/ai_factory.py +++ b/chatmastermind/ai_factory.py @@ -17,7 +17,7 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 is not found, it uses the first AI in the list. """ ai_conf: AIConfig - if args.AI: + if 'AI' in args and args.AI: try: ai_conf = config.ais[args.AI] except KeyError: @@ -32,11 +32,11 @@ def create_ai(args: argparse.Namespace, config: Config) -> AI: # noqa: 11 if ai_conf.name == 'openai': ai = OpenAI(cast(OpenAIConfig, ai_conf)) - if args.model: + if 'model' in args and args.model: ai.config.model = args.model - if args.max_tokens: + if 'max_tokens' in args and args.max_tokens: ai.config.max_tokens = args.max_tokens - if args.temperature: + if 'temperature' in args and args.temperature: ai.config.temperature = args.temperature return ai else: diff --git a/chatmastermind/ais/openai.py b/chatmastermind/ais/openai.py index a388a7a..0e7ad41 100644 --- a/chatmastermind/ais/openai.py +++ b/chatmastermind/ais/openai.py @@ -62,7 +62,12 @@ class OpenAI(AI): """ Return all models supported by this AI. """ - raise NotImplementedError + ret = [] + for engine in sorted(openai.Engine.list()['data'], key=lambda x: x['id']): + if engine['ready']: + ret.append(engine['id']) + ret.sort() + return ret def print_models(self) -> None: """ diff --git a/chatmastermind/main.py b/chatmastermind/main.py index 99aca09..7e18185 100755 --- a/chatmastermind/main.py +++ b/chatmastermind/main.py @@ -100,6 +100,7 @@ def create_parser() -> argparse.ArgumentParser: help="Manage configuration", aliases=['c']) config_cmd_parser.set_defaults(func=config_cmd) + config_cmd_parser.add_argument('-A', '--AI', help='AI ID to use') config_group = config_cmd_parser.add_mutually_exclusive_group(required=True) config_group.add_argument('-l', '--list-models', help="List all available models", action='store_true')