diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fa8dc29e..545411fa 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -184,7 +184,6 @@ jobs: cluster: # H100 clusters - ai2/jupiter-cirrascale-2 - - ai2/pluto-cirrascale - ai2/augusta-google-1 # A100 clusters - ai2/saturn-cirrascale diff --git a/CHANGELOG.md b/CHANGELOG.md index 8276edc6..c85f4245 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Changed storage of shared shard state in sharded checkpoints from smallest shard to lowest rank (normally 0). +- Changed how the trainer handles loading a checkpoint when `load_path` is provided. Now `load_path` is only used if no checkpoint is found in the `save_folder`. ### Fixed diff --git a/src/olmo_core/train/callbacks/checkpointer.py b/src/olmo_core/train/callbacks/checkpointer.py index b6e6fbd1..73273a36 100644 --- a/src/olmo_core/train/callbacks/checkpointer.py +++ b/src/olmo_core/train/callbacks/checkpointer.py @@ -216,7 +216,10 @@ def pre_train(self): path for _, path in sorted(ephemeral_checkpoints, key=lambda x: x[0]) ] for path in self._ephemeral_checkpoints: - log.info(f"Collected existing ephemeral checkpoint at '{path}'") + log.info( + f"Found existing ephemeral checkpoint at '{path}' which will " + "be removed when the next checkpoint is saved" + ) def post_train_batch(self): self._await_last_checkpoint(blocking=False) diff --git a/src/olmo_core/train/common.py b/src/olmo_core/train/common.py index b988899f..5f01c6e3 100644 --- a/src/olmo_core/train/common.py +++ b/src/olmo_core/train/common.py @@ -75,17 +75,19 @@ class LoadStrategy(StrEnum): if_available = "if_available" """ - Only load from the load path if a checkpoint exists there. + The trainer will attempt to load a checkpoint from the save folder or load path (in that order) + but will train from scratch if no checkoint is found. """ always = "always" """ - Always try loading from the load path. + The trainer will attempt to load a checkpoint from the save folder or load path (in that order) + and raise an error if no checkpoint is found. """ never = "never" """ - Never load from the load path. + The trainer will never load a checkpoint even if one exists in the save folder or load path. """ diff --git a/src/olmo_core/train/trainer.py b/src/olmo_core/train/trainer.py index 953dc4d7..0c8d17aa 100644 --- a/src/olmo_core/train/trainer.py +++ b/src/olmo_core/train/trainer.py @@ -173,8 +173,10 @@ class Trainer: load_path: Optional[PathOrStr] = None """ - Where to load a checkpoint from prior to training. - Defaults to ``save_folder``. + An alternative location to load a checkpoint from if no checkpoint is found in the current :data:`save_folder`. + + This can be set to a checkpoint path or the path to a folder of checkpoints such as the :data:`save_folder` + from a different run. """ load_strategy: LoadStrategy = LoadStrategy.if_available @@ -538,7 +540,7 @@ def check_if_canceled(self): def fit(self): """ - Fit the model, potentially loading a checkpoint before hand depending on the + Fit the model, potentially loading a checkpoint first depending on the :data:`load_strategy`. """ self._canceled = False @@ -546,12 +548,30 @@ def fit(self): self._canceling_rank = None # Maybe load a checkpoint. - if not self.checkpoint_loaded: - load_path = self.load_path if self.load_path is not None else self.save_folder - if self.load_strategy == LoadStrategy.always: - self.load_checkpoint(load_path) - elif self.load_strategy == LoadStrategy.if_available: - self.maybe_load_checkpoint(load_path) + if not self.checkpoint_loaded and self.load_strategy != LoadStrategy.never: + # Try loading from the save folder first. + self.maybe_load_checkpoint(self.save_folder) + + # Then fallback to the load path, if provided. + if self.load_path is not None: + if not self.checkpoint_loaded: + self.maybe_load_checkpoint(self.load_path) + else: + log.warning( + f"Ignoring load path ('{self.load_path}') since checkpoint was found in save folder" + ) + + if not self.checkpoint_loaded: + if self.load_strategy == LoadStrategy.always: + raise FileNotFoundError( + f"No checkpoint found in save folder ('{self.save_folder}') or " + f"load path ('{self.load_path}')" + ) + else: + log.warning( + f"No checkpoint found in save folder ('{self.save_folder}') or " + f"load path ('{self.load_path}'), will train from scratch..." + ) log.info(f"Training for {self.max_steps:,d} steps") @@ -709,9 +729,10 @@ def maybe_load_checkpoint( load_optimizer_state=load_optimizer_state, load_trainer_state=load_trainer_state, ) + assert self.checkpoint_loaded + return True else: - log.warning(f"No checkpoint found in '{dir}', will train from scratch...") - return should_load + return False def save_checkpoint(self) -> PathOrStr: """