diff --git a/colossalai/checkpoint_io/distributed_checkpoint_utils.py b/colossalai/checkpoint_io/distributed_checkpoint_utils.py index 22b79d977332..5e94954c5df4 100644 --- a/colossalai/checkpoint_io/distributed_checkpoint_utils.py +++ b/colossalai/checkpoint_io/distributed_checkpoint_utils.py @@ -11,6 +11,8 @@ from colossalai.interface import ModelWrapper from colossalai.utils import get_non_persistent_buffers_set +from colossalai.shardformer.layer.parallel_module import ParallelModule +from contextlib import contextmanager from .index_file import CheckpointIndexFile from .utils import ( @@ -32,67 +34,32 @@ MODEL_META_PREFIX = "pytorch_model-meta-dist-" MODEL_WEIGHT_PREFIX = "pytorch_model-dist-" SHARD_META_SUFFIX = ".index.json" +UNSHARD_META_SUFFIX = ".json" -def dist_model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False): - destination = dict() - # Save parameters. - for name, param in model.named_parameters(): - if param is None: - continue - destination[prefix + name] = param - # Save buffers. - non_persist_buffers_set = get_non_persistent_buffers_set(model) - for name, buf in model.named_buffers(): - if buf is not None and name not in non_persist_buffers_set: - buffer = buf if keep_vars else buf.detach() - destination[prefix + name] = buffer - - # Save extra states. - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if ( - getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) - is not torch.nn.Module.get_extra_state - ): - extra_state = model.get_extra_state() - destination[extra_state_key] = extra_state - return destination - - -def load_state_dict_into_dist_model( - model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False -): - destination = dict() - # Save parameters. - for name, param in model.named_parameters(): - if param is None: - continue - with torch.no_grad(): - param.copy_(state_dict[prefix + name]) - # Save buffers. - non_persist_buffers_set = get_non_persistent_buffers_set(model) - for name, buf in model.named_buffers(): - if buf is not None and name not in non_persist_buffers_set: - with torch.no_grad(): - buf.copy_(state_dict[prefix + name]) - - # Save extra states. - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if ( - getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) - is not torch.nn.Module.get_extra_state - ): - extra_state = model.get_extra_state() - with torch.no_grad(): - extra_state.copy_(state_dict[extra_state_key]) - return destination +@contextmanager +def RestoreDefaultStateDictBehavior(model): + original_methods = {} + for name, module in model.named_modules(): + if isinstance(module, ParallelModule): + original_methods[module] = (module._save_to_state_dict, module._load_from_state_dict) + module._save_to_state_dict = nn.Module._save_to_state_dict.__get__(module, nn.Module) + module._load_from_state_dict = nn.Module._load_from_state_dict.__get__(module, nn.Module) + try: + yield model + finally: + for module, original_method in original_methods.items(): + module._save_to_state_dict, module._load_from_state_dict = original_method + def create_model_metadata( - model: nn.Module, + model: ModelWrapper, prefix: str = "", - tp_size=None, - tp_rank=None, + tp_size: int = None, + tp_rank: int = None, + zero_size: int = None, + zero_rank: int = None, ): param_origin_shape = model.param_origin_shape model = model.unwrap() @@ -105,7 +72,7 @@ def create_model_metadata( tp_partition_dim = search_tp_partition_dim( current_shape=param.shape, original_shape=original_shape, tp_size=tp_size ) - model_metadata[prefix + name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int) + model_metadata[prefix + name]["offsets"] = [0] * len(original_shape) model_metadata[prefix + name]["lengths"] = list(param.shape) model_metadata[prefix + name]["global_shape"] = list(original_shape) if tp_partition_dim is not None: @@ -257,119 +224,9 @@ def is_pytorch_model_meta_dist_file(checkpoint_index_file): return False -def dist_model_sharder( - model: nn.Module, - prefix: str = "", - keep_vars: bool = False, - size_per_shard: int = 1024, - pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, -) -> Iterator[Tuple[OrderedDict, int]]: - # An internel method that breaks state_dict of model into shards within limited size. - - state_dict_sharder = StateDictSharder(size_per_shard) - - # Save parameters. - for name, param in model.named_parameters(): - if param is None: - continue - if pinned_state_dicts is not None: - if (prefix + name) not in pinned_state_dicts: - pinned_state_dicts[prefix + name] = torch.empty_like(param, pin_memory=True, device="cpu") - pinned_state_dicts[prefix + name].copy_(param) - param = pinned_state_dicts[prefix + name] - block, block_size = state_dict_sharder.append_param(prefix + name, param) - if block is not None: - yield block, block_size - - # Save buffers. - non_persist_buffers_set = get_non_persistent_buffers_set(model) - for name, buf in model.named_buffers(): - if buf is not None and name not in non_persist_buffers_set: - buffer = buf if keep_vars else buf.detach() - if pinned_state_dicts is not None: - if (prefix + name) not in pinned_state_dicts: - pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu") - pinned_state_dicts[prefix + name].copy_(buffer) - buffer = pinned_state_dicts[prefix + name] - block, block_size = state_dict_sharder.append_param(prefix + name, buffer) - if block is not None: - yield block, block_size - - # Save extra states. - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if ( - getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state) - is not torch.nn.Module.get_extra_state - ): - extra_state = model.get_extra_state() - if pinned_state_dicts is not None: - if extra_state_key not in pinned_state_dicts: - pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu") - pinned_state_dicts[extra_state_key].copy_(extra_state) - extra_state = pinned_state_dicts[extra_state_key] - block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) - if block is not None: - yield block, block_size - - # Return the last block in sharder. - yield state_dict_sharder.current_block, state_dict_sharder.current_block_size - - -def save_dist_unshard_model( - model: ModelWrapper, - model_metadata: Dict, - checkpoint: str, - use_safetensors: bool, - use_async: bool = False, - dist_id=0, - pinned_state_dicts=None, -): - """ - Save model state dict to a single file with given checkpointing path. - - Args: - model (nn.Module): Model on local device to be saved. - checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path. - gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True. - use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. - use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. - """ - - model = model.unwrap() - - # The logic of collecting parameter shards along tp degree - # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. - state_dict = dist_model_state_dict(model) - - Path(checkpoint).mkdir(parents=True, exist_ok=True) - file_name = f"{MODEL_WEIGHT_PREFIX}{dist_id:05d}.bin" - if use_async: - file_name = file_name.replace(".bin", ".safetensors") - checkpoint_file = os.path.join(checkpoint, file_name) - metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}.json") - save_metadata(model_metadata, metadata_file, file_name) - - if use_async: - from colossalai.utils.safetensors import save - - if id(model) not in pinned_state_dicts: - pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - for name, param in state_dict.items(): - pinned_state_dicts[id(model)][name].copy_(param) - state_dict[name] = pinned_state_dicts[id(model)][name] - writer = save(path=checkpoint_file, state_dict=state_dict) - return writer - else: - save_state_dict(state_dict, checkpoint_file, use_safetensors) - return None - - def load_dist_model( - model: ModelWrapper, model_metadata: Dict, checkpoint: str, - low_cpu_mem_mode: bool = True, - num_threads: int = 1, ): """ Load model from a single file with the given path of checkpoint. @@ -380,10 +237,6 @@ def load_dist_model( strict (bool, optional): For name matching during loading state_dict. Defaults to False. This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled. """ - - model_before_wrapping = model - model = model.unwrap() - metadata_loaded = load_metadata(checkpoint) load_files = {} @@ -420,92 +273,14 @@ def load_dist_model( ) state_dict[key] = state - if not low_cpu_mem_mode: - state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) - - load_state_dict_into_dist_model(model=model, state_dict=state_dict) - - # Update master params if mixed-precision training is enabled. - model_before_wrapping.update_master_params() - + return state_dict -def save_dist_sharded_model( - model: ModelWrapper, - model_metadata: Dict, - checkpoint: str, - prefix: Optional[str] = None, - size_per_shard: int = 1024, - use_safetensors: bool = False, - use_async: bool = False, - dist_id: int = 0, - pinned_state_dicts=None, -) -> None: - """ - Save sharded model checkpoint under the given checkpointing path. - The following files will be created under the path: - - An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names. - - Multiple files that store state tensors of models. - If pipeline parallelism is used, the filenames are in the form of "pytorch_model.-stage-000XX-shard-000XX.bin". - If pipeline parallelism is not used, "pytorch_model.-000XX.bin" - - - Args: - model (nn.Module): Model on local device to be saved. - checkpoint (str): Checkpointing path which should be a directory path. - gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True. - prefix (str, optional): Perfix of file to save. Defaults to None. - size_per_shard (int, optional): Size per shard in MB. Defaults to 1024. - use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. - use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. - """ - - model = model.unwrap() - - if os.path.isfile(checkpoint): - logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") - return - - Path(checkpoint).mkdir(parents=True, exist_ok=True) - # Devices along the same dp_group share the same copies of model. - # So only let the device with dp_rank == 0 and sp_rank == 0 save the model. - - if use_async: - if id(model) not in pinned_state_dicts: - pinned_state_dicts[id(model)] = {} - pinned_state_dicts = pinned_state_dicts[id(model)] - else: - pinned_state_dicts = None - state_dict_shard = dist_model_sharder(model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts) - weights_name, _ = get_model_base_filenames(prefix, use_safetensors) - index_file = CheckpointIndexFile(checkpoint) - - # Manage filenames of sharded weights and index file for each pipeline stage. +def get_dist_files_name(weights_name, dist_id): weights_name = weights_name.replace(".bin", f"-dist-{dist_id:05d}-shard.bin") weights_name = weights_name.replace(".safetensors", f"-dist-{dist_id:05d}-shard.safetensors") - metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}") - async_writers = [] - if use_async: - total_size, writers = async_save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=True, - state_preprocess=False, - ) - async_writers.extend(writers) - else: - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=True, - use_safetensors=use_safetensors, - use_pp_format=True, - ) - for k, _ in model_metadata.items(): - model_metadata[k]["file"] = index_file.get_checkpoint_file(k) + return weights_name - save_metadata(model_metadata, metadata_file, total_size=total_size) - return async_writers +def get_dist_meta_file_name(checkpoint, dist_id, use_safetensors): + if use_safetensors: + return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}") + return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{UNSHARD_META_SUFFIX}") \ No newline at end of file diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index cbad7d78854a..93c836e22cb5 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -5,6 +5,7 @@ from pathlib import Path from shutil import rmtree from typing import Dict, Iterator, Optional, OrderedDict, Tuple +from contextlib import nullcontext import torch import torch.distributed as dist @@ -28,8 +29,13 @@ create_model_metadata, is_pytorch_model_meta_dist_file, load_dist_model, - save_dist_sharded_model, - save_dist_unshard_model, + save_metadata, + get_dist_files_name, + get_dist_meta_file_name, + MODEL_META_PREFIX, + MODEL_WEIGHT_PREFIX, + SHARD_META_SUFFIX, + RestoreDefaultStateDictBehavior ) from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -97,13 +103,14 @@ def __init__( self.verbose = verbose self.coordinator = DistCoordinator() - @staticmethod def _model_sharder( + self, model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024, pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, + gather_dtensor: bool = True, ) -> Iterator[Tuple[OrderedDict, int]]: # An internel method that breaks state_dict of model into shards within limited size. @@ -113,10 +120,15 @@ def _model_sharder( for name, param in model.named_parameters(): if param is None: continue - # Gather tensor pieces when using tensor parallel. - if is_padded_tensor(param): - param = to_unpadded_tensor(param) - param_ = gather_distributed_param(param, keep_vars=False) + + if gather_dtensor: + # Gather tensor pieces when using tensor parallel. + if is_padded_tensor(param): + param = to_unpadded_tensor(param) + param_ = gather_distributed_param(param, keep_vars=False) + else: + param_ = param + if pinned_state_dicts is not None: if (prefix + name) not in pinned_state_dicts: pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu") @@ -237,26 +249,14 @@ def save_sharded_model( assert isinstance(model, ModelWrapper), "Please boost the model before saving!" model._force_wait_all_gather() - - if gather_dtensor: - if self.dp_rank != 0 and self.sp_rank != 0: - return - dist_id = self.tp_size * self.pp_rank + self.tp_rank - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) - async_writers = save_dist_sharded_model( - model=model, - model_metadata=model_metadata, - checkpoint=checkpoint, - prefix=prefix, - size_per_shard=size_per_shard, - use_safetensors=use_safetensors, - use_async=use_async, - dist_id=dist_id, - pinned_state_dicts=self.pinned_state_dicts, - ) - self.async_writers.extend(async_writers) + if self.dp_rank != 0 and self.sp_rank != 0: return - + + model_metadata = None + if not gather_dtensor: + # Manage filenames of sharded weights and index file for each pipeline stage. + model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + model = model.unwrap() if os.path.isfile(checkpoint): @@ -264,28 +264,30 @@ def save_sharded_model( return Path(checkpoint).mkdir(parents=True, exist_ok=True) - # Devices along the same dp_group share the same copies of model. - # So only let the device with dp_rank == 0 save the model. - if self.dp_rank != 0: - return # Then collect the sharded parameters & buffers along tp_group. # Only devices with tp_rank == 0 are responsible for model saving. - control_saving = self.tp_rank == 0 and self.sp_rank == 0 + control_saving = self.tp_rank == 0 if gather_dtensor else True if control_saving and use_async: if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = {} pinned_state_dicts = self.pinned_state_dicts[id(model)] else: pinned_state_dicts = None - state_dict_shard = HybridParallelCheckpointIO._model_sharder( - model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts + state_dict_shard = self._model_sharder( + model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts, gather_dtensor=gather_dtensor ) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) - if self.pp_size == 1: + if self.pp_size == 1 or not gather_dtensor: # When pipeline is not used, save the model shards as in general checkpointIO + if not gather_dtensor: + dist_id = self.tp_size * self.pp_rank + self.tp_rank + weights_name = get_dist_files_name(weights_name=weights_name, dist_id=dist_id) + metadata_file = get_dist_meta_file_name(checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_safetensors) + if use_async: total_size, writers = async_save_state_dict_shards( sharded_state_dict=state_dict_shard, @@ -305,16 +307,22 @@ def save_sharded_model( is_master=control_saving, use_safetensors=use_safetensors, ) - if control_saving: - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model, checkpoint) - if self.verbose and self.coordinator.is_master(): - logging.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + if not gather_dtensor: + # saving metadata for distributed checkpoint + for k, _ in model_metadata.items(): + model_metadata[k]["file"] = index_file.get_checkpoint_file(k) + save_metadata(model_metadata, metadata_file, total_size=total_size) + else: + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) else: # When pipeline is used, each stage produces its own shard files and index files. @@ -405,13 +413,15 @@ def load_sharded_model( if is_pytorch_model_meta_dist_file(checkpoint_index_file): model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) checkpoint = checkpoint_index_file.parent - load_dist_model( - model=model, + state_dict = load_dist_model( model_metadata=model_metadata, checkpoint=checkpoint, - low_cpu_mem_mode=low_cpu_mem_mode, - num_threads=num_threads, ) + model = model.unwrap() + with RestoreDefaultStateDictBehavior(model): + load_state_dict_into_model( + model, state_dict, missing_keys=[], strict=False, load_sub_module=True + ) return model_before_wrapping = model # backup for model before wrapping @@ -803,47 +813,43 @@ def save_unsharded_model( assert isinstance(model, ModelWrapper), "Please boost the model before saving!" model._force_wait_all_gather() - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + if self.dp_rank != 0 and self.sp_rank != 0: + return - if gather_dtensor: - if self.dp_rank != 0 and self.sp_rank != 0: - return + if not gather_dtensor: dist_id = self.tp_size * self.pp_rank + self.tp_rank - writer = save_dist_unshard_model( - model=model, - model_metadata=model_metadata, - checkpoint=checkpoint, - use_safetensors=use_safetensors, - use_async=use_async, - dist_id=dist_id, - pinned_state_dicts=self.pinned_state_dicts, - ) - if writer is not None: - self.async_writers.append(writer) - return + Path(checkpoint).mkdir(parents=True, exist_ok=True) + model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) + checkpoint_file = os.path.join(checkpoint, f"{MODEL_WEIGHT_PREFIX}{dist_id:05d}.bin") + if use_async: + checkpoint_file = checkpoint_file.replace(".bin", f".safetensors") + metadata_file = get_dist_meta_file_name(checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_async) + save_metadata(model_metadata=model_metadata, metadata_file=metadata_file, checkpoint_file=checkpoint_file) + else: + checkpoint_file = checkpoint model = model.unwrap() - if self.dp_rank != 0: - return # The logic of collecting parameter shards along tp degree # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. - state_dict = model.state_dict() - if self.pp_size == 1: - # When pipeline is not used, let master rank directly save the collected state_dict. - if self.tp_rank == 0: - if use_async: - from colossalai.utils.safetensors import save + ctx = RestoreDefaultStateDictBehavior(model) if not gather_dtensor else nullcontext() + with ctx: + state_dict = model.state_dict() - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - for name, param in state_dict.items(): - self.pinned_state_dicts[id(model)][name].copy_(param) - state_dict[name] = self.pinned_state_dicts[id(model)][name] - writer = save(path=checkpoint, state_dict=state_dict) - self.async_writers.append(writer) - else: - save_state_dict(state_dict, checkpoint, use_safetensors) + if (self.pp_size == 1 and self.tp_rank == 0) or not gather_dtensor: + # When pipeline is not used, let master rank directly save the collected state_dict. + if use_async: + from colossalai.utils.safetensors import save + + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + for name, param in state_dict.items(): + self.pinned_state_dicts[id(model)][name].copy_(param) + state_dict[name] = self.pinned_state_dicts[id(model)][name] + writer = save(path=checkpoint_file, state_dict=state_dict) + self.async_writers.append(writer) + else: + save_state_dict(state_dict, checkpoint_file, use_safetensors) else: # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. state_dict_list = [None for _ in range(self.pp_size)] @@ -862,10 +868,10 @@ def save_unsharded_model( for name, param in complete_state_dict.items(): self.pinned_state_dicts[id(model)][name].copy_(param) complete_state_dict[name] = self.pinned_state_dicts[id(model)][name] - writer = save(path=checkpoint, state_dict=complete_state_dict) + writer = save(path=checkpoint_file, state_dict=complete_state_dict) self.async_writers.append(writer) else: - save_state_dict(complete_state_dict, checkpoint, use_safetensors) + save_state_dict(complete_state_dict, checkpoint_file, use_safetensors) def load_unsharded_model( self, @@ -890,18 +896,16 @@ def load_unsharded_model( assert isinstance(model, ModelWrapper), "Please boost the model before loading!" model._force_wait_all_gather() + load_dtensor = False if os.path.isdir(checkpoint): for filename in os.listdir(checkpoint): if is_pytorch_model_meta_dist_file(filename): - model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) - load_dist_model( - model=model, - model_metadata=model_metadata, - checkpoint=checkpoint, - low_cpu_mem_mode=low_cpu_mem_mode, - num_threads=num_threads, - ) - return + load_dtensor = True + break + + model_metadata = None # used for dist model + if load_dtensor: + model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) strict = False model_before_wrapping = model @@ -910,10 +914,17 @@ def load_unsharded_model( # Load from checkpoint. Since the logic of breaking parameter shards along tp degree # has been implemented by _load_from_state_dict method of ParallelModule in Shardformer, # model.load_state_dict can be directly called. - state_dict = load_state_dict(checkpoint) + if load_dtensor: + state_dict = load_dist_model(model_metadata=model_metadata, checkpoint=checkpoint) + else: + state_dict = load_state_dict(checkpoint) + if not low_cpu_mem_mode: state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) - model.load_state_dict(state_dict, strict=strict) + + ctx = RestoreDefaultStateDictBehavior(model) if load_dtensor else nullcontext() + with ctx: + model.load_state_dict(state_dict, strict=strict) # Update master params if mixed-precision training is enabled. model_before_wrapping.update_master_params() diff --git a/tests/test_checkpoint_io/test_dist_checkpointio.py b/tests/test_checkpoint_io/test_dist_checkpointio.py index 09d6eb345bab..74c1efc2cbc2 100644 --- a/tests/test_checkpoint_io/test_dist_checkpointio.py +++ b/tests/test_checkpoint_io/test_dist_checkpointio.py @@ -76,13 +76,14 @@ def _preprocess_data(data): optimizer_0.step() optimizer_0.zero_grad() with shared_tempdir() as tempdir: + tempdir = "/home/jiangmingyan/workspace/ColossalAI/tests/test_checkpoint_io/output" model_ckpt_path_0 = f"{tempdir}/model_0" booster_0.save_model( model_0, model_ckpt_path_0, shard=shard, - gather_dtensor=True, + gather_dtensor=False, size_per_shard=size_per_shard, use_async=use_async, ) @@ -104,7 +105,7 @@ def _preprocess_data(data): model_1, model_ckpt_path_1, shard=shard, - gather_dtensor=True, + gather_dtensor=False, size_per_shard=size_per_shard, use_async=use_async, )