Skip to content

Commit

Permalink
Remove duplicates
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 committed Jan 21, 2025
1 parent f388bbe commit 6a8a917
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 351 deletions.
285 changes: 30 additions & 255 deletions colossalai/checkpoint_io/distributed_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from colossalai.interface import ModelWrapper
from colossalai.utils import get_non_persistent_buffers_set
from colossalai.shardformer.layer.parallel_module import ParallelModule
from contextlib import contextmanager

from .index_file import CheckpointIndexFile
from .utils import (
Expand All @@ -32,67 +34,32 @@
MODEL_META_PREFIX = "pytorch_model-meta-dist-"
MODEL_WEIGHT_PREFIX = "pytorch_model-dist-"
SHARD_META_SUFFIX = ".index.json"
UNSHARD_META_SUFFIX = ".json"


def dist_model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False):
destination = dict()
# Save parameters.
for name, param in model.named_parameters():
if param is None:
continue
destination[prefix + name] = param
# Save buffers.
non_persist_buffers_set = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persist_buffers_set:
buffer = buf if keep_vars else buf.detach()
destination[prefix + name] = buffer

# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
destination[extra_state_key] = extra_state
return destination


def load_state_dict_into_dist_model(
model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False
):
destination = dict()
# Save parameters.
for name, param in model.named_parameters():
if param is None:
continue
with torch.no_grad():
param.copy_(state_dict[prefix + name])
# Save buffers.
non_persist_buffers_set = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persist_buffers_set:
with torch.no_grad():
buf.copy_(state_dict[prefix + name])

# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
with torch.no_grad():
extra_state.copy_(state_dict[extra_state_key])
return destination
@contextmanager
def RestoreDefaultStateDictBehavior(model):
original_methods = {}
for name, module in model.named_modules():
if isinstance(module, ParallelModule):
original_methods[module] = (module._save_to_state_dict, module._load_from_state_dict)
module._save_to_state_dict = nn.Module._save_to_state_dict.__get__(module, nn.Module)
module._load_from_state_dict = nn.Module._load_from_state_dict.__get__(module, nn.Module)
try:
yield model
finally:
for module, original_method in original_methods.items():
module._save_to_state_dict, module._load_from_state_dict = original_method



def create_model_metadata(
model: nn.Module,
model: ModelWrapper,
prefix: str = "",
tp_size=None,
tp_rank=None,
tp_size: int = None,
tp_rank: int = None,
zero_size: int = None,
zero_rank: int = None,
):
param_origin_shape = model.param_origin_shape
model = model.unwrap()
Expand All @@ -105,7 +72,7 @@ def create_model_metadata(
tp_partition_dim = search_tp_partition_dim(
current_shape=param.shape, original_shape=original_shape, tp_size=tp_size
)
model_metadata[prefix + name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int)
model_metadata[prefix + name]["offsets"] = [0] * len(original_shape)
model_metadata[prefix + name]["lengths"] = list(param.shape)
model_metadata[prefix + name]["global_shape"] = list(original_shape)
if tp_partition_dim is not None:
Expand Down Expand Up @@ -257,119 +224,9 @@ def is_pytorch_model_meta_dist_file(checkpoint_index_file):
return False


def dist_model_sharder(
model: nn.Module,
prefix: str = "",
keep_vars: bool = False,
size_per_shard: int = 1024,
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.

state_dict_sharder = StateDictSharder(size_per_shard)

# Save parameters.
for name, param in model.named_parameters():
if param is None:
continue
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(param, pin_memory=True, device="cpu")
pinned_state_dicts[prefix + name].copy_(param)
param = pinned_state_dicts[prefix + name]
block, block_size = state_dict_sharder.append_param(prefix + name, param)
if block is not None:
yield block, block_size

# Save buffers.
non_persist_buffers_set = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persist_buffers_set:
buffer = buf if keep_vars else buf.detach()
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu")
pinned_state_dicts[prefix + name].copy_(buffer)
buffer = pinned_state_dicts[prefix + name]
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size

# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if (
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
if pinned_state_dicts is not None:
if extra_state_key not in pinned_state_dicts:
pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu")
pinned_state_dicts[extra_state_key].copy_(extra_state)
extra_state = pinned_state_dicts[extra_state_key]
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size

# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size


def save_dist_unshard_model(
model: ModelWrapper,
model_metadata: Dict,
checkpoint: str,
use_safetensors: bool,
use_async: bool = False,
dist_id=0,
pinned_state_dicts=None,
):
"""
Save model state dict to a single file with given checkpointing path.
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a file path. Can be absolute or relative path.
gather_dtensor (bool, optional): Whether to gather dtensor, currently not used. Defaults to True.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""

model = model.unwrap()

# The logic of collecting parameter shards along tp degree
# has been implemented by _save_to_state_dict method of ParallelModule in Shardformer.
state_dict = dist_model_state_dict(model)

Path(checkpoint).mkdir(parents=True, exist_ok=True)
file_name = f"{MODEL_WEIGHT_PREFIX}{dist_id:05d}.bin"
if use_async:
file_name = file_name.replace(".bin", ".safetensors")
checkpoint_file = os.path.join(checkpoint, file_name)
metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}.json")
save_metadata(model_metadata, metadata_file, file_name)

if use_async:
from colossalai.utils.safetensors import save

if id(model) not in pinned_state_dicts:
pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
for name, param in state_dict.items():
pinned_state_dicts[id(model)][name].copy_(param)
state_dict[name] = pinned_state_dicts[id(model)][name]
writer = save(path=checkpoint_file, state_dict=state_dict)
return writer
else:
save_state_dict(state_dict, checkpoint_file, use_safetensors)
return None


def load_dist_model(
model: ModelWrapper,
model_metadata: Dict,
checkpoint: str,
low_cpu_mem_mode: bool = True,
num_threads: int = 1,
):
"""
Load model from a single file with the given path of checkpoint.
Expand All @@ -380,10 +237,6 @@ def load_dist_model(
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since not all params in checkpoint are needed for each device when pipeline is enabled.
"""

model_before_wrapping = model
model = model.unwrap()

metadata_loaded = load_metadata(checkpoint)

load_files = {}
Expand Down Expand Up @@ -420,92 +273,14 @@ def load_dist_model(
)
state_dict[key] = state

if not low_cpu_mem_mode:
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)

load_state_dict_into_dist_model(model=model, state_dict=state_dict)

# Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params()

return state_dict

def save_dist_sharded_model(
model: ModelWrapper,
model_metadata: Dict,
checkpoint: str,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
dist_id: int = 0,
pinned_state_dicts=None,
) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a directory path.
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False.
"""

model = model.unwrap()

if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return

Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 and sp_rank == 0 save the model.

if use_async:
if id(model) not in pinned_state_dicts:
pinned_state_dicts[id(model)] = {}
pinned_state_dicts = pinned_state_dicts[id(model)]
else:
pinned_state_dicts = None
state_dict_shard = dist_model_sharder(model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts)
weights_name, _ = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)

# Manage filenames of sharded weights and index file for each pipeline stage.
def get_dist_files_name(weights_name, dist_id):
weights_name = weights_name.replace(".bin", f"-dist-{dist_id:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-dist-{dist_id:05d}-shard.safetensors")
metadata_file = os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}")
async_writers = []
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=True,
state_preprocess=False,
)
async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=True,
use_safetensors=use_safetensors,
use_pp_format=True,
)
for k, _ in model_metadata.items():
model_metadata[k]["file"] = index_file.get_checkpoint_file(k)
return weights_name

save_metadata(model_metadata, metadata_file, total_size=total_size)
return async_writers
def get_dist_meta_file_name(checkpoint, dist_id, use_safetensors):
if use_safetensors:
return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{SHARD_META_SUFFIX}")
return os.path.join(checkpoint, f"{MODEL_META_PREFIX}{dist_id:05d}{UNSHARD_META_SUFFIX}")
Loading

0 comments on commit 6a8a917

Please sign in to comment.