Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix finetuning #101

Merged
merged 24 commits into from
Jan 8, 2025
Merged
2 changes: 1 addition & 1 deletion dolomite_engine/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@


def ensure_last_checkpoint_is_saved() -> None:
global _FUTURE
if _FUTURE is not None:
_FUTURE.result()

Expand Down Expand Up @@ -194,6 +193,7 @@ def save_checkpoint(

save_args(args, save_path, mode=Mode.training)

global _FUTURE
_FUTURE = dcp.async_save(
{
"state": _Saver(
Expand Down
157 changes: 135 additions & 22 deletions dolomite_engine/finetune.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from contextlib import nullcontext
from contextlib import AbstractContextManager, nullcontext

import torch
from torch.distributed.pipelining.schedules import _PipelineSchedule
from torch.distributed.tensor.parallel import loss_parallel
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
from transformers import set_seed

from .arguments import TrainingArgs, get_args
Expand All @@ -11,16 +12,127 @@
from .data import ResumableDataLoader, custom_iterator, get_finetuning_dataloader, get_next_batch
from .distributed import dtensor_to_tensor, wrap_model_container_for_distributed_training
from .enums import DatasetSplit, Mode, TuningMethod
from .model_wrapper import get_model_container
from .model_wrapper import ModelWrapper, get_model_container
from .optimization import get_optimizer_container, get_scheduler_container
from .train_utils import all_reduce_metrics_tracker, get_torch_profiler, track_metrics, train_step
from .utils import ExperimentsTracker, MetricsTrackingDict, ProcessGroupManager, init_distributed, setup_tf32
from .train_utils import all_reduce_metrics_tracker, get_torch_profiler, track_metrics
from .utils import (
ExperimentsTracker,
MetricsTrackingDict,
ProcessGroupManager,
StepTracker,
init_distributed,
is_torchao_available,
setup_tf32,
)


if is_torchao_available():
from .distributed import FP8Manager


def train_step_without_pipeline_parallel(
model: ModelWrapper,
optimizer: Optimizer,
lr_scheduler: LambdaLR,
train_dataloader: ResumableDataLoader,
gradient_clipping: float,
forward_context: AbstractContextManager,
backward_context: AbstractContextManager,
sync_every_gradient_accumulation_step: bool,
) -> MetricsTrackingDict:
"""runs backpropagation and applies the gradient if at the edge of gradient accumulation boundary

Args:
model (ModelWrapper): model
optimizer (Optimizer): optimizer
lr_scheduler (LamdaLR): learning rate scheduler
train_dataloader (ResumableDataLoader): training dataloader
gradient_accumulation_steps (int): gradient accumulation steps
gradient_clipping (float): gradient clipping value
forward_context (AbstractContextManager): a context that is used for every model forward call
backward_context (AbstractContextManager): a context that is used for every model backward call
sync_every_gradient_accumulation_step (bool): whether to sync on every gradient accumulation step

Returns:
MetricsTrackingDict: metrics to track
"""

fsdp_algorithm = 2 if hasattr(model, "set_requires_gradient_sync") else 1

no_sync = nullcontext
if not sync_every_gradient_accumulation_step:
if fsdp_algorithm == 1:
no_sync = model.no_sync
else:
model.set_requires_gradient_sync(False)

metrics_tracker = MetricsTrackingDict({})
grad_norm = None
optimizer.zero_grad()

gradient_accumulation_steps = StepTracker.get_gradient_accumulation_steps()

# note the effect of gradient accumulation division is already in the lm_loss_multiplier
batches = [get_next_batch(train_dataloader) for _ in range(gradient_accumulation_steps)]
lm_loss_multiplier = 1 / sum([(batch["labels"] != -100).sum() for batch in batches])

with no_sync():
for batch in batches[:-1]:
with forward_context():
loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier)

# compute gradients
with backward_context():
loss_micro_step_dict["loss"].backward()

with torch.inference_mode():
metrics_tracker = metrics_tracker + loss_micro_step_dict

if fsdp_algorithm == 2:
model.set_requires_gradient_sync(True)

batch = batches[-1]
with forward_context():
loss_micro_step_dict = model(batch, lm_loss_multiplier=lm_loss_multiplier)

# compute gradients
with backward_context():
loss_micro_step_dict["loss"].backward()

with torch.inference_mode():
metrics_tracker = metrics_tracker + loss_micro_step_dict

if gradient_clipping is not None:
if fsdp_algorithm == 1:
grad_norm = model.clip_grad_norm_(gradient_clipping)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)

if is_torchao_available():
FP8Manager.sync_float8_amax_and_scale_history([model])

optimizer.step()
lr_scheduler.step()

if is_torchao_available():
FP8Manager.precompute_float8_dynamic_scale_for_fsdp([model])

with torch.inference_mode():
metrics_tracker["grad_norm"] = (
torch.tensor(0, device=torch.cuda.current_device()) if grad_norm is None else grad_norm
)

for key in metrics_tracker:
metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key])

metrics_tracker = all_reduce_metrics_tracker(metrics_tracker)

return metrics_tracker


def train(
args: TrainingArgs,
model_container: ModelContainer,
pipeline_schedule: _PipelineSchedule,
optimizer_container: OptimizerContainer,
lr_scheduler_container: LRSchedulerContainer,
train_dataloader: ResumableDataLoader,
Expand All @@ -43,7 +155,6 @@ def train(
"""

num_training_steps = args.training_parameters.num_training_steps
gradient_accumulation_steps = args.training_parameters.gradient_accumulation_steps
gradient_clipping = args.training_parameters.gradient_clipping

eval_during_training = args.training_parameters.eval_during_training
Expand Down Expand Up @@ -73,20 +184,15 @@ def train(
while global_step < num_training_steps:
global_step += 1

loss_step_dict = train_step(
model_container=model_container,
pipeline_schedule=pipeline_schedule,
optimizer_container=optimizer_container,
lr_scheduler_container=lr_scheduler_container,
loss_step_dict = train_step_without_pipeline_parallel(
model=model_container[0],
optimizer=optimizer_container[0],
lr_scheduler=lr_scheduler_container[0],
train_dataloader=train_dataloader_infinite,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_clipping=gradient_clipping,
forward_context=forward_context,
backward_context=backward_context,
sync_every_gradient_accumulation_step=args.distributed_args.sync_every_gradient_accumulation_step,
is_pipeline_parallel_enabled=args.distributed_args.num_pipeline_stages > 1,
local_batch_size=None,
sequence_length=None,
)

metrics_tracker = metrics_tracker + loss_step_dict
Expand Down Expand Up @@ -124,7 +230,7 @@ def train(
ensure_last_checkpoint_is_saved()

if torch_profiler is not None:
torch_profiler.__exit__()
torch_profiler.__exit__(None, None, None)


@torch.no_grad()
Expand Down Expand Up @@ -165,13 +271,15 @@ def evaluate(

metrics_tracker = MetricsTrackingDict({})
val_dataloader = custom_iterator(val_dataloader, infinite=False)
loss_tokens = 0

for _ in range(num_steps):
batch = get_next_batch(val_dataloader)
loss_tokens += (batch["labels"] != -100).sum()
loss_step_dict = model_container[0](batch)
metrics_tracker = metrics_tracker + loss_step_dict

metrics_tracker = metrics_tracker / num_steps
metrics_tracker = metrics_tracker / loss_tokens.item()

for key in metrics_tracker:
metrics_tracker[key] = dtensor_to_tensor(metrics_tracker[key])
Expand All @@ -193,8 +301,6 @@ def evaluate(
def main() -> None:
"""main program"""

assert False

mode = Mode.training

setup_tf32()
Expand All @@ -217,6 +323,11 @@ def main() -> None:
use_async_tensor_parallel=args.distributed_args.use_async_tensor_parallel,
)

StepTracker(
micro_batch_size=args.training_parameters.micro_batch_size,
gradient_accumulation_steps=args.training_parameters.gradient_accumulation_steps,
)

set_seed(args.random_args.seed)

assert args.distributed_args.num_pipeline_stages == 1, "pipeline parallel is not supported with finetuning"
Expand All @@ -241,7 +352,7 @@ def main() -> None:
is_encoder_decoder=model_container[0].is_encoder_decoder,
)

model_container, pipeline_schedule = wrap_model_container_for_distributed_training(args, model_container)
model_container, _ = wrap_model_container_for_distributed_training(args, model_container)

optimizer_container = get_optimizer_container(
optimizer_class_name=args.optimizer_args.class_name,
Expand All @@ -261,6 +372,9 @@ def main() -> None:
extra_lr_scheduler_args=args.lr_scheduler_args.extra_lr_scheduler_args,
)

assert len(model_container) == len(optimizer_container)
assert len(optimizer_container) == len(lr_scheduler_container)

log_model_optimizer_container(model_container, optimizer_container)

starting_iteration = 0
Expand All @@ -283,7 +397,6 @@ def main() -> None:
train(
args,
model_container=model_container,
pipeline_schedule=pipeline_schedule,
optimizer_container=optimizer_container,
lr_scheduler_container=lr_scheduler_container,
train_dataloader=train_dataloader,
Expand Down
37 changes: 32 additions & 5 deletions dolomite_engine/model_wrapper/finetuning.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import torch
import torch.distributed
from torch.distributed._tensor.placement_types import Replicate

from ..communication import Communication
from ..distributed import tensor_to_dtensor
from ..hf_models import get_autoregressive_language_modeling_loss
from ..utils import MetricsTrackingDict, ProcessGroupManager
from .base import ModelWrapper


class ModelWrapperForFinetuning(ModelWrapper):
def forward(self, batch: dict) -> MetricsTrackingDict:
def forward(self, batch: dict, lm_loss_multiplier: float = 1) -> MetricsTrackingDict:
"""forward function for a batch

Args:
Expand All @@ -25,17 +27,42 @@ def forward(self, batch: dict) -> MetricsTrackingDict:

model_outputs = self.model(**batch)

loss = get_autoregressive_language_modeling_loss(
lm_logits=model_outputs.logits,
return self.get_loss(
model_outputs=model_outputs,
labels=labels,
upcast_logits_for_loss=self.upcast_logits_for_loss,
cu_seqlens=batch.get("cu_seqlens", None),
lm_loss_multiplier=lm_loss_multiplier,
)

def get_loss(
self, model_outputs, labels: torch.Tensor, cu_seqlens: torch.Tensor | None, lm_loss_multiplier: float = 1
) -> torch.Tensor | dict:
logits: torch.Tensor = model_outputs.logits
aux_loss = model_outputs.aux_loss if hasattr(model_outputs, "aux_loss") else None

lm_loss = get_autoregressive_language_modeling_loss(
lm_logits=logits,
labels=labels,
upcast_logits_for_loss=self.upcast_logits_for_loss,
cu_seqlens=cu_seqlens,
use_padding_free_transformer=self.use_padding_free_transformer,
reduction="sum",
tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings,
)

return MetricsTrackingDict({"loss": loss})
lm_loss = lm_loss * lm_loss_multiplier

if aux_loss is None:
loss = lm_loss
output = {"loss": loss}
else:
if ProcessGroupManager.is_tensor_parallel_enabled():
aux_loss = tensor_to_dtensor(aux_loss, device_mesh=self.tp_mesh, current_placement=Replicate())

loss = lm_loss + self.router_aux_loss_coef * aux_loss
output = {"loss": loss, "lm_loss": lm_loss, "aux_loss": aux_loss}

return output

def _broadcast_inputs_for_tensor_parallel(self, batch: dict) -> dict:
device = torch.cuda.current_device()
Expand Down
2 changes: 1 addition & 1 deletion dolomite_engine/model_wrapper/pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def forward(self, batch: dict, prev_aux_loss: torch.Tensor | None = None, lm_los

return output

def get_loss(self, model_outputs, labels: torch.Tensor, lm_loss_multiplier: float = 1) -> torch.Tensor:
def get_loss(self, model_outputs, labels: torch.Tensor, lm_loss_multiplier: float = 1) -> torch.Tensor | dict:
if isinstance(model_outputs, torch.Tensor):
logits = model_outputs
aux_loss = None
Expand Down
Loading
Loading