diff --git a/train.py b/train.py index 44ac730..c799ea6 100644 --- a/train.py +++ b/train.py @@ -139,6 +139,7 @@ def train( tokenizer, epochs=10, gradient_accumulation_steps=4, + resume_from_checkpoint=None, ): """ Train the LongRoPE model. @@ -153,6 +154,7 @@ def train( tokenizer: Tokenizer for encoding/decoding text. epochs (int): Number of training epochs. gradient_accumulation_steps (int): Number of steps to accumulate gradients. + resume_from_checkpoint (str): Path to a checkpoint to resume training from. Returns: None @@ -164,6 +166,7 @@ def train( best_val_loss = float("inf") patience = 0 max_patience = 3 + start_epoch = 0 for epoch in range(epochs): model.train()