-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.py
68 lines (56 loc) · 2.33 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
"""
Author: Shilpa Kancharla
Last updated: April 1, 2022
"""
def training_loop(net, train_loader, val_loader, optimizer, loss_fn):
num_epochs = 1
loss_l1 = nn.L1Loss()
# Store loss history for future plotting
loss_history, test_loss_history = [], []
loss_history_mae, test_loss_history_mae = [], []
loss, test_loss = 0, 0
history = dict()
history_val = dict()
counter = 0
val_counter = 0
for epoch in range(num_epochs):
batch = iter(train_loader)
for data, targets in batch: # Training loop
optimizer.zero_grad() # Clear gradients for next train
data = data.cuda()
targets = targets.cuda()
net.train() # Forward pass
prediction = net(data) # Predictions
loss = torch.sqrt(loss_fn(prediction, targets))
loss_mae = loss_l1(prediction, targets)
loss_history.append(loss.item())
loss_history_mae.append(loss_mae.item())
# Gradient calculation and weight update
loss.backward() # Backpropagation, compute gradients
optimizer.step() # Performs the update (apply gradients)
if counter % 10 == 0: # Print every 10 results
print(f"Training Iteration {counter}: Training RMSE: {loss.item()}, Training MAE: {loss_mae.item()}")
counter += 1
with torch.no_grad(): # Test loop - do not track history for backpropagation
net.eval() # Test forward pass
test_batch = iter(val_loader)
for test_data, test_targets in test_batch:
test_data = test_data.cuda()
test_targets = test_targets.cuda()
test_pred = net(test_data) # Predictions
test_loss = torch.sqrt(loss_fn(test_pred, test_targets))
test_loss_mae = loss_l1(test_pred, test_targets)
test_loss_history.append(test_loss.item())
test_loss_history_mae.append(test_loss_mae.item())
if val_counter % 10 == 0:
print(f"Validation Iteration {val_counter}: Validation RMSE: {test_loss.item()}, Validation MAE: {test_loss_mae.item()}")
val_counter += 1
history['Training RMSE'] = loss_history
history['Training MAE'] = loss_history_mae
history_val['Validation RMSE'] = test_loss_history
history_val['Validation MAE'] = test_loss_history_mae
torch.save(net.state_dict(), SRC + 'results/model2.pt')
return history, history_val