Skip to content

Commit

Permalink
improve docs and logging
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 24, 2025
1 parent bcce34d commit 740aa65
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import torch.distributed.checkpoint as dist_cp
import torch.distributed.checkpoint.state_dict as dist_cp_sd
import torch.nn as nn
from rich.progress import track
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.metadata import Metadata, TensorStorageMetadata

Expand Down Expand Up @@ -323,7 +324,8 @@ class UnshardStrategyType(StrEnum):

chunks = "chunks"
"""
Save multiple tensors to a file.
Like :data:`one_file_per_tensor` but multiple tensors and objects may be grouped into the same file
up to the limit defined by :data:`UnshardStrategy.chunk_size_bytes`.
"""


Expand All @@ -340,7 +342,7 @@ class UnshardStrategy:

chunk_size_bytes: Optional[int] = None
"""
The approximate max chunk size, in bytes, for the :data:`UnshardStrategyType.chunks` strategy.
The approximate max chunk size (per file size), in bytes, for the :data:`UnshardStrategyType.chunks` strategy.
"""

def __post_init__(self):
Expand Down Expand Up @@ -381,6 +383,7 @@ def unshard_checkpoint(
unshard_strategy: Optional[UnshardStrategy] = None,
pre_download: bool = False,
work_dir: Optional[PathOrStr] = None,
quiet: bool = False,
) -> Tuple[Path, Optional[Path]]:
"""
Convert a checkpoint saved via :func:`save_model_and_optim_state()` into unsharded
Expand All @@ -395,6 +398,9 @@ def unshard_checkpoint(
.. warning::
This should only be called in a non-distributed context. Otherwise a :class:`RuntimeError` is raised.
.. seealso::
:func:`load_keys()` if you only need to load and unshard certain keys in the checkpoint.
:param dir: The path/URL to the original checkpoint created via :func:`save_model_and_optim_state()`.
:param target_dir: The directory to save the unsharded model/optimizer checkpoint files to.
This must be a local directory. URLs are not supported.
Expand All @@ -406,6 +412,7 @@ def unshard_checkpoint(
:param unshard_strategy: The strategy to use. Defaults to :meth:`UnshardStrategy.one_file`.
:param pre_download: Download and cache relevant remote checkpoint files before trying to read from them.
:param work_dir: A working directory for caching files/directories.
:param quiet: Do not show progress messages.
:return: The path to the unsharded model checkpoint and the path to the unsharded
optimizer checkpoint if ``optim=True``. These paths may represent files or directories
Expand Down Expand Up @@ -438,7 +445,10 @@ def unshard_checkpoint(
target_dir.mkdir(exist_ok=True, parents=True)

ext = "pt" if not use_safetensors else "safetensors"
metadata = get_checkpoint_metadata(dir)
try:
metadata = get_checkpoint_metadata(dir)
except FileNotFoundError as exc:
raise FileNotFoundError(f"'{dir}' does not appear to be a model/optim checkpoint") from exc

def save(state_dict: Dict[str, Any], path: Path):
if path.is_file() and not save_overwrite:
Expand Down Expand Up @@ -507,13 +517,17 @@ def unshard_chunk(prefix: str, path: Path, keys: List[str]):
gc_cuda()

model_path, model_chunks = get_chunks("model")
for chunk_path, chunk_keys in model_chunks:
for chunk_path, chunk_keys in track(
model_chunks, description="Unsharding model chunks...", disable=quiet
):
unshard_chunk("model", chunk_path, chunk_keys)

optim_path: Optional[Path] = None
if optim:
optim_path, optim_chunks = get_chunks("optim")
for chunk_path, chunk_keys in optim_chunks:
for chunk_path, chunk_keys in track(
optim_chunks, description="Unsharding optim chunks...", disable=quiet
):
unshard_chunk("optim", chunk_path, chunk_keys)

return model_path, optim_path
Expand All @@ -538,7 +552,8 @@ def load_keys(
:param pre_download: Download and cache relevant remote checkpoint files before trying to read from them.
:param work_dir: A working directory for caching files/directories.
:returns: The (unsharded) objects from the checkpoint corresponding to the given keys.
:returns: The (unsharded) objects from the checkpoint corresponding to the given keys, in the
same order as the keys.
"""
if is_distributed():
raise RuntimeError("'load_keys' cannot be called in a distributed context")
Expand Down

0 comments on commit 740aa65

Please sign in to comment.