Skip to content

Commit

Permalink
Adding a validation step in the fine-tuning process
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 4, 2024
1 parent 81ffdbb commit 591756a
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np
Expand Down Expand Up @@ -610,10 +611,10 @@ def fine_tune(model, train_data, val_data, target_length, lambda_factors, n_hat,
"""
model.lambda_factors[f"{target_length // 1000}k"] = lambda_factors
model.n_hat[f"{target_length // 1000}k"] = n_hat
optimizer = optim.Adam(model.parameters(), lr=1e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

best_val_perplexity = float("inf")
best_model = None
best_model_state = None

for step in range(steps):
# Training
Expand All @@ -638,9 +639,12 @@ def fine_tune(model, train_data, val_data, target_length, lambda_factors, n_hat,
print(f"Step {step}, Validation Perplexity: {val_perplexity}")
if val_perplexity < best_val_perplexity:
best_val_perplexity = val_perplexity
best_model = copy.deepcopy(model)
best_model_state = model.state_dict()

return best_model if best_model is not None else model
if best_model_state is not None:
model.load_state_dict(best_model_state)

return model


def evaluate_perplexity(model, data, target_length):
Expand Down

0 comments on commit 591756a

Please sign in to comment.