forked from assassint2017/MICCAI-LITS2017
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_ds.py
127 lines (96 loc) · 2.45 KB
/
train_ds.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
深度监督下的训练脚本
"""
from time import time
import os
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from loss.Dice_loss import DiceLoss
from net.DialResUNet import net
from dataset.dataset_random import train_ds
# 定义超参数
on_server = True
os.environ['CUDA_VISIBLE_DEVICES'] = '0' if on_server is False else '1,2,3'
cudnn.benchmark = True
Epoch = 3000
leaing_rate_base = 1e-4
alpha = 0.33
batch_size = 1 if on_server is False else 3
num_workers = 1 if on_server is False else 3
pin_memory = False if on_server is False else True
net = torch.nn.DataParallel(net).cuda()
net.train()
# 定义数据加载
train_dl = DataLoader(train_ds, batch_size, True, num_workers=num_workers, pin_memory=pin_memory)
# 定义损失函数
loss_func = DiceLoss()
# 定义优化器
opt = torch.optim.Adam(net.parameters(), lr=leaing_rate_base)
# 学习率衰减
lr_decay = torch.optim.lr_scheduler.MultiStepLR(opt, [1500])
# 训练网络
start = time()
for epoch in range(Epoch):
lr_decay.step()
mean_loss = []
for step, (ct, seg) in enumerate(train_dl):
ct = ct.cuda()
seg = seg.cuda()
outputs = net(ct)
loss1 = loss_func(outputs[0], seg)
loss2 = loss_func(outputs[1], seg)
loss3 = loss_func(outputs[2], seg)
loss4 = loss_func(outputs[3], seg)
loss = (loss1 + loss2 + loss3) * alpha + loss4
mean_loss.append(loss4.item())
opt.zero_grad()
loss.backward()
opt.step()
if step % 20 is 0:
print('epoch:{}, step:{}, loss1:{:.3f}, loss2:{:.3f}, loss3:{:.3f}, loss4:{:.3f}, time:{:.3f} min'
.format(epoch, step, loss1.item(), loss2.item(), loss3.item(), loss4.item(), (time() - start) / 60))
mean_loss = sum(mean_loss) / len(mean_loss)
if epoch % 10 is 0 and epoch is not 0:
# 网络模型的命名方式为:epoch轮数+当前minibatch的loss+本轮epoch的平均loss
torch.save(net.state_dict(), './module/net{}-{:.3f}-{:.3f}.pth'.format(epoch, loss.item(), mean_loss))
if epoch % 15 is 0 and epoch is not 0:
alpha *= 0.8
# 深度监督的系数变化
# 1.000
# 0.800
# 0.640
# 0.512
# 0.410
# 0.328
# 0.262
# 0.210
# 0.168
# 0.134
# 0.107
# 0.086
# 0.069
# 0.055
# 0.044
# 0.035
# 0.028
# 0.023
# 0.018
# 0.014
# 0.012
# 0.009
# 0.007
# 0.006
# 0.005
# 0.004
# 0.003
# 0.002
# 0.002
# 0.002
# 0.001
# 0.001
# 0.001
# 0.001
# 0.001
# 0.000
# 0.000