diff --git a/trainer/io.py b/trainer/io.py index 6e08aea..cca0a45 100644 --- a/trainer/io.py +++ b/trainer/io.py @@ -180,6 +180,20 @@ def save_best_model( save_func=None, **kwargs, ): + """ + Saves the best model based on the training losses + + Compares the best loss to the current loss. If current loss is better than the previous loss, current is set to best loss + + When starting from a saved checkpoint, the losses are stored in a dict like the following one + {train_loss: value, val_loss: value} + + Needed to handle this when the model training is restarted from a checkpoint + """ + + if isinstance(best_loss, dict): + best_loss = best_loss["train_loss"] + if current_loss < best_loss: best_model_name = f"best_model_{current_step}.pth" checkpoint_path = os.path.join(out_path, best_model_name) @@ -208,6 +222,7 @@ def save_best_model( shortcut_path = os.path.join(out_path, shortcut_name) fs.copy(checkpoint_path, shortcut_path) best_loss = current_loss + return best_loss