Skip to content

Commit

Permalink
Merge branch 'main' into refactor-tp
Browse files Browse the repository at this point in the history
  • Loading branch information
mayank31398 committed Aug 13, 2024
2 parents 7d0814e + 722351c commit 4e6f95f
Show file tree
Hide file tree
Showing 60 changed files with 1,445 additions and 768 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ jobs:
python-version: 3.11

- name: Installation
run: make install-dev
run: |
make install-dev
git clone -b granitemoe https://github.com/mayank31398/transformers && cd transformers && pip install . && cd ..
- name: Unit Tests
run: make test
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ sh scripts/generate.sh configs/sst2/inference.yml
sh scripts/unshard.sh configs/sst2/unshard.yml
```

## Running basic inference

For a simple HuggingFace inference example, refer to [tools/inference.py](tools/inference.py).
For an example running tensor parallel inference, refer to [tools/tensor_parallel_inference.py](tools/tensor_parallel_inference.py).

## Using custom datasets
The data directory should obey the following structure:
```text
Expand Down
5 changes: 2 additions & 3 deletions configs/research/cross-layer-attention/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@ model_args:
pretrained_config:
activation_function: swiglu
add_bias: true
apply_residual_connection_post_layernorm: false
attention_softmax_in_fp32: true
attn_pdrop: 0
bos_token_id: 0
embd_pdrop: 0
eos_token_id: 0
initializer_range: 0.02
layer_norm_epsilon: 1e-05
model_type: gpt_megatron
model_type: gpt_dolomite
n_embd: 3072
n_head: 12
n_inner: 8192
Expand Down Expand Up @@ -86,7 +85,7 @@ training_parameters:
gradient_accumulation_steps: 4

optimizer_args:
class_name: ApexFusedAdam
class_name: TorchAdamW
class_args:
lr: 3e-4
weight_decay: 0.1
Expand Down
3 changes: 1 addition & 2 deletions configs/research/cross-layer-attention/cla.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ model_args:
pretrained_config:
activation_function: swiglu
add_bias: true
apply_residual_connection_post_layernorm: false
attention_softmax_in_fp32: true
attn_pdrop: 0
bos_token_id: 0
Expand Down Expand Up @@ -119,7 +118,7 @@ training_parameters:
gradient_accumulation_steps: 4

optimizer_args:
class_name: ApexFusedAdam
class_name: TorchAdamW
class_args:
lr: 3e-4
weight_decay: 0.1
Expand Down
15 changes: 6 additions & 9 deletions dolomite_engine/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ class ModelArgs(BaseArgs):
reset_attention_mask: bool = False
# whether to reset position ids for pretraining
reset_position_ids: bool = False
# whether to upcast logits for loss
upcast_logits_for_loss: bool = False

def model_post_init(self, __context: Any) -> None:
_check_not_None([(self.model_class, "model_class")])
Expand All @@ -77,11 +75,6 @@ def model_post_init(self, __context: Any) -> None:

self.model_class: AutoModelForCausalLM | AutoModelForSeq2SeqLM = getattr(transformers, self.model_class)

if self.pretrained_config is not None:
assert self.upcast_logits_for_loss == self.pretrained_config.get(
"upcast_logits_for_loss", False
), "`upcast_logits_for_loss` should match in the model pretrained_config and the model_args"


class PromptTuningArgs(BaseArgs):
# prompt tuning init method
Expand Down Expand Up @@ -197,6 +190,9 @@ class LoadArgs(BaseArgs):
load_experiments_tracker_state: bool = True
# whether to load starting iteration
load_starting_iteration: bool = True
# whether to resume learning rate during training
# this is a NO-OP if we are loading LR scheduler
resume_learning_rate: bool = True

def model_post_init(self, __context: Any) -> None:
_check_not_None([(self.load_path, "load_path")])
Expand All @@ -206,6 +202,9 @@ def model_post_init(self, __context: Any) -> None:
not self.load_lr_scheduler
), "lr_scheduler loading doesn't make sense if you aren't loading optimizer"

if self.load_lr_scheduler:
assert self.resume_learning_rate, "resume learning rate needs to be True when reloading LR scheduler"


class DatasetArgs(BaseArgs):
# dataset class
Expand Down Expand Up @@ -381,8 +380,6 @@ class LoggingArgs(BaseArgs):
logging_level: str = "INFO"
# log interval
log_interval: int = 1
# running mean window
running_mean_window: int = 10
# arguments if using aim
aim_args: AimArgs | None = None
# arguments if using wandb
Expand Down
58 changes: 47 additions & 11 deletions dolomite_engine/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.distributed
import torch.distributed.checkpoint as dcp
import yaml
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import _CHECKPOINT_WRAPPED_MODULE
from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.state_dict import (
Expand All @@ -30,6 +31,7 @@
from .enums import DistributedBackend, Mode, TuningMethod
from .hf_models.models.gpt_dolomite_TP import fix_unsharded_state_dict
from .model_wrapper import ModelWrapper, get_model
from .optimization import get_scheduler
from .utils import (
ExperimentsTracker,
ProcessGroupManager,
Expand Down Expand Up @@ -93,12 +95,12 @@ def save_checkpoint(
):
model_state_dict = model.state_dict()
if dp_rank == 0:
torch.save(model_state_dict, _get_model_path(save_path))
torch.save(model_state_dict, f"{_get_model_path(save_path)}.pt")

if save_optimizer:
optimizer_state_dict = FSDP.optim_state_dict(model=model, optim=optimizer)
if dp_rank == 0:
torch.save(optimizer_state_dict, _get_optimizer_path(save_path))
torch.save(optimizer_state_dict, f"{_get_optimizer_path(save_path)}.pt")
else:
dcp.save(get_model_state_dict(model), checkpoint_id=_get_model_path(save_path))

Expand Down Expand Up @@ -189,10 +191,6 @@ def load_checkpoint_for_training(
log_rank_0(logging.INFO, f"loading checkpoint saved at {load_path}")

if distributed_backend == DistributedBackend.deepspeed:
from deepspeed import DeepSpeedEngine

assert isinstance(model, DeepSpeedEngine)

model.load_checkpoint(
args.load_args.load_path,
tag=_get_checkpoint_tag(iteration),
Expand All @@ -201,21 +199,21 @@ def load_checkpoint_for_training(
)
elif distributed_backend == DistributedBackend.torch:
if args.distributed_args.fsdp_algorithm == 1:
assert isinstance(model, FSDP)

# TODO add support for local state dict
with FSDP.state_dict_type(
model,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
):
model.load_state_dict(torch.load(_get_model_path(load_path)))
model.load_state_dict(torch.load(f"{_get_model_path(load_path)}.pt", map_location="cpu"))

if load_optimizer:
optimizer.load_state_dict(
FSDP.optim_state_dict_to_load(
model=model, optim=optimizer, optim_state_dict=torch.load(_get_optimizer_path(load_path))
model=model,
optim=optimizer,
optim_state_dict=torch.load(f"{_get_optimizer_path(load_path)}.pt", map_location="cpu"),
)
)
else:
Expand All @@ -232,7 +230,12 @@ def load_checkpoint_for_training(
del optimizer_state_dict

if load_lr_scheduler:
assert load_optimizer, "load_lr_scheduler requires loading of optimizer"

lr_scheduler.load_state_dict(torch.load(_get_lr_scheduler_path(load_path)))
else:
if args.load_args.resume_learning_rate:
_resume_learning_rate(args, optimizer=optimizer, lr_scheduler=lr_scheduler, iteration=iteration)
else:
raise ValueError(f"unexpected distributed_backend ({distributed_backend})")

Expand Down Expand Up @@ -315,7 +318,7 @@ def load_checkpoint_for_inference(
for key in list(state.keys()):
state[key] = state[key].to(dtype)
# fix for gradient checkpointing
state[key.replace("._checkpoint_wrapped_module", "")] = state.pop(key)
state[key.replace(f".{_CHECKPOINT_WRAPPED_MODULE}", "")] = state.pop(key)

strict = True

Expand Down Expand Up @@ -388,6 +391,10 @@ def load_checkpoint_for_inference(
if use_meta:
model = model.to_empty(device="cpu")

dtype = string_to_torch_dtype(model.dtype)
for key in list(state.keys()):
state[key] = state[key].to(dtype)

model.load_state_dict(state)
else:
raise ValueError(f"unexpected distributed_backend ({args['distributed_args']['distributed_backend']})")
Expand All @@ -409,6 +416,35 @@ def save_args(args: TrainingArgs | InferenceArgs, save_path: str, mode: Mode) ->
yaml.dump(args.to_dict(), open(save_path, "w"), indent=2)


def _resume_learning_rate(
args: TrainingArgs, optimizer: Optimizer, lr_scheduler: LambdaLR, iteration: int | None = None
) -> None:
initial_lr = []
for grp in optimizer.param_groups:
initial_lr.append(grp["initial_lr"])
grp["initial_lr"] = grp["lr"]

# we create lr scheduler again here since optimizer is loaded from disk and lr scheduler is now out of sync
# this helps to resume phase 2
lr_scheduler_tmp = get_scheduler(
optimizer=optimizer,
num_warmup_steps=args.lr_scheduler_args.num_warmup_steps,
num_constant_steps=args.lr_scheduler_args.num_constant_steps,
num_decay_steps=args.lr_scheduler_args.num_decay_steps,
num_training_steps=args.training_parameters.num_training_steps,
lr_decay_style=args.lr_scheduler_args.lr_decay_style,
lr_decay_factor=args.lr_scheduler_args.lr_decay_factor,
extra_lr_scheduler_args=args.lr_scheduler_args.extra_lr_scheduler_args,
last_epoch=-1 if iteration is None else iteration - 1,
)

for grp, lr_ in zip(optimizer.param_groups, initial_lr):
grp["initial_lr"] = lr_

lr_scheduler.load_state_dict(lr_scheduler_tmp.state_dict())
del lr_scheduler_tmp


def _get_checkpoint_tag(iteration: int) -> str:
return f"global_step{iteration}"

Expand Down
4 changes: 3 additions & 1 deletion dolomite_engine/data/megatron/indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

# Essentially re-written in entirety

from __future__ import annotations

import logging
import os
import shutil
Expand Down Expand Up @@ -109,7 +111,7 @@ def __init__(self, idx_path: str, dtype: type[numpy.number]) -> None:
self.idx_path = idx_path
self.dtype = dtype

def __enter__(self) -> "_IndexWriter":
def __enter__(self) -> _IndexWriter:
"""Enter the context introduced by the 'with' keyword
Returns:
Expand Down
7 changes: 6 additions & 1 deletion dolomite_engine/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,15 @@ def collate_fn(
if not use_padding_free_transformer:
input_ids = torch.tensor(input_ids)
attention_mask = torch.tensor(attention_mask)

if labels is not None:
labels = torch.tensor(labels)

result = {"input_ids": input_ids, "attention_mask": attention_mask}
result = {"input_ids": input_ids}

if not use_padding_free_transformer:
result["attention_mask"] = attention_mask

if mode == Mode.training:
result["labels"] = labels

Expand Down
Loading

0 comments on commit 4e6f95f

Please sign in to comment.