diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bc9425a0b0cd..1fba6d4b5c4e 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -78,6 +78,9 @@ def __init__( self.require_grad_sync = True self.overlap_allgather = overlap_allgather self.use_fp8 = use_fp8 + self.param_origin_shape = {} + for name, param in module.named_parameters(): + self.param_origin_shape[name] = param.shape shardformer = ShardFormer(shard_config) if custom_policy is not None: diff --git a/colossalai/checkpoint_io/distributed_checkpoint_utils.py b/colossalai/checkpoint_io/distributed_checkpoint_utils.py new file mode 100644 index 000000000000..563ec99dc21e --- /dev/null +++ b/colossalai/checkpoint_io/distributed_checkpoint_utils.py @@ -0,0 +1,238 @@ +import json +import os +from typing import Dict + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.distributed_c10d import _get_default_group + +from colossalai.interface import ModelWrapper +from colossalai.shardformer.layer.parallel_module import ParallelModule +from contextlib import contextmanager + +from .utils import ( + load_state_dict, + search_tp_partition_dim, +) + +MODEL_META_PREFIX = "pytorch_model-meta-dist-" +MODEL_WEIGHT_PREFIX = "pytorch_model-dist-" +SHARD_META_SUFFIX = ".index.json" +UNSHARD_META_SUFFIX = ".json" + + +@contextmanager +def RestoreDefaultStateDictBehavior(model): + original_methods = {} + for name, module in model.named_modules(): + if isinstance(module, ParallelModule): + original_methods[module] = (module._save_to_state_dict, module._load_from_state_dict) + module._save_to_state_dict = nn.Module._save_to_state_dict.__get__(module, nn.Module) + module._load_from_state_dict = nn.Module._load_from_state_dict.__get__(module, nn.Module) + try: + yield model + finally: + for module, original_method in original_methods.items(): + module._save_to_state_dict, module._load_from_state_dict = original_method + + +def save_metadata(model_metadata, metadata_file, checkpoint_file=None, total_size=None): + metadata_dicts = { + "checkpoint_version": "1.0", + "total_size": total_size, + "metadata": {}, + } + for name, data in model_metadata.items(): + metadata_dicts["metadata"][name] = {} + for k, v in data.items(): + if isinstance(v, torch.Tensor): + v = v.tolist() + metadata_dicts["metadata"][name][k] = v + if checkpoint_file is not None: + metadata_dicts["metadata"][name]["file"] = checkpoint_file + metadata_dicts["metadata"][name]["rank"] = dist.get_rank(_get_default_group()) + with open(metadata_file, "w") as json_file: + json.dump(metadata_dicts, json_file, indent=4) + + +def load_metadata(checkpoint: str): + metadata_dict = {} + for filename in os.listdir(checkpoint): + if filename.startswith(MODEL_META_PREFIX) and filename.endswith(".json"): + file_path = os.path.join(checkpoint, filename) + try: + with open(file_path, "r") as f: + metadata_json = json.load(f) + for name, item in metadata_json["metadata"].items(): + if name not in metadata_dict: + metadata_dict[name] = {} + metadata_dict[name]["global_shape"] = item["global_shape"] + metadata_dict[name]["shards"] = {} + else: + assert metadata_dict[name]["global_shape"] == item["global_shape"] + shard = {item["rank"]: {}} + for k, v in item.items(): + if k == "rank": + continue + shard[item["rank"]][k] = v + metadata_dict[name]["shards"].update(shard) + except (json.JSONDecodeError, IOError) as e: + print(f"Unable to load file {file_path}: {e}") + return metadata_dict + + +def find_covering_shards(shards, target_offsets, target_lengths): + """ + Parameters: + + shards: A list containing information about all shards. + target_offsets: A one-dimensional array representing the starting position of the target tensor in each dimension. + target_lengths: A one-dimensional array representing the lengths of the target tensor in each dimension. + Returns: + + A list of all shards that cover the target range. + """ + target_start = target_offsets + target_end = [start + length for start, length in zip(target_offsets, target_lengths)] + + covering_shards = {} + + global_shape = None + total_lengths = None + for rank, shard in shards.items(): + shard_start = shard["offsets"] + shard_lengths = shard["lengths"] + if global_shape == None: + global_shape = shard["global_shape"] + total_lengths = [0] * len(global_shape) + shard_end = [start + length for start, length in zip(shard_start, shard_lengths)] + + overlap = any( + not (target_end[dim] <= shard_start[dim] or target_start[dim] >= shard_end[dim]) + for dim in range(len(target_start)) + ) + if overlap: + covering_shards.update({rank: shard}) + for dim in range(len(shard_start)): + total_lengths[dim] = max(total_lengths[dim], shard_start[dim] + shard_lengths[dim]) + + assert total_lengths == global_shape + return covering_shards + + +def extract_weight_from_shard_partial(shard, target_offsets, target_lengths): + """ + Extract the target range of weights from shard data, supporting partial overlap. + + param shard: A dictionary containing shard data, including 'offsets', 'lengths', and 'weight'. + param target_offsets: A 1D array indicating the starting position of the target tensor in each dimension. + param target_lengths: A 1D array indicating the length of the target tensor in each dimension. + return: The extracted sub-tensor of the target weights and its position within the target range. + """ + shard_offsets = shard["offsets"] + shard_lengths = shard["lengths"] + weight = shard["weight"] + + slices = [] + target_slices = [] + + for dim, (t_offset, t_length, s_offset, s_length) in enumerate( + zip(target_offsets, target_lengths, shard_offsets, shard_lengths) + ): + intersection_start = max(t_offset, s_offset) + intersection_end = min(t_offset + t_length, s_offset + s_length) + + if intersection_start >= intersection_end: + return None, None + + shard_slice_start = intersection_start - s_offset + shard_slice_end = intersection_end - s_offset + slices.append(slice(shard_slice_start, shard_slice_end)) + + target_slice_start = intersection_start - t_offset + target_slice_end = intersection_end - t_offset + target_slices.append(slice(target_slice_start, target_slice_end)) + + target_weight = weight[tuple(slices)] + return target_weight, target_slices + + +def assemble_tensor_from_shards_partial(shards, target_offsets, target_lengths, dtype): + target_tensor = torch.zeros(target_lengths, dtype=dtype) + + for rank, shard in shards.items(): + target_weight, target_slices = extract_weight_from_shard_partial(shard, target_offsets, target_lengths) + + if target_weight is not None and target_slices is not None: + target_tensor[tuple(target_slices)] = target_weight + + return target_tensor + + +def is_pytorch_model_meta_dist_file(checkpoint_index_file): + if MODEL_META_PREFIX in str(checkpoint_index_file): + return True + return False + + +def load_dist_model( + model_metadata: Dict, + checkpoint: str, +): + """ + Load model from a single file with the given path of checkpoint. + + Args: + model (nn.Module): The model to be loaded. + checkpoint_index_file (str): Path to the checkpoint file. + strict (bool, optional): For name matching during loading state_dict. Defaults to False. + This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled. + """ + metadata_loaded = load_metadata(checkpoint) + + load_files = {} + covered_shards = {} + for key, item in model_metadata.items(): + offsets = item["offsets"] + lengths = item["lengths"] + assert ( + item["global_shape"] == metadata_loaded[key]["global_shape"] + ), f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}" + shards = metadata_loaded[key]["shards"] + covering_shards = find_covering_shards(shards=shards, target_offsets=offsets, target_lengths=lengths) + covered_shards[key] = covering_shards + for rank, shard in covering_shards.items(): + if rank not in load_files: + load_files[rank] = set() + load_files[rank].add(shard["file"]) + + dtype = None + for rank, files in load_files.items(): + for file in files: + file_path = os.path.join(checkpoint, file) + state_dict_shard = load_state_dict(file_path) + for key, weight in state_dict_shard.items(): + if key not in covered_shards or rank not in covered_shards[key]: + continue + if dtype == None: + dtype = weight.dtype + covered_shards[key][rank]["weight"] = weight + state_dict = {} + for key, shards in covered_shards.items(): + state = assemble_tensor_from_shards_partial( + shards, model_metadata[key]["offsets"], model_metadata[key]["lengths"], dtype=dtype + ) + state_dict[key] = state + + return state_dict + +def get_dist_files_name(weights_name, dist_id): + weights_name = weights_name.replace(".bin", f"-dist-{dist_id:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-dist-{dist_id:05d}-shard.safetensors") + return weights_name + +def get_dist_meta_file_name(checkpoint, dist_id, use_safetensors): + if use_safetensors: + return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}") + return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{UNSHARD_META_SUFFIX}") \ No newline at end of file diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 154d5cb5e5f3..2c06cf4e80c1 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -5,6 +5,7 @@ from pathlib import Path from shutil import rmtree from typing import Dict, Iterator, Optional, OrderedDict, Tuple +from contextlib import nullcontext import torch import torch.distributed as dist @@ -24,6 +25,15 @@ from colossalai.utils import get_current_device, get_non_persistent_buffers_set from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat +from .distributed_checkpoint_utils import ( + is_pytorch_model_meta_dist_file, + load_dist_model, + save_metadata, + get_dist_files_name, + get_dist_meta_file_name, + MODEL_WEIGHT_PREFIX, + RestoreDefaultStateDictBehavior +) from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile from .utils import ( @@ -90,13 +100,14 @@ def __init__( self.verbose = verbose self.coordinator = DistCoordinator() - @staticmethod def _model_sharder( + self, model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024, pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, + gather_dtensor: bool = True, ) -> Iterator[Tuple[OrderedDict, int]]: # An internel method that breaks state_dict of model into shards within limited size. @@ -106,10 +117,15 @@ def _model_sharder( for name, param in model.named_parameters(): if param is None: continue - # Gather tensor pieces when using tensor parallel. - if is_padded_tensor(param): - param = to_unpadded_tensor(param) - param_ = gather_distributed_param(param, keep_vars=False) + + if gather_dtensor: + # Gather tensor pieces when using tensor parallel. + if is_padded_tensor(param): + param = to_unpadded_tensor(param) + param_ = gather_distributed_param(param, keep_vars=False) + else: + param_ = param + if pinned_state_dicts is not None: if (prefix + name) not in pinned_state_dicts: pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu") @@ -126,7 +142,7 @@ def _model_sharder( buffer = buf if keep_vars else buf.detach() if pinned_state_dicts is not None: if (prefix + name) not in pinned_state_dicts: - pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu") pinned_state_dicts[prefix + name].copy_(buffer) buffer = pinned_state_dicts[prefix + name] block, block_size = state_dict_sharder.append_param(prefix + name, buffer) @@ -142,7 +158,7 @@ def _model_sharder( extra_state = model.get_extra_state() if pinned_state_dicts is not None: if extra_state_key not in pinned_state_dicts: - pinned_state_dicts[extra_state_key] = torch.empty_like(param_, pin_memory=True, device="cpu") + pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu") pinned_state_dicts[extra_state_key].copy_(extra_state) extra_state = pinned_state_dicts[extra_state_key] block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) @@ -199,6 +215,34 @@ def _optimizer_sharder( # Return the last block in sharder. yield state_dict_sharder.current_block, state_dict_sharder.current_block_size + def create_model_metadata( + self, + model: ModelWrapper, + prefix: str = "", + ): + param_origin_shape = model.param_origin_shape + model = model.unwrap() + model_metadata = {} + for name, param in model.named_parameters(): + if param is None: + continue + model_metadata[prefix + name] = {} + original_shape = param_origin_shape[name] + tp_partition_dim = search_tp_partition_dim( + current_shape=param.shape, original_shape=original_shape, tp_size=self.tp_size + ) + model_metadata[prefix + name]["offsets"] = [0] * len(original_shape) + model_metadata[prefix + name]["lengths"] = list(param.shape) + model_metadata[prefix + name]["global_shape"] = list(original_shape) + if tp_partition_dim is not None: + partition_size = param.shape[tp_partition_dim] + model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * self.tp_rank + if self.tp_rank == self.tp_size - 1: + model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[tp_partition_dim] - ( + partition_size * (self.tp_size - 1) + ) + return model_metadata + def save_sharded_model( self, model: ModelWrapper, @@ -230,6 +274,14 @@ def save_sharded_model( assert isinstance(model, ModelWrapper), "Please boost the model before saving!" model._force_wait_all_gather() + if self.dp_rank != 0 and self.sp_rank != 0: + return + + model_metadata = None + if not gather_dtensor: + # Manage filenames of sharded weights and index file for each pipeline stage. + model_metadata = self.create_model_metadata(model) + model = model.unwrap() if os.path.isfile(checkpoint): @@ -237,28 +289,30 @@ def save_sharded_model( return Path(checkpoint).mkdir(parents=True, exist_ok=True) - # Devices along the same dp_group share the same copies of model. - # So only let the device with dp_rank == 0 save the model. - if self.dp_rank != 0: - return # Then collect the sharded parameters & buffers along tp_group. # Only devices with tp_rank == 0 are responsible for model saving. - control_saving = self.tp_rank == 0 and self.sp_rank == 0 + control_saving = self.tp_rank == 0 if gather_dtensor else True if control_saving and use_async: if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = {} pinned_state_dicts = self.pinned_state_dicts[id(model)] else: pinned_state_dicts = None - state_dict_shard = HybridParallelCheckpointIO._model_sharder( - model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts + state_dict_shard = self._model_sharder( + model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts, gather_dtensor=gather_dtensor ) + weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) - if self.pp_size == 1: + if self.pp_size == 1 or not gather_dtensor: # When pipeline is not used, save the model shards as in general checkpointIO + if not gather_dtensor: + dist_id = self.tp_size * self.pp_rank + self.tp_rank + weights_name = get_dist_files_name(weights_name=weights_name, dist_id=dist_id) + metadata_file = get_dist_meta_file_name(checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_safetensors) + if use_async: total_size, writers = async_save_state_dict_shards( sharded_state_dict=state_dict_shard, @@ -278,16 +332,22 @@ def save_sharded_model( is_master=control_saving, use_safetensors=use_safetensors, ) - if control_saving: - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model, checkpoint) - if self.verbose and self.coordinator.is_master(): - logging.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + if not gather_dtensor: + # saving metadata for distributed checkpoint + for k, _ in model_metadata.items(): + model_metadata[k]["file"] = index_file.get_checkpoint_file(k) + save_metadata(model_metadata, metadata_file, total_size=total_size) + else: + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) else: # When pipeline is used, each stage produces its own shard files and index files. @@ -298,9 +358,9 @@ def save_sharded_model( Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) # Manage filenames of sharded weights and index file for each pipeline stage. - weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") - weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") - save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") + weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank:05d}-shard.bin") + weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank:05d}-shard.safetensors") + save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) if use_async: total_size, writers = async_save_state_dict_shards( @@ -374,6 +434,21 @@ def load_sharded_model( """ assert isinstance(model, ModelWrapper), "Please boost the model before loading!" model._force_wait_all_gather() + + if is_pytorch_model_meta_dist_file(checkpoint_index_file): + model_metadata = self.create_model_metadata(model) + checkpoint = checkpoint_index_file.parent + state_dict = load_dist_model( + model_metadata=model_metadata, + checkpoint=checkpoint, + ) + model = model.unwrap() + with RestoreDefaultStateDictBehavior(model): + load_state_dict_into_model( + model, state_dict, missing_keys=[], strict=False, load_sub_module=True + ) + return + model_before_wrapping = model # backup for model before wrapping model = model.unwrap() @@ -762,28 +837,44 @@ def save_unsharded_model( assert isinstance(model, ModelWrapper), "Please boost the model before saving!" model._force_wait_all_gather() - model = model.unwrap() - if self.dp_rank != 0: + + if self.dp_rank != 0 and self.sp_rank != 0: return + if not gather_dtensor: + dist_id = self.tp_size * self.pp_rank + self.tp_rank + Path(checkpoint).mkdir(parents=True, exist_ok=True) + model_metadata = self.create_model_metadata(model) + checkpoint_file = os.path.join(checkpoint, f"{MODEL_WEIGHT_PREFIX}{dist_id:05d}.bin") + if use_async: + checkpoint_file = checkpoint_file.replace(".bin", f".safetensors") + metadata_file = get_dist_meta_file_name(checkpoint=checkpoint, dist_id=dist_id, use_safetensors=use_async) + save_metadata(model_metadata=model_metadata, metadata_file=metadata_file, checkpoint_file=checkpoint_file) + else: + checkpoint_file = checkpoint + + model = model.unwrap() + # The logic of collecting parameter shards along tp degree # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. - state_dict = model.state_dict() - if self.pp_size == 1: - # When pipeline is not used, let master rank directly save the collected state_dict. - if self.tp_rank == 0: - if use_async: - from colossalai.utils.safetensors import save + ctx = RestoreDefaultStateDictBehavior(model) if not gather_dtensor else nullcontext() + with ctx: + state_dict = model.state_dict() - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - for name, param in state_dict.items(): - self.pinned_state_dicts[id(model)][name].copy_(param) - state_dict[name] = self.pinned_state_dicts[id(model)][name] - writer = save(path=checkpoint, state_dict=state_dict) - self.async_writers.append(writer) - else: - save_state_dict(state_dict, checkpoint, use_safetensors) + if (self.pp_size == 1 and self.tp_rank == 0) or not gather_dtensor: + # When pipeline is not used, let master rank directly save the collected state_dict. + if use_async: + from colossalai.utils.safetensors import save + + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + for name, param in state_dict.items(): + self.pinned_state_dicts[id(model)][name].copy_(param) + state_dict[name] = self.pinned_state_dicts[id(model)][name] + writer = save(path=checkpoint_file, state_dict=state_dict) + self.async_writers.append(writer) + else: + save_state_dict(state_dict, checkpoint_file, use_safetensors) else: # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. state_dict_list = [None for _ in range(self.pp_size)] @@ -802,10 +893,10 @@ def save_unsharded_model( for name, param in complete_state_dict.items(): self.pinned_state_dicts[id(model)][name].copy_(param) complete_state_dict[name] = self.pinned_state_dicts[id(model)][name] - writer = save(path=checkpoint, state_dict=complete_state_dict) + writer = save(path=checkpoint_file, state_dict=complete_state_dict) self.async_writers.append(writer) else: - save_state_dict(complete_state_dict, checkpoint, use_safetensors) + save_state_dict(complete_state_dict, checkpoint_file, use_safetensors) def load_unsharded_model( self, @@ -829,6 +920,18 @@ def load_unsharded_model( assert isinstance(model, ModelWrapper), "Please boost the model before loading!" model._force_wait_all_gather() + + load_dtensor = False + if os.path.isdir(checkpoint): + for filename in os.listdir(checkpoint): + if is_pytorch_model_meta_dist_file(filename): + load_dtensor = True + break + + model_metadata = None # used for dist model + if load_dtensor: + model_metadata = self.create_model_metadata(model) + strict = False model_before_wrapping = model model = model.unwrap() @@ -836,10 +939,17 @@ def load_unsharded_model( # Load from checkpoint. Since the logic of breaking parameter shards along tp degree # has been implemented by _load_from_state_dict method of ParallelModule in Shardformer, # model.load_state_dict can be directly called. - state_dict = load_state_dict(checkpoint) + if load_dtensor: + state_dict = load_dist_model(model_metadata=model_metadata, checkpoint=checkpoint) + else: + state_dict = load_state_dict(checkpoint) + if not low_cpu_mem_mode: state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads) - model.load_state_dict(state_dict, strict=strict) + + ctx = RestoreDefaultStateDictBehavior(model) if load_dtensor else nullcontext() + with ctx: + model.load_state_dict(state_dict, strict=strict) # Update master params if mixed-precision training is enabled. model_before_wrapping.update_master_params() diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 50b6f1438961..524fc3b2190e 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -309,12 +309,13 @@ def async_save_state_dict_shards( checkpoint_file_path = os.path.join(checkpoint, shard_file) if state_preprocess: - state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".") + state_dict, metadata = _flatten_optim_state_dict(state_dict=shard, seperator=".") else: state_dict = shard + metadata = None # Only save on master rank. - writer = save(checkpoint_file_path, state_dict=state_dict) + writer = save(checkpoint_file_path, state_dict=state_dict, metadata=metadata) writers.append(writer) shard_filenames.append(shard_file) del shard @@ -371,9 +372,10 @@ def async_move_save_state_dict_shards( checkpoint_file_path = os.path.join(checkpoint, shard_file) if state_preprocess: - state_dict, _ = _flatten_optim_state_dict(state_dict=shard) + state_dict, metadata = _flatten_optim_state_dict(state_dict=shard) else: state_dict = shard + metadata = None if pinned_state_dict is not None: sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()} @@ -382,7 +384,7 @@ def async_move_save_state_dict_shards( returned_state_dict.update(sub_pinned_state_dict) # Only save on master rank. - writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict) + writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict, metadata) writers.append(writer) shard_filenames.append(shard_file) del shard @@ -854,13 +856,7 @@ def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: # check if there is only one a file ending with .index.json in this directory index_files = list(checkpoint_path.glob("*.index.*json")) - # if we found a .index.json file, make sure there is only one - if len(index_files) > 0: - assert ( - len(index_files) == 1 - ), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}" - - if len(index_files) == 1: + if len(index_files) >= 1: return True, index_files[0] else: return False, None @@ -943,8 +939,8 @@ def get_shard_filename(weights_name: str, idx: int): """ get shard file name """ - shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin") - shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors") + shard_file = weights_name.replace(".bin", f"-{idx:05d}.bin") + shard_file = shard_file.replace(".safetensors", f"-{idx:05d}.safetensors") return shard_file diff --git a/tests/test_checkpoint_io/test_dist_checkpointio.py b/tests/test_checkpoint_io/test_dist_checkpointio.py new file mode 100644 index 000000000000..5aa9c4f1fadd --- /dev/null +++ b/tests/test_checkpoint_io/test_dist_checkpointio.py @@ -0,0 +1,140 @@ +import pytest +import torch +import torch.distributed as dist +from torch.optim import Adam +from utils import shared_tempdir + +import colossalai +from colossalai.booster import Booster +from colossalai.booster.plugin import HybridParallelPlugin +from colossalai.shardformer.layer.utils import Randomizer +from colossalai.tensor.d_tensor.api import clear_layout_converter +from colossalai.testing import ( + check_state_dict_equal, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) +from tests.kit.model_zoo import model_zoo + +TEST_CONFIGS = [ + ( + {"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1}, + {"tp_size": 2, "pp_size": 1, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1}, + ) +] + + +@parameterize("shard", [False, True]) +@parameterize("model_name", ["transformers_llama_for_causal_lm"]) +@parameterize("size_per_shard", [1]) +@parameterize("test_config", TEST_CONFIGS) +@parameterize("use_async", [False, True]) +@parameterize("low_cpu_mem_mode", [False, True]) +@clear_cache_before_run() +def exam_state_dict( + shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool, low_cpu_mem_mode: bool +): + (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( + iter(model_zoo.get_sub_registry(model_name).values()) + ) + criterion = loss_fn + test_config_0, test_config_1 = test_config + plugin_0 = HybridParallelPlugin(**test_config_0) + booster_0 = Booster(plugin=plugin_0) + + def _criterion(outputs, inputs): + outputs = output_transform_fn(outputs) + loss = criterion(outputs) + return loss + + def _preprocess_data(data): + if booster_0.plugin.stage_manager is not None: + for k, v in data.items(): + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: + new_shape = [1] * v.dim() + new_shape[0] = 4 + data[k] = v.to("cuda").repeat(*new_shape) + return iter([data]) + else: + return {k: v.cuda() for k, v in data.items()} + + model_0 = model_fn().cuda() + optimizer_0 = Adam(model_0.parameters(), lr=1e-3) + model_0, optimizer_0, criterion, _, _ = booster_0.boost(model_0, optimizer_0, criterion) + + data = data_gen_fn() + model_0.train() + if booster_0.plugin.stage_manager is not None: + booster_0.execute_pipeline(_preprocess_data(data), model_0, _criterion, optimizer_0, return_loss=True) + else: + output = model_0(**_preprocess_data(data)) + loss = criterion(output) + optimizer_0.backward(loss) + + optimizer_0.step() + optimizer_0.zero_grad() + with shared_tempdir() as tempdir: + model_ckpt_path_0 = f"{tempdir}/model_0" + + booster_0.save_model( + model_0, + model_ckpt_path_0, + shard=shard, + gather_dtensor=False, + size_per_shard=size_per_shard, + use_async=use_async, + ) + booster_0.checkpoint_io._sync_d2h() + booster_0.checkpoint_io._sync_io() + dist.barrier() + + plugin_1 = HybridParallelPlugin(**test_config_1) + booster_1 = Booster(plugin=plugin_1) + + model_1 = model_fn().cuda() + optimizer_1 = Adam(model_1.parameters(), lr=1e-3) + model_1, optimizer_1, criterion, _, _ = booster_1.boost(model_1, optimizer_1, criterion) + + booster_1.load_model(model_1, model_ckpt_path_0, low_cpu_mem_mode=low_cpu_mem_mode) + + model_ckpt_path_1 = f"{tempdir}/model_1" + booster_1.save_model( + model_1, + model_ckpt_path_1, + shard=shard, + gather_dtensor=False, + size_per_shard=size_per_shard, + use_async=use_async, + ) + booster_1.checkpoint_io._sync_d2h() + booster_1.checkpoint_io._sync_io() + dist.barrier() + + model_2 = model_fn().cuda() + optimizer_2 = Adam(model_2.parameters(), lr=1e-3) + model_2, optimizer_2, criterion, _, _ = booster_0.boost(model_2, optimizer_2, criterion) + + booster_0.load_model(model_2, model_ckpt_path_1, low_cpu_mem_mode=low_cpu_mem_mode) + check_state_dict_equal(model_0.unwrap().state_dict(), model_2.unwrap().state_dict()) + + dist.barrier() + Randomizer.reset_index() + clear_layout_converter() + + +def run_dist(rank, world_size, port): + colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") + exam_state_dict() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_hybrid_ckpIO(world_size): + spawn(run_dist, world_size) + + +if __name__ == "__main__": + test_hybrid_ckpIO(4)