-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_unet.py
113 lines (87 loc) · 4.35 KB
/
train_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch.backends.cudnn as cudnn
import torch.optim as optim
import sys, time
from os.path import join
import torch
from lib.losses.loss import *
from lib.common import *
from config_train import parse_args
from lib.logger import Logger, Print_Logger
import models
from test import Test
from function import get_dataloader, train, val, get_dataloaderV2
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
def main():
setpu_seed(2021)
args = parse_args()
save_path = join(args.outf, args.save)
save_args(args, save_path)
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
cudnn.benchmark = True
log = Logger(save_path)
sys.stdout = Print_Logger(os.path.join(save_path, 'train_log.txt'))
print('The computing device used is: ', 'GPU' if device.type == 'cuda' else 'CPU')
# net = models.UNetFamily.U_Net(1,2).to(device)
# net = models.UNetFamily.U_Net_small(1,2).to(device)
# net = models.UNetFamily.AttU_Net_small(1,2).to(device)
# net = models.UNetFamily.Dense_Unet(1,2, 64).to(device)
net = models.OCE_Net.OCENet(1,2).to(device)
print("Total number of parameters: " + str(count_parameters(net)))
# log.save_graph(net, torch.randn((1, 1, 48, 48)).to(device).to(
# device=device)) # Save the model structure to the tensorboard file
# torch.nn.init.kaiming_normal(net, mode='fan_out') # Modify default initialization method
# net.apply(weight_init)
# The training speed of this task is fast, so pre training is not recommended
if args.pre_trained is not None:
# Load checkpoint.
print('==> Resuming from checkpoint..')
checkpoint = torch.load(args.outf + '%s/latest_model.pth' % args.pre_trained)
net.load_state_dict(checkpoint['net'])
# optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
# criterion = LossMulti(jaccard_weight=0,class_weights=np.array([0.5,0.5]))
criterion = CrossEntropyLoss2d() # Initialize loss function
optimizer = optim.Adam(net.parameters(), lr=args.lr)
# create a list of learning rate with epochs
# lr_schedule = make_lr_schedule(np.array([50, args.N_epochs]),np.array([0.001, 0.0001]))
# lr_scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.5)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.N_epochs, eta_min=0)
train_loader, val_loader = get_dataloaderV2(args) # create dataloader
# train_loader, val_loader = get_dataloader(args)
if args.val_on_test:
print('\033[0;32m===============Validation on Testset!!!===============\033[0m')
val_tool = Test(args)
best = {'epoch': 0, 'AUC_roc': 0.5} # Initialize the best epoch and performance(AUC of ROC)
trigger = 0 # Early stop Counter
for epoch in range(args.start_epoch, args.N_epochs + 1):
print('\nEPOCH: %d/%d --(learn_rate:%.6f) | Time: %s' % \
(epoch, args.N_epochs, optimizer.state_dict()['param_groups'][0]['lr'], time.asctime()))
# train stage
train_log = train(train_loader, net, criterion, optimizer, device)
# val stage
if not args.val_on_test:
val_log = val(val_loader, net, criterion, device)
else:
val_tool.inference(net)
val_log = val_tool.val()
log.update(epoch, train_log, val_log) # Add log information
lr_scheduler.step()
# Save checkpoint of latest and best model.
state = {'net': net.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, join(save_path, 'latest_model.pth'))
trigger += 1
if val_log['val_auc_roc'] > best['AUC_roc']:
print('\033[0;33mSaving best model!\033[0m')
torch.save(state, join(save_path, 'best_model.pth'))
best['epoch'] = epoch
best['AUC_roc'] = val_log['val_auc_roc']
trigger = 0
print('Best performance at Epoch: {} | AUC_roc: {}'.format(best['epoch'], best['AUC_roc']))
# early stopping
if not args.early_stop is None:
if trigger >= args.early_stop:
print("=> early stopping")
break
torch.cuda.empty_cache()
if __name__ == '__main__':
main()