Skip to content

Commit

Permalink
Ray Train Axolotl Integration (#2251)
Browse files Browse the repository at this point in the history
* current

not clean working version
move torch trainer to do_cli
update code with config changes and clean up
edit config
cleanup
add run name to trainer

* address comments

* use axolotl train in multigpu tests and add ray tests for multi-gpu

* accelerate uses underscores for main_process_port arg

* chore: lint

* fix order of accelerate args

* include ray train in docker images

* current

not clean working version
move torch trainer to do_cli
update code with config changes and clean up
edit config
cleanup
add run name to trainer

* address comments

* use axolotl train in multigpu tests and add ray tests for multi-gpu

* accelerate uses underscores for main_process_port arg

* chore: lint

* fix order of accelerate args

* include ray train in docker images

* fix bf16 resolution behavior

* move dtype logic

* x

Signed-off-by: SumanthRH <[email protected]>

* rename

Signed-off-by: SumanthRH <[email protected]>

* add to sidebar

Signed-off-by: SumanthRH <[email protected]>

* Apply suggestions from code review

Co-authored-by: Eric Tang <[email protected]>

* Update docs/ray-integration.qmd

Co-authored-by: Eric Tang <[email protected]>

* pre-commit fixes

Signed-off-by: SumanthRH <[email protected]>

* use output_dir instead of hardcoded saves path

Co-authored-by: NanoCode012 <[email protected]>

* bugfix storage dir

* change type\ for resources_per_worker

---------

Signed-off-by: SumanthRH <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
Co-authored-by: SumanthRH <[email protected]>
Co-authored-by: Sumanth R Hegde <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
Co-authored-by: NanoCode012 <[email protected]>
  • Loading branch information
6 people authored Jan 29, 2025
1 parent 54dd7ab commit 268543a
Show file tree
Hide file tree
Showing 16 changed files with 491 additions and 99 deletions.
1 change: 1 addition & 0 deletions _quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ website:
- docs/multi-node.qmd
- docs/unsloth.qmd
- docs/amd_hpc.qmd
- docs/ray-integration.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
- section: "Reference"
Expand Down
4 changes: 2 additions & 2 deletions cicd/Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
fi

RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi

RUN python scripts/unsloth_install.py | sh
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ WORKDIR /workspace/axolotl

# If AXOLOTL_EXTRAS is set, append it in brackets
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers] $AXOLOTL_ARGS; \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray] $AXOLOTL_ARGS; \
fi

RUN python scripts/unsloth_install.py | sh
Expand Down
Binary file added docs/images/ray-cluster-dashboard.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
93 changes: 93 additions & 0 deletions docs/ray-integration.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
---
title: Ray Train integration
description: How to use Axolotl with Ray Train
---

Axolotl supports using Ray as an alternative to `accelerate` for orchestrating training. This is especially useful for multi-node training since you only have to setup code and dependencies in a single node and launch training as if you were using a single node.

With the `--use-ray` CLI flag, Axolotl will use Ray Train's [`TorchTrainer`](https://docs.ray.io/en/latest/train/api/doc/ray.train.torch.TorchTrainer.html#ray.train.torch.TorchTrainer) to run training.

## Ray cluster setup

A prerequisite using the Ray Train integration is to setup a Ray cluster on your desired node(s). For a detailed guide on how you can get started with ray clusters, check the official Ray docs here: https://docs.ray.io/en/latest/cluster/getting-started.html

Every Ray cluster has one _head_ node and a set of worker nodes. The head node is just like any other worker node, but it also runs certain special processes related to scheduling and orchestration. Ray-enabled scripts are run on the head node and depending on the resources (number of CPUs, GPUs, etc) they request, will be scheduled to run certain tasks on the worker nodes. For more on key concepts behind a Ray cluster, you can refer this [doc](https://docs.ray.io/en/latest/cluster/key-concepts.html#cluster-key-concepts).

## Sanity check

To run a sanity check on whether your ray cluster is setup properly, execute the following on the head node:

```bash
ray status
```

The output should have a summary of your Ray cluster - list of all the nodes in your cluster, the number of CPUs and GPUs in your cluster, etc. For example, if you have a cluster with 1 CPU-only head node and 2 4xL40S worker nodes, the output can look like this:


```
Node status
---------------------------------------------------------------
Active:
1 head
Idle:
2 4xL40S:48CPU-384GB
Pending:
(no pending nodes)
Recent failures:
(no failures)
Resources
---------------------------------------------------------------
Usage:
0.0/96.0 CPU
0.0/8.0 GPU
0B/800.00GiB memory
0B/229.57GiB object_store_memory
Demands:
(no resource demands)
```

You should also be able to see the same on the [Ray dashboard](https://docs.ray.io/en/latest/ray-observability/getting-started.html).


## Configuring training with Ray Train

You can find an example configuration at `configs/llama-3/lora-1b-ray.yaml`.

The key parameters to note here are:

```yaml
...
use_ray: true
ray_num_workers: 4
# optional
resources_per_worker:
GPU: 1
...
```

- `use_ray`: This is the flag that enables the Ray Train integration. You can either use the corresponding `--use-ray` flag in the CLI or set `use_ray` in the config file.
- `ray_num_workers`: This is the number of workers/GPUs to use for training.
- `resources_per_worker`: This is the Ray [resource request](https://docs.ray.io/en/latest/ray-core/scheduling/resources.html) for each worker. This can be used to request a specific GPU type or a custom resource for each worker. For example, if your ray cluster has GPUs of different types, and you only want to use NVIDIA L40S GPUs, you can do

```yaml
resources_per_worker:
accelerator_type:L40S: 0.001
```
## Launching training
You can simply run the following command on the head node:
```bash
axolotl train examples/llama-3/lora-1b-ray.yml --use-ray
```

This will launch training on the head node and workers will be scheduled automatically by Ray Train to run on the appropriate head or worker nodes.

You can also monitor training progress on the Ray dashboard.

Coming back to the example on a Ray cluster with 1 head node and 2 4xL40S worker nodes, let's say you want to make use of all 8 GPUs. You would be able to just set `ray_num_workers: 8` and run the previous command. The Cluster tab will show the following:

![Ray dashboard](./images/ray-cluster-dashboard.png)
79 changes: 79 additions & 0 deletions examples/llama-3/lora-1b-ray.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
base_model: NousResearch/Llama-3.2-1B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out

adapter: lora
lora_model_dir:

sequence_len: 2048
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true

lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed: deepspeed_configs/zero3.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"

use_ray: true
ray_num_workers: 4
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,5 +150,8 @@ def get_package_version():
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
],
"ray": [
"ray[train]",
],
},
)
2 changes: 2 additions & 0 deletions src/axolotl/cli/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class TrainerCliArgs:
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
main_process_port: Optional[int] = field(default=None)
num_processes: Optional[int] = field(default=None)


@dataclass
Expand Down
17 changes: 16 additions & 1 deletion src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,23 @@ def train(config: str, accelerate: bool, **kwargs) -> None:
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()

if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append(str(num_processes))

base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
Expand Down
44 changes: 43 additions & 1 deletion src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Union

import fire
from accelerate import Accelerator
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser

Expand All @@ -15,6 +16,7 @@
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils.config import normalize_config, resolve_dtype
from axolotl.utils.dict import DictDefault

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -63,7 +65,47 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
return_remaining_strings=True
)

do_train(parsed_cfg, parsed_cli_args)
if parsed_cfg.use_ray:
from ray.train import RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer

train_loop_config = {"cfg": parsed_cfg.to_dict(), "cli_args": parsed_cli_args}
trainer = TorchTrainer(
ray_train_func,
train_loop_config=train_loop_config,
scaling_config=ScalingConfig(
num_workers=parsed_cfg.ray_num_workers,
resources_per_worker=parsed_cfg.resources_per_worker.to_dict(),
use_gpu=True,
),
run_config=RunConfig(
name=parsed_cfg.ray_run_name,
storage_path=Path(parsed_cfg.output_dir).absolute().as_posix(),
),
)
return trainer.fit()
return do_train(parsed_cfg, parsed_cli_args)


def ray_train_func(kwargs: dict):
# cast `cfg` back to DictDefault (ray tune deepcopy has issues with DictDefault so needed it to be dict)
# also renormalize the config now that TorchTrainer has spawned distributed workers
cfg = DictDefault(kwargs["cfg"])
normalize_config(cfg)

# now that we are on the worker node, we can check `is_torch_bf16_gpu_available` to resolve dtype
resolve_dtype(cfg)

# ray serializing objects gets rid of frozen attribute - HF expects dict not DefaultDict
if cfg.deepspeed:
cfg.deepspeed = cfg.deepspeed.to_dict()

# initialize accelerator before model instantiation
Accelerator(gradient_accumulation_steps=cfg.gradient_accumulation_steps)

kwargs["cfg"] = cfg

do_train(**kwargs)


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def train(
model.config.save_pretrained(str(Path(cfg.output_dir)))

# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
if (
cfg.local_rank == 0 and not cfg.use_ray
): # ray workers don't have access to this signal

def terminate_handler(_, __, model_weakref):
if model_weakref() is not None:
Expand Down
Loading

0 comments on commit 268543a

Please sign in to comment.