Skip to content

Commit

Permalink
Merge pull request #176 from johnbensnyder/mixtral
Browse files Browse the repository at this point in the history
added mixtral and hybrid shard support
  • Loading branch information
johnbensnyder authored Mar 10, 2024
2 parents b72144b + a208f34 commit f5e0b7b
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 56 deletions.
20 changes: 15 additions & 5 deletions 3.test_cases/10.FSDP/1.distributed-training.sbatch
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

#SBATCH --nodes=4 # number of nodes to use, 4 p4d(e) = 32 A100 GPUs
#SBATCH --nodes=4 # number of nodes to use
#SBATCH --job-name=FSDP # name of your job
#SBATCH --exclusive # job has exclusive use of the resource, no sharing

Expand All @@ -13,12 +13,16 @@ set -ex;
###### User Variables #####
###########################

GPUS_PER_NODE=8 # 4 for G5.12x, 8 for P4/P5

###########################
## Environment Variables ##
###########################

## Plenty of EFA level variables
## Comment out for non-efa instances (G5, G4d, P3)
## Comment out for non-efa instances (G4d, P3)
## For G5.12x, Comment out RDMA and Fork safe
## For G4dn and other G5, comment out all
export FI_EFA_USE_DEVICE_RDMA=1 # use for p4d
export FI_EFA_FORK_SAFE=1
export FI_LOG_LEVEL=1
Expand All @@ -30,7 +34,7 @@ export NCCL_DEBUG=INFO
###########################

declare -a TORCHRUN_ARGS=(
--nproc_per_node=8 \
--nproc_per_node=$GPUS_PER_NODE \
--nnodes=$SLURM_JOB_NUM_NODES \
--rdzv_id=$SLURM_JOB_ID \
--rdzv_backend=c10d \
Expand All @@ -53,10 +57,16 @@ declare -a TRAINING_ARGS=(
--num_heads=32 \ # 7b: 32 13b: 40 70b: 64
--model_type=llama_v2 \
--tokenizer="hf-internal-testing/llama-tokenizer" \
--checkpoint_freq=50 \
--checkpoint_freq=5000 \
--validation_freq=500 \
--max_steps 5000 \
--checkpoint_dir=./checkpoints \
--resume_from_checkpoint=./checkpoints
--dataset='c4' \
--dataset_config_name='en' \
--resume_from_checkpoint=./checkpoints \
--train_batch_size=1 \
--val_batch_size=1 \
--sharding_strategy="hybrid"
)

srun -l ${TORCHRUN} "${TORCHRUN_ARGS[@]}" $TRAIN_SCRIPT "${TRAINING_ARGS[@]}"
91 changes: 91 additions & 0 deletions 3.test_cases/10.FSDP/2.distributed-training-mistral.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/bin/bash

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0

#SBATCH --nodes=4 # number of nodes to use
#SBATCH --job-name=FSDP # name of your job
#SBATCH --exclusive # job has exclusive use of the resource, no sharing

set -ex;

###########################
###### User Variables #####
###########################

GPUS_PER_NODE=8 # 4 for G5.12x, 8 for P4/P5

###########################
## Environment Variables ##
###########################

## Plenty of EFA level variables
## Comment out for non-efa instances (G4d, P3)
## For G5.12x, Comment out RDMA and Fork safe
## For G4dn and other G5, comment out all
export FI_EFA_USE_DEVICE_RDMA=1 # use for p4d
export FI_EFA_FORK_SAFE=1
export FI_LOG_LEVEL=1
export FI_PROVIDER=efa
export NCCL_DEBUG=INFO

###########################
####### Torch Dist #######
###########################

declare -a TORCHRUN_ARGS=(
--nproc_per_node=$GPUS_PER_NODE \
--nnodes=$SLURM_JOB_NUM_NODES \
--rdzv_id=$SLURM_JOB_ID \
--rdzv_backend=c10d \
--rdzv_endpoint=$(hostname) \
)

export TORCHRUN=./pt_fsdp/bin/torchrun
export TRAIN_SCRIPT=./train.py

############################
# Mixtral Training Params ##
############################

declare -a TRAINING_ARGS=(
--train_batch_size=4 \
--val_batch_size=4 \
--max_steps=5000 \
--seed=42 \
--bf16=1 \
--grad_clip=1.0 \
--weight_decay=0.2 \
--beta1=0.9 \
--beta2=0.95 \
--activation_checkpointing=1 \
--intermediate_size=14336 \
--num_key_value_heads=8 \
--logging_freq=1 \
--max_context_width=32768 \
--vocab_size=32000 \
--hidden_width=4096 \
--num_layers=32 \
--num_heads=32 \
--resid_pdrop=0.1 \
--embd_pdrop=0.1 \
--attn_pdrop=0.1 \
--summary_first_pdrop=0.1 \
--initializer_range=0.02 \
--model_type="mixtral" \
--rotary_pct=0.25 \
--rotary_emb_base=10000 \
--lr=0.0001 \
--lr_decay_style="cosine" \
--min_lr=1e-5 \
--warmup=0.0032 \
--plateau=0.0 \
--dataset="c4" \
--tokenizer="mistralai/Mixtral-8x7B-v0.1" \
--epochs=3 \
--dataset_config_name="en" \
--limit_all_gathers=1 \
--sharding_strategy="hybrid"
)

srun -l ${TORCHRUN} "${TORCHRUN_ARGS[@]}" $TRAIN_SCRIPT "${TRAINING_ARGS[@]}"
16 changes: 9 additions & 7 deletions 3.test_cases/10.FSDP/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Get Started Training Llama 2 with PyTorch FSDP in 5 Minutes
# Get Started Training Llama 2 and Mixtral with PyTorch FSDP in 5 Minutes

These scripts provide an easy way to get started with multinode [FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) training on Slurm. It is designed to be as simple as possible, requires no data preparation, and uses a simple Conda environment.

Expand Down Expand Up @@ -36,21 +36,23 @@ If you'd like to instead use your own dataset, you can do so by [formatting it a

## 3. Launch Training

The script to launch a Slurm batch training job can be found in `1.distributed_training.sbatch`. You can adjust the number of training nodes by modifying `#SBATCH --nodes=4`.
The script to launch a Llama 2 Slurm batch training job can be found in `1.distributed_training.sbatch`. The script to launch a Mixtral training can be found in `2.distrbiuted_training_mixtral.sbatch` You can adjust the number of training nodes by modifying `#SBATCH --nodes=4`.

If you are using a non-EFA enable instance, such as G5, comment out lines 21-25.
If you are using a non-RDMA enable instance, such as G5.12x, comment out lines 21-22. These instances have EFA between nodes, but do not have the GPU direct RDMA access of P4d and P5 instances.

```
## Plenty of EFA level variables
## Comment out for non-efa instances (G5, G4d, P3)
# export FI_EFA_USE_DEVICE_RDMA=1 # use for p4d
# export FI_EFA_FORK_SAFE=1
# export FI_LOG_LEVEL=1
# export FI_PROVIDER=efa
# export NCCL_DEBUG=INFO
export FI_LOG_LEVEL=1
export FI_PROVIDER=efa
export NCCL_DEBUG=INFO
```

Also, make sure `--nproc_per_node` to match the number of GPUs on your instance type (8 for P4d/P5, 4 for G5.12xlarge, 1 for G5.xlarge).
If you are using non-EFA enabled instances, such as G4dn, or single GPU G5 nodes, comment out all EFA environment variables on lines 21-25.

Also, under `User Variables` make sure to adjust `GPUS_PER_NODE` to match the number of GPUs on your instance type (8 for P4d/P5, 4 for G5.12xlarge, 1 for G5.xlarge).

You can also adjust the training parameters in `TRAINING_ARGS` (for example, to train Llama 2 70b). Additional parameters can be found in `model/arguments.py`. Note that we use the same directory for both `--checkpoint_dir` and `--resume_from_checkpoint`. If there are multiple checkpoints, `--resume_from_checkpoint` will automatically select the most recent one. This way if our training is interupted for any reason, it will automatically pick up the most recent checkpoint.

Expand Down
18 changes: 13 additions & 5 deletions 3.test_cases/10.FSDP/model_utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ def parse_args(): # pylint: disable=too-many-statements
help="enable gradient checkpointing to reduce memory consumption",
)
opt_grp.add_argument(
"--llama_intermediate_size",
"--intermediate_size",
type=int,
default=11008,
help="intermediate_size for Llama v2, a dimension associated with MLP",
help="intermediate_size, a dimension associated with MLP",
)
opt_grp.add_argument(
"--num_key_value_heads",
type=int,
default=None,
help="num_key_value_heads for Llama v2",
help="num_key_value_heads for GQA",
)
parser.add_argument(
"--logging_freq", type=int, default=1, help="number of iterations between logging"
Expand Down Expand Up @@ -79,6 +79,13 @@ def parse_args(): # pylint: disable=too-many-statements
fsdp_grp.add_argument("--offload_activations", type=int, default=0)
fsdp_grp.add_argument("--activation_loading_horizon", type=int, default=2)
fsdp_grp.add_argument("--limit_all_gathers", default=1, type=int)
fsdp_grp.add_argument(
"--sharding_strategy",
type=str,
default="full",
choices=["full", "hybrid"],
help="FSDP sharding strategy https://pytorch.org/docs/stable/fsdp.html",
)

# learning rate
lr_grp = parser.add_argument_group(
Expand Down Expand Up @@ -118,7 +125,8 @@ def parse_args(): # pylint: disable=too-many-statements
help="Percentage of total iterations to keep at max if using plateau lr",
)
io_grp = parser.add_argument_group(title="io", description="location for input and output")
io_grp.add_argument("--dataset_path", type=str, default="c4")
io_grp.add_argument("--dataset", type=str, default="c4")
io_grp.add_argument("--dataset_config_name", type=str, default=None)
io_grp.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b")
io_grp.add_argument(
"--resume_from_checkpoint",
Expand Down Expand Up @@ -155,4 +163,4 @@ def parse_args(): # pylint: disable=too-many-statements
help="number of batches to estimate validation loss",
)

return parser.parse_known_args()
return parser.parse_known_args()
56 changes: 53 additions & 3 deletions 3.test_cases/10.FSDP/model_utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,15 @@
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from datetime import datetime
import tqdm
import logging
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
from transformers import AutoTokenizer
from datasets import load_dataset

from model_utils.concat_dataset import ConcatTokensDataset

g_gigabyte = 1024**3

Expand Down Expand Up @@ -126,7 +131,7 @@ def get_model_config(args):
model_config = LlamaConfig(
vocab_size=args.vocab_size,
hidden_size=args.hidden_width,
intermediate_size=args.llama_intermediate_size,
intermediate_size=args.intermediate_size,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_heads,
num_key_value_heads=args.num_key_value_heads,
Expand All @@ -139,8 +144,26 @@ def get_model_config(args):
tie_word_embeddings=False,
rope_scaling=None,
)
elif "mixtral" in args.model_type:
from transformers import MixtralConfig
model_config = MixtralConfig(
vocab_size=args.vocab_size,
hidden_size=args.hidden_width,
intermediate_size=args.intermediate_size,
num_hidden_layers=args.num_layers,
num_attention_heads=args.num_heads,
num_key_value_heads=args.num_key_value_heads,
hidden_act="silu",
max_position_embeddings=args.max_context_width,
initializer_range=args.initializer_range,
rms_norm_eps=1e-5,
use_cache=False,
tie_word_embeddings=False,
num_experts_per_tok=2,
num_local_experts=8,
)
else:
raise NotImplementedError
raise NotImplementedError(f"Model {args.model_type} not implemented")
return model_config

def compute_num_params(model):
Expand Down Expand Up @@ -202,6 +225,15 @@ def get_transformer_layer(model_type="gpt2"):
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

transformer_layer = LlamaDecoderLayer

elif model_type == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer

transformer_layer = MixtralDecoderLayer

else:
raise NotImplementedError(f"Model type {model_type} not implemented")

return transformer_layer

def get_sharding_strategy(strategy: str):
Expand Down Expand Up @@ -412,4 +444,22 @@ def get_learning_rate_scheduler(optimizer, args):
override_lr_scheduler=False,
)

return lr_scheduler
return lr_scheduler

def create_streaming_dataloader(dataset,
tokenizer,
name=None,
global_rank=0,
batch_size=1,
max_context_width=4096,
workers=4,
split=None):
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
data = load_dataset(dataset, name=name, streaming=True, split=split).shuffle(42+global_rank)
train_concat_dataset = ConcatTokensDataset(data, tokenizer, max_context_width, True)
train_dataloader = DataLoader(train_concat_dataset,
batch_size=batch_size,
num_workers=workers,
pin_memory=True,
prefetch_factor=4)
return train_dataloader
Loading

0 comments on commit f5e0b7b

Please sign in to comment.