-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
90 lines (77 loc) · 3.19 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import pickle
import torch
import simplebinmi
class TrainConfig:
def __init__(self, model, criterion, optimizer):
self.model = model
self.criterion = criterion
self.optimizer = optimizer
class Train:
def __init__(self, config):
self.config = config
self.epochs = 5000
self.mi_cycle = 1
self.n_layers = None
self.losses = dict()
self.accuracy = dict()
self.running_mis_xt = dict()
self.running_mis_ty = dict()
for phase in ['train', 'test']:
self.losses[phase] = []
self.accuracy[phase] = []
self.running_mis_xt[phase] = []
self.running_mis_ty[phase] = []
@staticmethod
def get_class_masks(data):
samples_split = dict()
n_classes = int(data['train']['class'].max()) + 1
for phase in ['train', 'test']:
samples_split[phase] = {}
classes = data[phase]['class'].detach().numpy()
for i in range(n_classes):
samples_split[phase][i] = classes == i
return samples_split
def run(self, data):
self.n_layers = self.config.model.n_layers
class_masks = self.get_class_masks(data)
for i in range(self.epochs):
to_print = ''
for phase in ['train', 'test']:
if phase == 'train':
self.config.model.train()
else:
self.config.model.eval()
self.config.optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs, hiddens = self.config.model(data[phase]['samples'])
loss = self.config.criterion(outputs, data[phase]['labels'])
if phase == 'train':
loss.backward()
self.config.optimizer.step()
loss = loss.item()
self.losses[phase].append(loss)
acc = (data[phase]['class'] == outputs.argmax(dim=1)).sum() / float(len(data[phase]['labels']))
self.accuracy[phase].append(float(acc))
to_print += f'{phase}: loss {loss:>.4f} - acc {acc:>.4f} \t'
if i % self.mi_cycle == 0:
running_mi_xt = []
running_mi_ty = []
for j in range(self.n_layers):
activity = hiddens[j].detach().numpy()
binxm, binym = simplebinmi.bin_calc_information(class_masks[phase], activity, binsize=0.07)
running_mi_xt.append(binxm)
running_mi_ty.append(binym)
self.running_mis_xt[phase].append(running_mi_xt)
self.running_mis_ty[phase].append(running_mi_ty)
print(f'Epoch {i:>4}: {to_print}')
def dump(self):
tracking = {
'n_layers': self.n_layers,
'mi_cycle': self.mi_cycle,
'losses': self.losses,
'accuracy': self.accuracy,
'running_mis_xt': self.running_mis_xt,
'running_mis_ty': self.running_mis_ty,
}
with open('train.pkl', 'wb') as f:
pickle.dump(tracking, f)