-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
176 lines (141 loc) · 6.93 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import argparse
import time
import math
import numpy as np
from collections import OrderedDict
from torch.utils.tensorboard import SummaryWriter
from dataset import create_dataloader
from utils import parse_configuration
from models import create_model
from utils.visualizer import Visualizer
"""Performs training of a specified model.
Input params:
config_file: Either a string with the path to the JSON
system-specific config file or a dictionary containing
the system-specific, dataset-specific and
model-specific settings.
export: Whether to export the final model (default=True).
"""
def train(config_file, export=False):
print('Reading config file...')
configuration = parse_configuration(config_file)
print('Initializing dataset...')
train_dataset = create_dataloader(configuration['train_dataset_params'])
train_dataset_size = len(train_dataset)
print('The number of training samples = {0}'.format(train_dataset_size))
val_dataset = create_dataloader(configuration['val_dataset_params'])
val_dataset_size = len(val_dataset)
print('The number of validation samples = {0}'.format(val_dataset_size))
print('Initializing model...')
model = create_model(configuration['model_params'])
model.setup()
print('Initializing visualization...')
visualizer = Visualizer(configuration['visualization_params']) # create a visualizer that displays images and plots
# Tensorboard writer
training_log = SummaryWriter(configuration['visualization_params']['log_path']+'tensorboard/training/')
validation_log = SummaryWriter(configuration['visualization_params']['log_path']+'tensorboard/validation/')
starting_epoch = configuration['model_params']['load_checkpoint']
num_epochs = configuration['model_params']['max_epochs'] + 1
# iter number for summary writer
for epoch in range(starting_epoch, num_epochs):
epoch_start_time = time.time()
train_iterations = len(train_dataset)
validation_iterations = len(val_dataset)
train_batch_size = configuration['train_dataset_params']['loader_params']['batch_size']
model.train()
loss_total = []
loss_smooth = []
loss_consistency = []
loss_similarity = []
for i, data in enumerate(train_dataset): # inner loop within one epoch
image_list, image_id = data
for _, image in enumerate(image_list):
model.set_input(image) # unpack data from dataset and apply preprocessing
model.forward()
model.backward()
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
losses = model.get_current_losses()
visualizer.print_losses(epoch, num_epochs, i, train_iterations, losses, image_id, name='training')
# add loss to tensorboard
loss_total.append(losses['total'])
loss_smooth.append(losses['smooth_mean'])
loss_similarity.append(losses['similarity_mean'])
loss_consistency.append(losses['consistency_strain_mean'])
training_log.add_scalar('loss/total',
np.mean(np.stack(loss_total)),
epoch)
training_log.add_scalar('loss/smooth',
np.mean(np.stack(loss_smooth)),
epoch)
training_log.add_scalar('loss/similarity',
np.mean(np.stack(loss_similarity)),
epoch)
training_log.add_scalar('loss/consistency',
np.mean(np.stack(loss_consistency)),
epoch)
training_log.add_scalar('learning rate',
model.lr,
epoch)
training_log.flush()
if epoch % configuration['validation_freq'] == 0:
# model.eval()
loss_total = []
loss_smooth = []
loss_consistency = []
loss_similarity = []
print('************************VALIDATION***************************')
for i, data in enumerate(val_dataset):
image_list, image_id = data
for _, image in enumerate(image_list):
model.set_input(image)
model.test()
model.backward_val()
losses = model.get_current_losses()
visualizer.print_losses(epoch, num_epochs, i, math.floor(validation_iterations / train_batch_size), losses, image_id, name='validation')
# add loss to tensorboard
loss_total.append(losses['total'])
loss_smooth.append(losses['smooth_mean'])
loss_similarity.append(losses['similarity_mean'])
loss_consistency.append(losses['consistency_strain_mean'])
validation_log.add_scalar('loss/total',
np.mean(np.stack(loss_total)),
epoch)
validation_log.add_scalar('loss/smooth',
np.mean(np.stack(loss_smooth)),
epoch)
validation_log.add_scalar('loss/similarity',
np.mean(np.stack(loss_similarity)),
epoch)
validation_log.add_scalar('loss/consistency',
np.mean(np.stack(loss_consistency)),
epoch)
validation_log.flush()
if epoch == num_epochs:
model.save_networks(epoch)
model.save_optimizers(epoch)
if epoch % configuration['model_update_freq'] == 0:
print('Saving model at the end of epoch {0}'.format(epoch))
model.save_networks(epoch)
model.save_optimizers(epoch)
print('End of epoch {0} / {1} \t Time Taken: {2} sec'.format(epoch, num_epochs, time.time() - epoch_start_time))
if configuration['model_params']['lr_policy'] == 'plateau':
model.update_learning_rate(np.mean(np.stack(loss_total))) # update learning rates every epoch
else:
model.update_learning_rate() # update learning rates every epoch
if export:
print('Exporting model')
model.eval()
custom_configuration = configuration['train_dataset_params']
custom_configuration['loader_params']['batch_size'] = 1 # set batch size to 1 for tracing
dl = create_dataloader(custom_configuration)
sample_input = next(iter(dl)) # sample input from the training dataset
model.set_input(sample_input)
model.export()
return model.get_hyperparam_result()
if __name__ == '__main__':
import multiprocessing
multiprocessing.set_start_method('spawn', True)
parser = argparse.ArgumentParser(description='Perform model training.')
parser.add_argument('configfile', help='path to the configfile')
args = parser.parse_args()
train(args.configfile)