Skip to content

Commit

Permalink
better error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Jan 24, 2025
1 parent 740aa65 commit 2adddfd
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,14 @@

from olmo_core.aliases import PathOrStr
from olmo_core.config import StrEnum
from olmo_core.io import clear_directory, dir_is_empty, is_url, normalize_path
from olmo_core.io import (
clear_directory,
dir_is_empty,
file_exists,
is_url,
join_path,
normalize_path,
)
from olmo_core.utils import gc_cuda, get_element_size, wait_for

from ..utils import barrier, get_fs_local_rank, is_distributed
Expand Down Expand Up @@ -445,10 +452,7 @@ def unshard_checkpoint(
target_dir.mkdir(exist_ok=True, parents=True)

ext = "pt" if not use_safetensors else "safetensors"
try:
metadata = get_checkpoint_metadata(dir)
except FileNotFoundError as exc:
raise FileNotFoundError(f"'{dir}' does not appear to be a model/optim checkpoint") from exc
metadata = get_checkpoint_metadata(dir)

def save(state_dict: Dict[str, Any], path: Path):
if path.is_file() and not save_overwrite:
Expand Down Expand Up @@ -559,6 +563,8 @@ def load_keys(
raise RuntimeError("'load_keys' cannot be called in a distributed context")

dir = normalize_path(dir)
# validate checkpoint.
get_checkpoint_metadata(dir)

keys = list(keys)
state_dict = _load_unsharded_keys(dir, keys, pre_download=pre_download, work_dir=work_dir)
Expand All @@ -573,7 +579,13 @@ def get_checkpoint_metadata(dir: PathOrStr) -> Metadata:
:param dir: The path/URL to the checkpoint.
"""
dir = normalize_path(dir)
storage_reader = RemoteFileSystemReader(dir)
try:
storage_reader = RemoteFileSystemReader(dir)
except FileNotFoundError as exc:
msg = f"'{dir}' does not appear to contain a state dict checkpoint."
if file_exists((suggested_path := join_path(dir, "model_and_optim/.metadata"))):
msg += f" Did you mean to use '{suggested_path}'?"
raise FileNotFoundError(msg) from exc
return storage_reader.read_metadata()


Expand Down

0 comments on commit 2adddfd

Please sign in to comment.