-
-
Notifications
You must be signed in to change notification settings - Fork 949
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CLI Implementation with Click (#2107)
* Initial CLI implementation with click package * Adding fetch command for pulling examples and deepspeed configs * Automating default options for CliArgs classes * Mimicking existing no config behavior * bugfix in choose_config * Updating fetch to sync instead of re-download * bugfix * isort fix * fixing yaml isort order * pre-commit fixes * simplifying argument parsing -- pass through kwargs to do_cli * make accelerate launch default for non-preprocess commands * fixing arg handling * testing None placeholder approach * removing hacky --use-gpu argument to preprocess command * Adding brief README documentation for CLI * remove (New) * Initial CLI pytest tests * progress on CLI pytest * adding inference CLI tests; cleanup * Refactor train CLI tests to remove various mocking * Major CLI test refator; adding remaining CLI codepath test coverage * pytest fixes * remove integration markers * parallelizing examples, deepspeed config downloads; rename test to match other CLI test naming * moving cli pytest due to isolation issues; cleanup * testing fixes; various minor improvements * fix * tests fix * Update tests/cli/conftest.py Co-authored-by: Wing Lian <[email protected]> --------- Co-authored-by: Dan Saunders <[email protected]> Co-authored-by: Wing Lian <[email protected]>
- Loading branch information
1 parent
e399ba5
commit fc973f4
Showing
25 changed files
with
1,113 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
#!/bin/bash | ||
set -e | ||
|
||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/ | ||
pytest -v --durations=10 -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ | ||
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ | ||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/ | ||
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/ | ||
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
/workspace/data/axolotl-artifacts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
"""CLI definition for various axolotl commands.""" | ||
# pylint: disable=redefined-outer-name | ||
import subprocess # nosec B404 | ||
from typing import Optional | ||
|
||
import click | ||
|
||
from axolotl.cli.utils import ( | ||
add_options_from_config, | ||
add_options_from_dataclass, | ||
build_command, | ||
fetch_from_github, | ||
) | ||
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs | ||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig | ||
|
||
|
||
@click.group() | ||
def cli(): | ||
"""Axolotl CLI - Train and fine-tune large language models""" | ||
|
||
|
||
@cli.command() | ||
@click.argument("config", type=click.Path(exists=True, path_type=str)) | ||
@add_options_from_dataclass(PreprocessCliArgs) | ||
@add_options_from_config(AxolotlInputConfig) | ||
def preprocess(config: str, **kwargs): | ||
"""Preprocess datasets before training.""" | ||
kwargs = {k: v for k, v in kwargs.items() if v is not None} | ||
|
||
from axolotl.cli.preprocess import do_cli | ||
|
||
do_cli(config=config, **kwargs) | ||
|
||
|
||
@cli.command() | ||
@click.argument("config", type=click.Path(exists=True, path_type=str)) | ||
@click.option( | ||
"--accelerate/--no-accelerate", | ||
default=True, | ||
help="Use accelerate launch for multi-GPU training", | ||
) | ||
@add_options_from_dataclass(TrainerCliArgs) | ||
@add_options_from_config(AxolotlInputConfig) | ||
def train(config: str, accelerate: bool, **kwargs): | ||
"""Train or fine-tune a model.""" | ||
kwargs = {k: v for k, v in kwargs.items() if v is not None} | ||
|
||
if accelerate: | ||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"] | ||
if config: | ||
base_cmd.append(config) | ||
cmd = build_command(base_cmd, kwargs) | ||
subprocess.run(cmd, check=True) # nosec B603 | ||
else: | ||
from axolotl.cli.train import do_cli | ||
|
||
do_cli(config=config, **kwargs) | ||
|
||
|
||
@cli.command() | ||
@click.argument("config", type=click.Path(exists=True, path_type=str)) | ||
@click.option( | ||
"--accelerate/--no-accelerate", | ||
default=True, | ||
help="Use accelerate launch for multi-GPU inference", | ||
) | ||
@click.option( | ||
"--lora-model-dir", | ||
type=click.Path(exists=True, path_type=str), | ||
help="Directory containing LoRA model", | ||
) | ||
@click.option( | ||
"--base-model", | ||
type=click.Path(exists=True, path_type=str), | ||
help="Path to base model for non-LoRA models", | ||
) | ||
@click.option("--gradio", is_flag=True, help="Launch Gradio interface") | ||
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode") | ||
@add_options_from_dataclass(TrainerCliArgs) | ||
@add_options_from_config(AxolotlInputConfig) | ||
def inference( | ||
config: str, | ||
accelerate: bool, | ||
lora_model_dir: Optional[str] = None, | ||
base_model: Optional[str] = None, | ||
**kwargs, | ||
): | ||
"""Run inference with a trained model.""" | ||
kwargs = {k: v for k, v in kwargs.items() if v is not None} | ||
del kwargs["inference"] # interferes with inference.do_cli | ||
|
||
if lora_model_dir: | ||
kwargs["lora_model_dir"] = lora_model_dir | ||
if base_model: | ||
kwargs["output_dir"] = base_model | ||
|
||
if accelerate: | ||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] | ||
if config: | ||
base_cmd.append(config) | ||
cmd = build_command(base_cmd, kwargs) | ||
subprocess.run(cmd, check=True) # nosec B603 | ||
else: | ||
from axolotl.cli.inference import do_cli | ||
|
||
do_cli(config=config, **kwargs) | ||
|
||
|
||
@cli.command() | ||
@click.argument("config", type=click.Path(exists=True, path_type=str)) | ||
@click.option( | ||
"--accelerate/--no-accelerate", | ||
default=False, | ||
help="Use accelerate launch for multi-GPU operations", | ||
) | ||
@click.option( | ||
"--model-dir", | ||
type=click.Path(exists=True, path_type=str), | ||
help="Directory containing model weights to shard", | ||
) | ||
@click.option( | ||
"--save-dir", | ||
type=click.Path(path_type=str), | ||
help="Directory to save sharded weights", | ||
) | ||
@add_options_from_dataclass(TrainerCliArgs) | ||
@add_options_from_config(AxolotlInputConfig) | ||
def shard(config: str, accelerate: bool, **kwargs): | ||
"""Shard model weights.""" | ||
kwargs = {k: v for k, v in kwargs.items() if v is not None} | ||
|
||
if accelerate: | ||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"] | ||
if config: | ||
base_cmd.append(config) | ||
cmd = build_command(base_cmd, kwargs) | ||
subprocess.run(cmd, check=True) # nosec B603 | ||
else: | ||
from axolotl.cli.shard import do_cli | ||
|
||
do_cli(config=config, **kwargs) | ||
|
||
|
||
@cli.command() | ||
@click.argument("config", type=click.Path(exists=True, path_type=str)) | ||
@click.option( | ||
"--accelerate/--no-accelerate", | ||
default=True, | ||
help="Use accelerate launch for weight merging", | ||
) | ||
@click.option( | ||
"--model-dir", | ||
type=click.Path(exists=True, path_type=str), | ||
help="Directory containing sharded weights", | ||
) | ||
@click.option( | ||
"--save-path", type=click.Path(path_type=str), help="Path to save merged weights" | ||
) | ||
@add_options_from_dataclass(TrainerCliArgs) | ||
@add_options_from_config(AxolotlInputConfig) | ||
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs): | ||
"""Merge sharded FSDP model weights.""" | ||
kwargs = {k: v for k, v in kwargs.items() if v is not None} | ||
|
||
if accelerate: | ||
base_cmd = [ | ||
"accelerate", | ||
"launch", | ||
"-m", | ||
"axolotl.cli.merge_sharded_fsdp_weights", | ||
] | ||
if config: | ||
base_cmd.append(config) | ||
cmd = build_command(base_cmd, kwargs) | ||
subprocess.run(cmd, check=True) # nosec B603 | ||
else: | ||
from axolotl.cli.merge_sharded_fsdp_weights import do_cli | ||
|
||
do_cli(config=config, **kwargs) | ||
|
||
|
||
@cli.command() | ||
@click.argument("config", type=click.Path(exists=True, path_type=str)) | ||
@click.option( | ||
"--lora-model-dir", | ||
type=click.Path(exists=True, path_type=str), | ||
help="Directory containing the LoRA model to merge", | ||
) | ||
@click.option( | ||
"--output-dir", | ||
type=click.Path(path_type=str), | ||
help="Directory to save the merged model", | ||
) | ||
def merge_lora( | ||
config: str, | ||
lora_model_dir: Optional[str] = None, | ||
output_dir: Optional[str] = None, | ||
): | ||
"""Merge a trained LoRA into a base model""" | ||
kwargs = {} | ||
if lora_model_dir: | ||
kwargs["lora_model_dir"] = lora_model_dir | ||
if output_dir: | ||
kwargs["output_dir"] = output_dir | ||
|
||
from axolotl.cli.merge_lora import do_cli | ||
|
||
do_cli(config=config, **kwargs) | ||
|
||
|
||
@cli.command() | ||
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) | ||
@click.option("--dest", help="Destination directory") | ||
def fetch(directory: str, dest: Optional[str]): | ||
""" | ||
Fetch example configs or other resources. | ||
Available directories: | ||
- examples: Example configuration files | ||
- deepspeed_configs: DeepSpeed configuration files | ||
""" | ||
fetch_from_github(f"{directory}/", dest) | ||
|
||
|
||
def main(): | ||
cli() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.