-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathTrainer.py
93 lines (73 loc) · 2.82 KB
/
Trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from copy import deepcopy
from typing import Tuple
import numpy as np
from utils import *
class Trainer:
'''
Trains a neural network
'''
def __init__(self,
net,
optim):
'''
Requires a neural network and an optimizer in order for training to occur.
Assign the neural network as an instance variable to the optimizer.
'''
self.net = net
self.optim = optim
self.best_loss = 1e9
setattr(self.optim, 'net', self.net)
def generate_batches(self,
X,
y,
size = 32) :
'''
Generates batches for training
'''
assert X.shape[0] == y.shape[0], \
'''
features and target must have the same number of rows, instead
features has {0} and target has {1}
'''.format(X.shape[0], y.shape[0])
N = X.shape[0]
for ii in range(0, N, size):
X_batch, y_batch = X[ii:ii+size], y[ii:ii+size]
yield X_batch, y_batch
def fit(self, X_train, y_train,
X_test, y_test,
epochs=100,
eval_every=10,
batch_size=32,
seed= 1,
restart = True):
'''
Fits the neural network on the training data for a certain number of epochs.
Every "eval_every" epochs, it evaluated the neural network on the testing data.
'''
np.random.seed(seed)
if restart:
for layer in self.net.layers:
layer.first = True
self.best_loss = 1e9
for e in range(epochs):
if (e+1) % eval_every == 0:
# for early stopping
last_model = deepcopy(self.net)
X_train, y_train = permute_data(X_train, y_train)
batch_generator = self.generate_batches(X_train, y_train,
batch_size)
for ii, (X_batch, y_batch) in enumerate(batch_generator):
self.net.train_batch(X_batch, y_batch)
self.optim.step()
if (e+1) % eval_every == 0:
test_preds = self.net.forward(X_test)
loss = self.net.loss.forward(test_preds, y_test)
if loss < self.best_loss:
print(f"Validation loss after {e+1} epochs is {loss:.3f}")
self.best_loss = loss
else:
print(f"""Loss increased after epoch {e+1}, final loss was {self.best_loss:.3f}, using the model from epoch {e+1-eval_every}""")
self.net = last_model
# ensure self.optim is still updating self.net
setattr(self.optim, 'net', self.net)
break