Skip to content

Commit

Permalink
Merge branch 'fix-rerun-checkpoint' into 'main'
Browse files Browse the repository at this point in the history
Fix checkpointing of rerun state machine

See merge request ADLR/megatron-lm!2444
  • Loading branch information
deepakn94 committed Dec 14, 2024
2 parents 3f5d5d4 + be8534a commit 71c394b
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 85 deletions.
156 changes: 80 additions & 76 deletions megatron/core/rerun_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
import numpy as np
import torch

import megatron.core.parallel_state as mpu
from megatron.core.dist_checkpointing.mapping import ShardedObject

"""DISCLAIMER: THIS IS AN EXPERIMENTAL FEATURE.
The rerun state machine implementation in this file is alpha-level code to help
Expand All @@ -34,6 +37,7 @@
EXIT_CODE_FAILED_ON_RESULT_VALIDATION: int = 17

SerializableStateType = Union[list, dict]
DataIteratorArgType = Optional[Union["RerunDataIterator", list["RerunDataIterator"]]]


class Caller(NamedTuple):
Expand Down Expand Up @@ -203,22 +207,22 @@ def __init__(

self.saved_results: dict[Call, Any] = {}
self.stats: dict[Caller, QuickStats] = defaultdict(lambda: QuickStats())
logger.warning(f"RerunStateMachine initialized in mode {mode}")
if _safe_get_rank() == 0:
logger.warning(f"RerunStateMachine initialized in mode {mode}")

def set_mode(self, mode: RerunMode) -> None:
"""Method to set the operating mode"""

logger.warning(f"Setting RerunStateMachine mode {mode}")
if _safe_get_rank() == 0:
logger.warning(f"Setting RerunStateMachine mode {mode}")
self.mode = mode

def get_mode(self) -> RerunMode:
"""Method to get the operating mode"""

return self.mode

def should_run_forward_backward(
self, data_iterator: Optional[Union["RerunDataIterator", list]]
) -> bool:
def should_run_forward_backward(self, data_iterator: DataIteratorArgType) -> bool:
"""Method instructing whether to (re)run the forward-backward pass.
Args:
Expand All @@ -243,16 +247,7 @@ def train_step(data_iterator, ...):

self.validation_counts = defaultdict(int)

data_iterators: list[RerunDataIterator] = []
if self.mode != RerunMode.DISABLED and data_iterator is not None:
if not isinstance(data_iterator, list):
data_iterators = [data_iterator]
else:
data_iterators = data_iterator
for d in data_iterators:
assert (
isinstance(d, RerunDataIterator),
), "data iterator is not wrapped with RerunDataIterator"
data_iterators: list[RerunDataIterator] = self._sanitize_data_iterators(data_iterator)

# Are we about to start the initial run?
if self.state == RerunState.NOT_RUNNING_YET:
Expand All @@ -264,7 +259,7 @@ def train_step(data_iterator, ...):
len(self.data_iterator_checkpoints) == len(data_iterators),
), "data iterator has different length than checkpointed data iterator"
for i, d in enumerate(data_iterators):
d.set_checkpoint_state(self.data_iterator_checkpoints[i])
d.load_state_dict(self.data_iterator_checkpoints[i])
self.data_iterator_checkpoints = None
self._save_state()
if data_iterators:
Expand Down Expand Up @@ -630,17 +625,15 @@ def train_step(data_iterator, ...):
self.last_loss = loss
return result

def get_checkpoint_state(
self, data_iterator: Optional[Union["RerunDataIterator", list]]
) -> list[dict[str, Any]]:
def state_dict(self, data_iterator: DataIteratorArgType, use_dist_ckpt: bool) -> dict[str, Any]:
"""Method that returns a state dict to be checkpointed.
Args:
data_iterator: the data iterator that needs to be checkpointed (or None
if this checkpoint is not requested by the rerun state machine).
use_dist_ckpt: generate a distributed checkpoint.
Returns:
A list of state dicts, each state dict representing the rerun state machine
for one rank.
A state dict representing the rerun state machine.
Example usage:
Expand All @@ -649,25 +642,15 @@ def save_my_model_checkpoint(data_iterator, ...):
...
rerun_state_machine = get_rerun_state_machine()
checkpoint['rerun_state_machine'] = (
rerun_state_machine.get_checkpoint_state(data_iterator)
rerun_state_machine.state_dict(data_iterator, False)
)
...
return checkpoint
"""

data_iterators: list[RerunDataIterator]
if self.mode == RerunMode.DISABLED:
data_iterators = []
elif isinstance(data_iterator, (list, tuple)):
data_iterators = data_iterator
else:
data_iterators = [data_iterator] if data_iterator is not None else []
for d in data_iterators:
assert (
isinstance(d, RerunDataIterator),
), "data iterator is not wrapped with RerunDataIterator"
data_iterators: list[RerunDataIterator] = self._sanitize_data_iterators(data_iterator)

state: dict[str, Any] = {
state_dict: dict[str, Any] = {
'mode': self.mode,
'state': self.state,
'current_iteration': self.current_iteration,
Expand All @@ -676,37 +659,39 @@ def save_my_model_checkpoint(data_iterator, ...):
'restart_again_requested': self.restart_again_requested,
'continue_requested': self.continue_requested,
# logged_sdc_enabled should not be saved (set at the job startup time).
'error_injector_checkpoint': self.error_injector.get_checkpoint_state(),
'error_injector_checkpoint': self.error_injector.state_dict(),
# validation_counts should not be saved (reset at the beginning of the training loop).
'failed_validation_call': self.failed_validation_call,
'initial_result': self.initial_result,
'suspicious_node': self.suspicious_node,
'suspicious_device': self.suspicious_device,
# No need to save saved_state (RNG state already captured in checkpoint).
'data_iterator_checkpoints': (
[d.get_checkpoint_state() for d in data_iterators] if data_iterators else None
[d.state_dict() for d in data_iterators] if data_iterators else None
),
'last_loss': self.last_loss,
# No need to save saved_results and stats (resets when job resumes).
}
state_list: list[dict[str, Any]]
if (
torch.distributed.is_initialized()
and torch.distributed.get_world_size() > 1
and self.mode != RerunMode.DISABLED
):
state_list = [None for i in range(torch.distributed.get_world_size())]
torch.distributed.all_gather_object(state_list, state)
else:
state_list = [state]
return state_list
if use_dist_ckpt:
pp_rank = mpu.get_pipeline_model_parallel_rank()
pp_size = mpu.get_pipeline_model_parallel_world_size()
tp_rank = mpu.get_tensor_model_parallel_rank()
tp_size = mpu.get_tensor_model_parallel_world_size()
state_dict = ShardedObject(
'rerun_state_machine_state',
state_dict,
(pp_size, tp_size),
(pp_rank, tp_rank),
replica_id=mpu.get_data_parallel_rank(with_context_parallel=True),
)
return state_dict

def set_checkpoint_state(self, state_list: list[dict[str, Any]]) -> None:
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Method that restores the state from a checkpoint.
Args:
state_list: the list of state dicts saved in the checkpoint and originally
obtained from get_checkpoint_state().
state_dict: the state dict saved in the checkpoint and originally
obtained from state_dict().
Returns:
None
Expand All @@ -716,31 +701,43 @@ def load_checkpoint(checkpoint, ...)
...
if 'rerun_state_machine' in checkpoint:
rerun_state_machine = get_rerun_state_machine()
rerun_state_machine.set_checkpoint_state(checkpoint['rerun_state_machine'])
rerun_state_machine.load_state_dict(checkpoint['rerun_state_machine'])
"""

if self.mode == RerunMode.DISABLED:
return
rank: int = _safe_get_rank()
if rank == 0:
logger.warning(
"Getting RerunStaeMachine state from checkpoint, args rerun options ignored"
)
state = state_list[rank]
self.mode = state['mode']
self.state = state['state']
self.current_iteration = state['current_iteration']
self.rerun_requested = state['rerun_requested']
self.checkpoint_requested = state['checkpoint_requested']
self.restart_again_requested = state['restart_again_requested']
self.continue_requested = state['continue_requested']
self.error_injector.set_checkpoint_state(state['error_injector_checkpoint'])
self.failed_validation_call = state['failed_validation_call']
self.initial_result = state['initial_result']
self.suspicious_node = state['suspicious_node']
self.suspicious_device = state['suspicious_device']
self.data_iterator_checkpoints = state['data_iterator_checkpoints']
self.last_loss = state['last_loss']
logger.warning("Getting RerunStaeMachine state from checkpoint, args rerun options ignored")
self.mode = state_dict['mode']
self.state = state_dict['state']
self.current_iteration = state_dict['current_iteration']
self.rerun_requested = state_dict['rerun_requested']
self.checkpoint_requested = state_dict['checkpoint_requested']
self.restart_again_requested = state_dict['restart_again_requested']
self.continue_requested = state_dict['continue_requested']
self.error_injector.load_state_dict(state_dict['error_injector_checkpoint'])
self.failed_validation_call = state_dict['failed_validation_call']
self.initial_result = state_dict['initial_result']
self.suspicious_node = state_dict['suspicious_node']
self.suspicious_device = state_dict['suspicious_device']
self.data_iterator_checkpoints = state_dict['data_iterator_checkpoints']
self.last_loss = state_dict['last_loss']

def _sanitize_data_iterators(
self, data_iterator: DataIteratorArgType
) -> list["RerunDataIterator"]:
data_iterators: list[RerunDataIterator]
if self.mode == RerunMode.DISABLED:
data_iterators = []
elif not isinstance(data_iterator, list):
data_iterators = [data_iterator]
else:
data_iterators = data_iterator
data_iterators = [d for d in data_iterators if d is not None]
for d in data_iterators:
assert (
isinstance(d, RerunDataIterator),
), "data iterator is not wrapped with RerunDataIterator"
return data_iterators

def _get_validation_call_info(self) -> Call:
"""Internal method to get the context about the caller to validate_result()."""
Expand Down Expand Up @@ -867,7 +864,7 @@ def advance(self) -> None:
self.replaying = False
self.saved_microbatches = []

def get_checkpoint_state(self) -> SerializableStateType:
def state_dict(self) -> SerializableStateType:
"""Method to capture the state of the iterator as a serializable dict."""

return {
Expand All @@ -876,7 +873,7 @@ def get_checkpoint_state(self) -> SerializableStateType:
'replay_pos': self.replay_pos,
}

def set_checkpoint_state(self, state_dict: SerializableStateType) -> None:
def load_state_dict(self, state_dict: SerializableStateType) -> None:
"""Method to restore the state saved as a serializable dict."""

self.saved_microbatches = state_dict['saved_microbatches']
Expand Down Expand Up @@ -1048,7 +1045,7 @@ def maybe_miscompare(
else:
raise RuntimeError("Should not be here")

def get_checkpoint_state(self) -> SerializableStateType:
def state_dict(self) -> SerializableStateType:
"""Method to capture the state of the error injector as a serializable dict."""

return {
Expand All @@ -1058,7 +1055,7 @@ def get_checkpoint_state(self) -> SerializableStateType:
'injected_error_type': self.injected_error_type,
}

def set_checkpoint_state(self, state_dict: SerializableStateType) -> None:
def load_state_dict(self, state_dict: SerializableStateType) -> None:
"""Method to restore the state saved as a serializable dict."""

self.error_injection_rate = state_dict['error_injection_rate']
Expand Down Expand Up @@ -1104,7 +1101,14 @@ def _set_rerun_state_machine(rerun_state_machine) -> None:
def _safe_get_rank() -> int:
"""Internal function that safely checks and returns the rank of the caller."""

return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
if torch.distributed.is_initialized():
return torch.distributed.get_rank()

# If torch.distributed is not initialized, try to read environment variables.
try:
return int(os.environ.get("RANK", 0))
except (ValueError, TypeError):
return 0


def _compare_floats(a: torch.Tensor, b: torch.Tensor) -> float:
Expand Down
30 changes: 22 additions & 8 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
# Collect rng state across data parallel ranks.
rng_state = get_rng_state(ckpt_type != CheckpointType.LEGACY)

# Collect rerun state across all ranks
rerun_state_machine = get_rerun_state_machine()
rerun_state = rerun_state_machine.state_dict(
data_iterator=train_data_iterator, use_dist_ckpt=ckpt_type != CheckpointType.LEGACY
)

# Checkpoint name.
return_base_dir = (ckpt_type != CheckpointType.LEGACY)
checkpoint_name = get_checkpoint_name(save_dir, iteration, release=False, pipeline_parallel=pipeline_parallel,
Expand Down Expand Up @@ -409,7 +415,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati
use_dist_ckpt=ckpt_type != CheckpointType.LEGACY,
iteration=iteration,
optim_sd_kwargs=optim_sd_kwargs,
train_data_iterator=train_data_iterator,
rerun_state=rerun_state,
)

if args.enable_ft_package and ft_client is not None:
Expand Down Expand Up @@ -593,7 +599,7 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path):

def generate_state_dict(args, model, optimizer, opt_param_scheduler,
rng_state, use_dist_ckpt=False, iteration=None,
optim_sd_kwargs=None, train_data_iterator=None):
optim_sd_kwargs=None, rerun_state=None):
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
Expand Down Expand Up @@ -623,10 +629,7 @@ def generate_state_dict(args, model, optimizer, opt_param_scheduler,
opt_param_scheduler.state_dict()

# Rerun state
rerun_state_machine = get_rerun_state_machine()
state_dict['rerun_state_machine'] = rerun_state_machine.get_checkpoint_state(
train_data_iterator
)
state_dict['rerun_state_machine'] = rerun_state

# RNG states.
if not args.no_save_rng:
Expand Down Expand Up @@ -1136,6 +1139,17 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
gen_sd_optim = None
gen_sd_opt_param_scheduler = None

# Determine if rerun state will be loaded
if (ckpt_tp_pp == run_tp_pp and not release and not args.finetune):
rerun_state_machine = get_rerun_state_machine()
gen_sd_rerun_state = rerun_state_machine.state_dict(
data_iterator=None, use_dist_ckpt=True
)
else:
gen_sd_rerun_state = None
if ckpt_tp_pp != run_tp_pp:
print_rank_0("{}: Rerun state will be ignored".format(mismatch_msg))

# [ModelOpt]: Initial loading from non-resume sharded checkpoint to a Distillation Model
# will result in key mismatch with loss modules potentially containing parameters, since
# it requires generating a state_dict before loading. Here we hide those modules if present.
Expand All @@ -1145,7 +1159,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
stack.enter_context(m.hide_loss_modules())
load_kwargs['sharded_state_dict'] = generate_state_dict(
args, model, gen_sd_optim, gen_sd_opt_param_scheduler, gen_sd_rng_state,
use_dist_ckpt=True, optim_sd_kwargs=optim_sd_kwargs, train_data_iterator=None
use_dist_ckpt=True, optim_sd_kwargs=optim_sd_kwargs, rerun_state=gen_sd_rerun_state
)

# When "--fp8-param-gather" is disabled, this function doesn't modify anything.
Expand Down Expand Up @@ -1268,7 +1282,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
# rerun state
try:
if 'rerun_state_machine' in state_dict:
get_rerun_state_machine().set_checkpoint_state(state_dict['rerun_state_machine'])
get_rerun_state_machine().load_state_dict(state_dict['rerun_state_machine'])
except Exception as e:
print(f"Unable to restore RerunMachine from checkpoint: {e}")
sys.exit()
Expand Down
Loading

0 comments on commit 71c394b

Please sign in to comment.