forked from Sherry-SR/edgeDL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_unet3d.py
executable file
·132 lines (106 loc) · 5.57 KB
/
train_unet3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import importlib
import argparse
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.config import load_config
from models.unet3d.losses import get_loss_criterion
from models.unet3d.metrics import get_evaluation_metric
from models.unet3d.model import get_model
from utils.trainer import NNTrainer
from utils.helper import get_logger, get_number_of_learnable_parameters
from utils.databuilder import get_train_loaders
import warnings
warnings.filterwarnings("ignore")
def _create_trainer(config, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, logger):
assert 'trainer' in config, 'Could not find trainer configuration'
trainer_config = config['trainer']
resume = trainer_config.get('resume', None)
pre_trained = trainer_config.get('pre_trained', None)
validate_iters = trainer_config.get('validate_iters', None)
if resume is not None:
# continue training from a given checkpoint
return NNTrainer.from_checkpoint(resume, model,
optimizer, lr_scheduler, loss_criterion,
eval_criterion, loaders,
logger=logger)
elif pre_trained is not None:
# fine-tune a given pre-trained model
return NNTrainer.from_pretrained(pre_trained, model, optimizer, lr_scheduler, loss_criterion,
eval_criterion, device=config['device'], loaders=loaders,
max_num_epochs=trainer_config['epochs'],
max_num_iterations=trainer_config['iters'],
validate_after_iters=trainer_config['validate_after_iters'],
log_after_iters=trainer_config['log_after_iters'],
eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
logger=logger, validate_iters = validate_iters)
else:
# start training from scratch
return NNTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion,
config['device'], loaders, trainer_config['checkpoint_dir'],
max_num_epochs=trainer_config['epochs'],
max_num_iterations=trainer_config['iters'],
validate_after_iters=trainer_config['validate_after_iters'],
log_after_iters=trainer_config['log_after_iters'],
eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
logger=logger, validate_iters = validate_iters)
def _create_optimizer(config, model):
assert 'optimizer' in config, 'Cannot find optimizer configuration'
optimizer_config = config['optimizer']
learning_rate = optimizer_config['learning_rate']
weight_decay = optimizer_config['weight_decay']
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
return optimizer
def _create_lr_scheduler(config, optimizer):
lr_config = config.get('lr_scheduler', None)
if lr_config is None:
# use ReduceLROnPlateau as a default scheduler
return ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=20, verbose=True)
else:
class_name = lr_config.pop('name')
m = importlib.import_module('torch.optim.lr_scheduler')
clazz = getattr(m, class_name)
# add optimizer to the config
lr_config['optimizer'] = optimizer
return clazz(**lr_config)
def main():
# Create main logger
logger = get_logger('UNet3DTrainer')
parser = argparse.ArgumentParser(description='UNet3D training')
parser.add_argument('--config', type=str, help='Path to the YAML config file', default='/home/SENSETIME/shenrui/Dropbox/SenseTime/edgeDL/resources/train_config_unet.yaml')
args = parser.parse_args()
# Load and log experiment configuration
config = load_config(args.config)
logger.info(config)
manual_seed = config.get('manual_seed', None)
if manual_seed is not None:
logger.info(f'Seed the RNG for all devices with {manual_seed}')
torch.manual_seed(manual_seed)
# see https://pytorch.org/docs/stable/notes/randomness.html
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Create the model
model = get_model(config)
# put the model on GPUs
logger.info(f"Sending the model to '{config['device']}'")
model = model.to(config['device'])
# Log the number of learnable parameters
logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')
# Create loss criterion
loss_criterion = get_loss_criterion(config)
# Create evaluation metric
eval_criterion = get_evaluation_metric(config)
# Create data loaders
loaders = get_train_loaders(config)
# Create the optimizer
optimizer = _create_optimizer(config, model)
# Create learning rate adjustment strategy
lr_scheduler = _create_lr_scheduler(config, optimizer)
# Create model trainer
trainer = _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,
loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders,
logger=logger)
# Start training
trainer.fit()
if __name__ == '__main__':
main()