-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathdefault_trainer.py
executable file
·305 lines (249 loc) · 13.2 KB
/
default_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# --------------------------------------------------------
# X-Decoder -- Generalized Decoding for Pixel, Image, and Language
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Modified by Xueyan Zou ([email protected])
# --------------------------------------------------------
from datetime import datetime
import time
import os
import sys
import importlib
import json
import random
#import wandb
import logging
import numpy as np
import copy
import contextlib
import shutil
from typing import Any, Callable, Union
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from mpi4py import MPI
from infinibatch import iterators
from .distributed_trainer import DistributedTrainer
from .utils_trainer import UtilsTrainer
from .utils.misc import *
from .utils.serialization import JSONEncoder, filter_jsonable
logger = logging.getLogger(__name__)
class DefaultTrainer(UtilsTrainer, DistributedTrainer):
def __init__(self, opt):
"""
Set up the task the model is being trained for.
"""
super().__init__(opt)
base_name = 'base_dir'
base_path = os.path.join(self.opt['base_path'], '__init__.py')
spec = importlib.util.spec_from_file_location(base_name, base_path)
module = importlib.util.module_from_spec(spec)
sys.modules[base_name] = module
spec.loader.exec_module(module)
logger.info(f"Imported {base_name} at base_path {self.opt['base_path']}")
pipeline_module = importlib.import_module(f"base_dir.pipeline.{self.opt['PIPELINE']}")
pipeline_class = getattr(pipeline_module, self.opt['PIPELINE'])
logger.info(f"Pipeline for training: {self.opt['PIPELINE']}")
self.pipeline = pipeline_class(self.opt)
def eval(self, ):
logger.info('-----------------------------------------------')
logger.info("Evaluating model ... ")
self.mode = "eval"
# self.model_names, self.raw_models, self.criteria = self.pipeline.set_up_model()
self.raw_models = self.pipeline.initialize_model()
self.model_names = self.raw_models.keys()
# move models to the device
for module_name in self.model_names:
self.raw_models[module_name].to(self.opt['device'])
# load model during evaluation
if self.opt['WEIGHT'] and os.path.isfile(self.opt['RESUME_FROM']):
model_path = self.opt['RESUME_FROM']
self.load_model(model_path)
else:
raise ValueError(f"Model not found: {model_path}")
results = self._eval_on_set(self.save_folder)
return results
def _eval_on_set(self, save_folder):
logger.info(f"Evaluation start ...")
if self.opt['FP16']:
from torch.cuda.amp import autocast
with autocast():
results = self.pipeline.evaluate_model(self, save_folder)
else:
results = self.pipeline.evaluate_model(self, save_folder)
if self.opt['rank'] == 0:
logger.info(results)
return results
def compute_loss(self, forward_func, batch):
def forward(func, trainer, batch):
if self.opt['FP16']:
from torch.cuda.amp import autocast
with autocast():
loss = func(trainer, batch)
else:
loss = func(trainer, batch)
return loss
loss = forward(forward_func, self, batch)
return loss
def backward_loss(self, loss, model_names=['default']): # noqa: E252
def backward(loss_tensor):
if self.opt['FP16']:
self.grad_scaler.scale(loss_tensor).backward()
else:
loss_tensor.backward()
if self.grad_acc_steps > 1:
loss = loss / self.grad_acc_steps
backward(loss)
return loss
def update_model(self, model_name='default'):
if self.opt['FP16']:
self.grad_scaler.unscale_(self.optimizers[model_name])
self.grad_scaler.step(self.optimizers[model_name])
else:
self.optimizers[model_name].step()
self.optimizers[model_name].zero_grad()
self.train_params['optim_steps'][model_name] += 1
self.lr_schedulers[model_name].step()
def train_step(self, batch):
self.grad_acc_batches.append(batch) # support batch accumulation
if self.is_gradient_accumulation_boundary():
# set all modules and criteria into training mode
for model_name in self.model_names:
self.models[model_name].train()
assert len(self.grad_acc_batches) == self.grad_acc_steps
total_batch_sample = 0
for batch_index, batch in enumerate(self.grad_acc_batches):
loss_info, sample_size_info, extra_info = \
self.pipeline.forward_step(self,
batch,
self.grad_acc_batches,
batch_index,
is_distributed=(self.opt['world_size'] > 1))
self.train_loss.update_iter(loss_info)
total_batch_sample += sample_size_info['num_samples']
if self.opt['FP16']:
# Update GradScaler after an effective batch
self.grad_scaler.update()
# update losses and item counts of an effective batch to the AverageMeters
if self.opt['world_size'] > 1:
total_batch_sample = torch.tensor(total_batch_sample).to(self.opt['device'])
torch.distributed.all_reduce(total_batch_sample, torch.distributed.ReduceOp.SUM)
total_batch_sample = total_batch_sample.item()
self.train_params['total_batch_size'] += total_batch_sample
self.grad_acc_batches = []
self.train_params['num_updates'] += 1
def init_train(self):
self.mode = "train"
logger.info('-------------------------------------------------------')
logger.info("Training on rank: {}".format(self.opt['rank']))
self.raw_models = self.pipeline.initialize_model()
self.model_names = list(self.raw_models.keys())
# move models to the device
for module_name in self.model_names:
self.raw_models[module_name].to(self.opt['device'])
self.train_dataloaders = self.pipeline.get_dataloaders(self, 'train', is_evaluation=False)
self.train_params = {
"updates_per_epoch": len(self.train_dataloaders),
"total_batch_size": 0,
"num_updates": 0,
"optim_steps": {module_name: 0 for module_name in self.model_names},
"start_epoch_idx": 0,
"start_batch_idx": 0,
"current_epoch_idx": 0,
"current_batch_idx": 0,
"resume_epoch_idx": 0,
}
self.train_loss = LossMeter()
self.grad_acc_batches = []
if self.opt['CUDA']:
torch.cuda.empty_cache()
self.create_optimizer_and_scheduler()
self.models = {model_name: self.raw_models[model_name] for model_name in self.model_names}
self._initialize_ddp()
if self.opt.get('WEIGHT', False):
self.load_weight(self.opt['RESUME_FROM'], must_exist=True)
if self.opt.get('RESUME', False):
self.load_checkpoint(self.opt['RESUME_FROM'], must_exist=True)
######################
# Start the main loop
######################
if self.opt['rank'] == 0:
# Train!
logger.info("***** Running training *****")
logger.info(f" Num of GPUs = {self.opt['world_size']}")
logger.info(f" Num Epochs = {self.opt['SOLVER']['MAX_NUM_EPOCHS']}")
logger.info(f" Num of Mini Batches per Epoch = {self.train_params['updates_per_epoch']}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {self.opt['SOLVER']['MAX_NUM_EPOCHS'] * self.train_params['updates_per_epoch']}")
logger.info(f" Gradient Accumulation steps = {self.grad_acc_steps}")
logger.info(f" Total optimization steps = {self.opt['SOLVER']['MAX_NUM_EPOCHS'] * self.train_params['updates_per_epoch'] // self.grad_acc_steps}")
def train(self):
"""
Training
"""
self.init_train()
current_optim_steps = self._get_and_validate_current_optim_steps()
num_epochs = self.opt['SOLVER']['MAX_NUM_EPOCHS']
if self.opt.get('EVAL_AT_START', False):
results = self._eval_on_set(self.save_folder)
# if self.opt['rank'] == 0 and self.opt['WANDB']:
# wandb.log(results)
train_prev_logged_time = datetime.now()
for epoch in range(self.train_params['start_epoch_idx'], num_epochs):
self.train_params['current_epoch_idx'] = epoch
logger.info(f"Start epoch: {epoch} training.")
epoch_start_time = datetime.now()
for batch_idx, batch in enumerate(self.train_dataloaders):
if self.train_params['current_epoch_idx'] == self.train_params['start_epoch_idx']:
if batch_idx < self.train_params['start_batch_idx']: # skip the first few batches for resuming
continue
self.train_params['current_batch_idx'] = batch_idx
prev_optim_steps = current_optim_steps
prev_total_batch_size = self.train_params['total_batch_size']
# update
self.prev_optim_steps = prev_optim_steps
self.train_step(batch)
current_optim_steps = self._get_and_validate_current_optim_steps()
# logging
if prev_optim_steps != current_optim_steps: # an optimizer update was made
log_first = self.opt.get("LOG_FIRST", 10)
log_every = self.opt.get("LOG_EVERY", 100)
if (current_optim_steps % log_every == 0) or (epoch == 0 and current_optim_steps <= log_first): # print logging
last_lr = {}
for module_name in self.model_names:
last_lr[module_name] = self.lr_schedulers[module_name].get_last_lr()[0]
train_time_delta = (datetime.now() - train_prev_logged_time).total_seconds()
train_prev_logged_time = datetime.now()
MB = 1024.0 * 1024.0
memory = torch.cuda.max_memory_allocated() / MB
if self.opt['rank'] == 0:
# if self.opt['WANDB']:
# # log for wandb
# wb_loss_info = {key: obj.val for key, obj in self.train_loss.losses.items()}
# wandb.log(wb_loss_info, step=self.prev_optim_steps)
# log for terminal
logger.info(f"epochs[{epoch:6}] optim steps[{current_optim_steps:.0f}] "
f"learning rate[{', '.join([f'{key}: {val:.5e}' for key, val in last_lr.items()])}] "
f"train loss[{', '.join([f'{key}: {obj.val:.5f}/{obj.avg:.5f}' for key, obj in self.train_loss.losses.items()])}] "
# f"total_loss[{total_loss:.5f}/{total_loss_avg:.5f} "
f"items per batch[{self.train_params['total_batch_size'] - prev_total_batch_size}] "
f"items per second[{(self.train_params['total_batch_size'] - prev_total_batch_size) / train_time_delta:.2f}] "
f"total items[{self.train_params['total_batch_size']}] "
f"mini batches[{self.train_params['num_updates']:6}] "
f"memory[{memory:.0f}] "
f"epoch remaining[{str((datetime.now() - epoch_start_time) / (batch_idx + 1) * (self.train_params['updates_per_epoch'] - batch_idx - 1)).split('.')[0]}]")
# evaluate and save ckpt every epoch
if batch_idx + 1 == self.train_params['updates_per_epoch']:
if self.opt.get('SAVE_CHECKPOINT', True):
self.save_checkpoint(self.train_params['num_updates'])
results = self._eval_on_set(self.save_folder)
# if self.opt['rank'] == 0 and self.opt['WANDB']:
# wandb.log(results)
break
logger.info(f"This epoch takes {datetime.now() - epoch_start_time}")
logger.info(f"PROGRESS: {100.0 * (epoch + 1) / num_epochs:.2f}%")
logger.info(f"Config files are at {self.opt['conf_files']}")
# if not self.opt.get('SAVE_CHECKPOINT', True):
# self.save_checkpoint(self.train_params['num_updates'])