From 35f4d0730c3227fe1ab936bd702c021396e205f0 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Fri, 5 Jul 2024 22:44:59 -0700 Subject: [PATCH] Added checkpointing functions --- train.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/train.py b/train.py index 44ac730..0a0271f 100644 --- a/train.py +++ b/train.py @@ -125,6 +125,27 @@ def preprocess_data(data, tokenizer, max_length, overlap): return sequences +def save_checkpoint(model, optimizer, scheduler, epoch, best_val_loss): + checkpoint = { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "best_val_loss": best_val_loss, + } + torch.save(checkpoint, f"checkpoint_epoch_{epoch}.pt") + logger.info(f"Checkpoint saved for epoch {epoch}") + + +def load_checkpoint(model, optimizer, scheduler, filename): + checkpoint = torch.load(filename) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + logger.info(f"Loaded checkpoint from {filename}") + return checkpoint["epoch"], checkpoint["best_val_loss"] + + def compute_perplexity(loss): return torch.exp(loss)