Skip to content

Commit

Permalink
CLI Implementation with Click (#2107)
Browse files Browse the repository at this point in the history
* Initial CLI implementation with click package

* Adding fetch command for pulling examples and deepspeed configs

* Automating default options for CliArgs classes

* Mimicking existing no config behavior

* bugfix in choose_config

* Updating fetch to sync instead of re-download

* bugfix

* isort fix

* fixing yaml isort order

* pre-commit fixes

* simplifying argument parsing -- pass through kwargs to do_cli

* make accelerate launch default for non-preprocess commands

* fixing arg handling

* testing None placeholder approach

* removing hacky --use-gpu argument to preprocess command

* Adding brief README documentation for CLI

* remove (New)

* Initial CLI pytest tests

* progress on CLI pytest

* adding inference CLI tests; cleanup

* Refactor train CLI tests to remove various mocking

* Major CLI test refator; adding remaining CLI codepath test coverage

* pytest fixes

* remove integration markers

* parallelizing examples, deepspeed config downloads; rename test to match other CLI test naming

* moving cli pytest due to isolation issues; cleanup

* testing fixes; various minor improvements

* fix

* tests fix

* Update tests/cli/conftest.py

Co-authored-by: Wing Lian <[email protected]>

---------

Co-authored-by: Dan Saunders <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
3 people authored Dec 6, 2024
1 parent e399ba5 commit fc973f4
Show file tree
Hide file tree
Showing 25 changed files with 1,113 additions and 12 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/tests-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ jobs:
- name: Run tests
run: |
pytest --ignore=tests/e2e/ tests/
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
- name: cleanup pip cache
run: |
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ jobs:
- name: Run tests
run: |
pytest -n8 --ignore=tests/e2e/ tests/
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
- name: cleanup pip cache
run: |
Expand Down Expand Up @@ -123,7 +124,8 @@ jobs:
- name: Run tests
run: |
pytest -n8 --ignore=tests/e2e/ tests/
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
pytest tests/patched/
- name: cleanup pip cache
run: |
Expand Down
41 changes: 40 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,46 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```

### Axolotl CLI

If you've installed this package using `pip` from source, we now support a new, more
streamlined CLI using [click](https://click.palletsprojects.com/en/stable/). Rewriting
the above commands:

```bash
# preprocess datasets - optional but recommended
CUDA_VISIBLE_DEVICES="0" axolotl preprocess examples/openllama-3b/lora.yml

# finetune lora
axolotl train examples/openllama-3b/lora.yml

# inference
axolotl inference examples/openllama-3b/lora.yml \
--lora-model-dir="./outputs/lora-out"

# gradio
axolotl inference examples/openllama-3b/lora.yml \
--lora-model-dir="./outputs/lora-out" --gradio

# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
axolotl train https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main/examples/openllama-3b/lora.yml
```

We've also added a new command for fetching `examples` and `deepspeed_configs` to your
local machine. This will come in handy when installing `axolotl` from PyPI.

```bash
# Fetch example YAML files (stores in "examples/" folder)
axolotl fetch examples

# Fetch deepspeed config files (stores in "deepspeed_configs/" folder)
axolotl fetch deepspeed_configs

# Optionally, specify a destination folder
axolotl fetch examples --dest path/to/folder
```

## Badge ❤🏷️

Building something cool with Axolotl? Consider adding a badge to your model card.
Expand Down Expand Up @@ -206,7 +246,6 @@ Thanks to all of our contributors to date. Help drive open source AI progress fo
❌: not supported
❓: untested


## Advanced Setup

### Environment
Expand Down
5 changes: 3 additions & 2 deletions cicd/cicd.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/bin/bash
set -e

pytest -v --durations=10 -n8 --ignore=tests/e2e/ /workspace/axolotl/tests/
pytest -v --durations=10 -n1 --dist loadfile -v /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 -n1 --dist loadfile /workspace/axolotl/tests/e2e/patched/ /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
1 change: 1 addition & 0 deletions outputs
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,11 @@ def parse_requirements():
packages=find_packages("src"),
install_requires=install_requires,
dependency_links=dependency_links,
entry_points={
"console_scripts": [
"axolotl=axolotl.cli.main:main",
],
},
extras_require={
"flash-attn": [
"flash-attn==2.7.0.post2",
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def choose_config(path: Path):

if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return yaml_files[0]
return str(yaml_files[0])

print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
Expand All @@ -391,7 +391,7 @@ def choose_config(path: Path):
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = yaml_files[choice - 1]
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/cli/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
CLI to run inference on a trained model
"""
from pathlib import Path
from typing import Union

import fire
import transformers
Expand All @@ -16,7 +17,7 @@
from axolotl.common.cli import TrainerCliArgs


def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, inference=True, **kwargs)
Expand Down
231 changes: 231 additions & 0 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
"""CLI definition for various axolotl commands."""
# pylint: disable=redefined-outer-name
import subprocess # nosec B404
from typing import Optional

import click

from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
build_command,
fetch_from_github,
)
from axolotl.common.cli import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig


@click.group()
def cli():
"""Axolotl CLI - Train and fine-tune large language models"""


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
def preprocess(config: str, **kwargs):
"""Preprocess datasets before training."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

from axolotl.cli.preprocess import do_cli

do_cli(config=config, **kwargs)


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def train(config: str, accelerate: bool, **kwargs):
"""Train or fine-tune a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.train import do_cli

do_cli(config=config, **kwargs)


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU inference",
)
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing LoRA model",
)
@click.option(
"--base-model",
type=click.Path(exists=True, path_type=str),
help="Path to base model for non-LoRA models",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def inference(
config: str,
accelerate: bool,
lora_model_dir: Optional[str] = None,
base_model: Optional[str] = None,
**kwargs,
):
"""Run inference with a trained model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
del kwargs["inference"] # interferes with inference.do_cli

if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["output_dir"] = base_model

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.inference import do_cli

do_cli(config=config, **kwargs)


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=False,
help="Use accelerate launch for multi-GPU operations",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing model weights to shard",
)
@click.option(
"--save-dir",
type=click.Path(path_type=str),
help="Directory to save sharded weights",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def shard(config: str, accelerate: bool, **kwargs):
"""Shard model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.shard import do_cli

do_cli(config=config, **kwargs)


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for weight merging",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing sharded weights",
)
@click.option(
"--save-path", type=click.Path(path_type=str), help="Path to save merged weights"
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
"""Merge sharded FSDP model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

if accelerate:
base_cmd = [
"accelerate",
"launch",
"-m",
"axolotl.cli.merge_sharded_fsdp_weights",
]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.merge_sharded_fsdp_weights import do_cli

do_cli(config=config, **kwargs)


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing the LoRA model to merge",
)
@click.option(
"--output-dir",
type=click.Path(path_type=str),
help="Directory to save the merged model",
)
def merge_lora(
config: str,
lora_model_dir: Optional[str] = None,
output_dir: Optional[str] = None,
):
"""Merge a trained LoRA into a base model"""
kwargs = {}
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if output_dir:
kwargs["output_dir"] = output_dir

from axolotl.cli.merge_lora import do_cli

do_cli(config=config, **kwargs)


@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]):
"""
Fetch example configs or other resources.
Available directories:
- examples: Example configuration files
- deepspeed_configs: DeepSpeed configuration files
"""
fetch_from_github(f"{directory}/", dest)


def main():
cli()


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion src/axolotl/cli/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
CLI to run merge a trained LoRA into a base model
"""
from pathlib import Path
from typing import Union

import fire
import transformers
Expand All @@ -11,7 +12,7 @@
from axolotl.common.cli import TrainerCliArgs


def do_cli(config: Path = Path("examples/"), **kwargs):
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
Expand Down
Loading

0 comments on commit fc973f4

Please sign in to comment.