-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathrun_utils.py
105 lines (96 loc) · 4.47 KB
/
run_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
def train_single_model(model, loss):
model.optimizer.zero_grad() # clear gradietns of all optimized variables
loss.backward() # backward pass
model.optimizer.step()
if(hasattr(model.params, 'renormalize_weights') and model.params.renormalize_weights):
with torch.no_grad(): # tell autograd to not record this operation
model.w.div_(torch.norm(model.w, dim=0, keepdim=True))
def train_epoch(epoch, model, loader):
model.train()
epoch_size = len(loader.dataset)
num_batches = epoch_size // model.params.batch_size
correct = 0
for batch_idx, (data, target) in enumerate(loader):
data, target = data.to(model.params.device), target.to(model.params.device)
inputs = []
if(model.params.model_type.lower() == 'ensemble'): # TODO: Move this to train_model
inputs.append(model[0].preprocess_data(data)) # First model preprocesses the input
for submodule_idx, submodule in enumerate(model):
loss = model.get_total_loss((inputs[-1], target), submodule_idx)
train_single_model(submodule, loss)
# TODO: include optional parameter to allow gradients to propagate through the entire ensemble.
inputs.append(submodule.get_encodings(inputs[-1]).detach()) # must detach to prevent gradient leaking
else:
inputs.append(model.preprocess_data(data))
loss = model.get_total_loss((inputs[-1], target))
train_single_model(model, loss)
if model.params.train_logs_per_epoch is not None:
if(batch_idx % int(num_batches/model.params.train_logs_per_epoch) == 0.):
batch_step = epoch * model.params.batches_per_epoch + batch_idx
model.print_update(
input_data=inputs[0], input_labels=target, batch_step=batch_step)
if(model.params.model_type.lower() == 'ensemble'):
for submodule in model:
submodule.scheduler.step(epoch)
else:
model.scheduler.step(epoch)
def test_single_model(model, data, target, epoch):
output = model(data)
#test_loss = torch.nn.functional.nll_loss(output, target, reduction='sum').item()
test_loss = torch.nn.CorssEntropyLoss()(output, target)
pred = output.max(1, keepdim=True)[1]
correct = pred.eq(target.view_as(pred)).sum().item()
return (test_loss, correct)
def test_epoch(epoch, model, loader, log_to_file=True):
with torch.no_grad():
model.eval()
test_loss = 0
correct = 0
for data, target in loader:
data, target = data.to(model.params.device), target.to(model.params.device)
if(model.params.model_type.lower() == 'ensemble'):
inputs = [model[0].preprocess_data(data)]
for submodule in model:
if(submodule.params.model_type == 'mlp'):
batch_test_loss, batch_correct = test_single_model(
submodule, inputs[-1], target, epoch)
test_loss += batch_test_loss
correct += batch_correct
inputs.append(submodule.get_encodings(inputs[-1]))
else:
inputs = [model.preprocess_data(data)]
batch_test_loss, batch_correct = test_single_model(
model, inputs[0], target, epoch)
test_loss += batch_test_loss
correct += batch_correct
test_loss /= len(loader.dataset)
test_accuracy = 100. * correct / len(loader.dataset)
stat_dict = {
'test_epoch':epoch,
'test_loss':test_loss,
'test_correct':correct,
'test_total':len(loader.dataset),
'test_accuracy':test_accuracy}
if log_to_file:
js_str = model.js_dumpstring(stat_dict)
model.log_info('<stats>'+js_str+'</stats>')
else:
return stat_dict
def get_inputs_and_outputs(epoch, model, loader, num_batches=1):
with torch.no_grad():
model.eval()
outputs = []
targets = []
inputs = []
batch = 0
for data, target in loader:
if batch >= num_batches:
pass
batch += 1
data, target = data.to(model.params.device), target.to(model.params.device)
output = model(data)
inputs.append(data)
targets.append(target)
outputs.append(output)
return (inputs, targets, outputs)