From e8659ea9814758ed7acc3c85f7d0aa9310082e63 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Jan 2025 03:26:59 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/checkpoint_io/__init__.py | 1 - .../checkpoint_io/general_checkpoint_io.py | 2 +- .../hybrid_parallel_checkpoint_io.py | 60 ++++++++++++++----- .../test_dist_checkpointio.py | 14 ++++- 4 files changed, 57 insertions(+), 20 deletions(-) diff --git a/colossalai/checkpoint_io/__init__.py b/colossalai/checkpoint_io/__init__.py index 5d8b65e3b384..ef37534fe01a 100644 --- a/colossalai/checkpoint_io/__init__.py +++ b/colossalai/checkpoint_io/__init__.py @@ -1,7 +1,6 @@ from .checkpoint_io_base import CheckpointIO from .general_checkpoint_io import GeneralCheckpointIO from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO - from .index_file import CheckpointIndexFile from .moe_checkpoint import MoECheckpointIO diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index d5ed5b848de3..c38958ee31b9 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -309,4 +309,4 @@ def load_sharded_model( ) def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 44c119eef6d5..dd1dd4258d9e 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -24,6 +24,13 @@ from colossalai.utils import get_current_device, get_non_persistent_buffers_set from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat +from .distributed_checkpoint_utils import ( + create_model_metadata, + is_pytorch_model_meta_dist_file, + load_dist_model, + save_dist_sharded_model, + save_dist_unshard_model, +) from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile from .utils import ( @@ -47,14 +54,6 @@ sharded_optimizer_loading_epilogue, ) -from .distributed_checkpoint_utils import ( - save_dist_sharded_model, - save_dist_unshard_model, - load_dist_model, - is_pytorch_model_meta_dist_file, - create_model_metadata -) - try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: @@ -244,9 +243,19 @@ def save_sharded_model( return dist_id = self.tp_size * self.pp_rank + self.tp_rank model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) - save_dist_sharded_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, prefix=prefix, size_per_shard=size_per_shard, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) + save_dist_sharded_model( + model=model, + model_metadata=model_metadata, + checkpoint=checkpoint, + prefix=prefix, + size_per_shard=size_per_shard, + use_safetensors=use_safetensors, + use_async=use_async, + dist_id=dist_id, + pinned_state_dicts=self.pinned_state_dicts, + ) return - + model = model.unwrap() if os.path.isfile(checkpoint): @@ -394,9 +403,15 @@ def load_sharded_model( if is_pytorch_model_meta_dist_file(checkpoint_index_file): model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) - load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint_index_file, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) + load_dist_model( + model=model, + model_metadata=model_metadata, + checkpoint=checkpoint_index_file, + low_cpu_mem_mode=low_cpu_mem_mode, + num_threads=num_threads, + ) return - + model_before_wrapping = model # backup for model before wrapping model = model.unwrap() @@ -792,9 +807,17 @@ def save_unsharded_model( if self.dp_rank != 0 and self.sp_rank != 0: return dist_id = self.tp_size * self.pp_rank + self.tp_rank - save_dist_unshard_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts) + save_dist_unshard_model( + model=model, + model_metadata=model_metadata, + checkpoint=checkpoint, + use_safetensors=use_safetensors, + use_async=use_async, + dist_id=dist_id, + pinned_state_dicts=self.pinned_state_dicts, + ) return - + model = model.unwrap() if self.dp_rank != 0: return @@ -867,7 +890,13 @@ def load_unsharded_model( for filename in os.listdir(checkpoint): if is_pytorch_model_meta_dist_file(filename): model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank) - load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads) + load_dist_model( + model=model, + model_metadata=model_metadata, + checkpoint=checkpoint, + low_cpu_mem_mode=low_cpu_mem_mode, + num_threads=num_threads, + ) return strict = False @@ -1099,7 +1128,6 @@ def gather_from_sharded_optimizer_state( dist.all_gather(gather_tensor, v, group=dp_group) v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) - # Then gather TP shards. partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size) if partition_dim is not None: diff --git a/tests/test_checkpoint_io/test_dist_checkpointio.py b/tests/test_checkpoint_io/test_dist_checkpointio.py index 08354c214a62..850a10c17ce6 100644 --- a/tests/test_checkpoint_io/test_dist_checkpointio.py +++ b/tests/test_checkpoint_io/test_dist_checkpointio.py @@ -79,7 +79,12 @@ def _preprocess_data(data): model_ckpt_path_0 = f"{tempdir}/model_0" booster_0.save_model( - model_0, model_ckpt_path_0, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async + model_0, + model_ckpt_path_0, + shard=shard, + gather_dtensor=True, + size_per_shard=size_per_shard, + use_async=use_async, ) booster_0.checkpoint_io._sync_d2h() booster_0.checkpoint_io._sync_io() @@ -96,7 +101,12 @@ def _preprocess_data(data): model_ckpt_path_1 = f"{tempdir}/model_1" booster_1.save_model( - model_1, model_ckpt_path_1, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async + model_1, + model_ckpt_path_1, + shard=shard, + gather_dtensor=True, + size_per_shard=size_per_shard, + use_async=use_async, ) booster_1.checkpoint_io._sync_d2h() booster_1.checkpoint_io._sync_io()