-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathunet.py
173 lines (124 loc) · 5.85 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
This code is to build and train 2D U-Net
"""
import numpy as np
import sys
import subprocess
import argparse
import os
from keras.models import Model
from keras.layers import Input, Activation, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose, ZeroPadding2D, add
from keras.optimizers import Adam, SGD
from keras.callbacks import ModelCheckpoint, CSVLogger
from keras import backend as K
from keras import losses
import tensorflow as tf
import matplotlib.pyplot as plt
import pandas as pd
import csv
from utils import *
from data import load_train_data
K.set_image_data_format('channels_last') # Tensorflow dimension ordering
# ----- paths setting -----
data_path = sys.argv[1] + "/"
model_path = data_path + "models/"
log_path = data_path + "logs/"
# ----- params for training and testing -----
batch_size = 1
cur_fold = sys.argv[2]
plane = sys.argv[3]
epoch = int(sys.argv[4])
init_lr = float(sys.argv[5])
# ----- Dice Coefficient and cost function for training -----
smooth = 1.
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2.0 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
def dice_coef_loss(y_true, y_pred):
return -dice_coef(y_true, y_pred)
def get_unet((img_rows, img_cols), flt=64, pool_size=(2, 2, 2), init_lr=1.0e-5):
"""build and compile Neural Network"""
print "start building NN"
inputs = Input((img_rows, img_cols, 1))
conv1 = Conv2D(flt, (3, 3), activation='relu', padding='same')(inputs)
conv1 = Conv2D(flt, (3, 3), activation='relu', padding='same')(conv1)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(flt*2, (3, 3), activation='relu', padding='same')(pool1)
conv2 = Conv2D(flt*2, (3, 3), activation='relu', padding='same')(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
conv3 = Conv2D(flt*4, (3, 3), activation='relu', padding='same')(pool2)
conv3 = Conv2D(flt*4, (3, 3), activation='relu', padding='same')(conv3)
pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
conv4 = Conv2D(flt*8, (3, 3), activation='relu', padding='same')(pool3)
conv4 = Conv2D(flt*8, (3, 3), activation='relu', padding='same')(conv4)
pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
conv5 = Conv2D(flt*16, (3, 3), activation='relu', padding='same')(pool4)
conv5 = Conv2D(flt*8, (3, 3), activation='relu', padding='same')(conv5)
up6 = concatenate([Conv2DTranspose(flt*8, (2, 2), strides=(2, 2), padding='same')(conv5), conv4], axis=3)
conv6 = Conv2D(flt*8, (3, 3), activation='relu', padding='same')(up6)
conv6 = Conv2D(flt*4, (3, 3), activation='relu', padding='same')(conv6)
up7 = concatenate([Conv2DTranspose(flt*4, (2, 2), strides=(2, 2), padding='same')(conv6), conv3], axis=3)
conv7 = Conv2D(flt*4, (3, 3), activation='relu', padding='same')(up7)
conv7 = Conv2D(flt*2, (3, 3), activation='relu', padding='same')(conv7)
up8 = concatenate([Conv2DTranspose(flt*2, (2, 2), strides=(2, 2), padding='same')(conv7), conv2], axis=3)
conv8 = Conv2D(flt*2, (3, 3), activation='relu', padding='same')(up8)
conv8 = Conv2D(flt, (3, 3), activation='relu', padding='same')(conv8)
up9 = concatenate([Conv2DTranspose(flt, (2, 2), strides=(2, 2), padding='same')(conv8), conv1], axis=3)
conv9 = Conv2D(flt, (3, 3), activation='relu', padding='same')(up9)
conv9 = Conv2D(flt, (3, 3), activation='relu', padding='same')(conv9)
conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
model = Model(inputs=[inputs], outputs=[conv10])
model.compile(optimizer=Adam(lr=init_lr), loss=dice_coef_loss, metrics=[dice_coef])
return model
def train(fold, plane, batch_size, nb_epoch,init_lr):
"""
train an Unet model with data from load_train_data()
Parameters
----------
fold : string
which fold is experimenting in 4-fold. It should be one of 0/1/2/3
plane : char
which plane is experimenting. It is from 'X'/'Y'/'Z'
batch_size : int
size of mini-batch
nb_epoch : int
number of epochs to train NN
init_lr : float
initial learning rate
"""
print "number of epoch: ", nb_epoch
print "learning rate: ", init_lr
# --------------------- load and preprocess training data -----------------
print '-'*80
print ' Loading and preprocessing train data...'
print '-'*80
imgs_train, imgs_mask_train = load_train_data(fold, plane)
imgs_row = imgs_train.shape[1]
imgs_col = imgs_train.shape[2]
imgs_train = preprocess(imgs_train)
imgs_mask_train = preprocess(imgs_mask_train)
imgs_train = imgs_train.astype('float32')
imgs_mask_train = imgs_mask_train.astype('float32')
# ---------------------- Create, compile, and train model ------------------------
print '-'*80
print ' Creating and compiling model...'
print '-'*80
model = get_unet((imgs_row, imgs_col), pool_size=(2, 2, 2), init_lr=init_lr)
print model.summary()
print '-'*80
print ' Fitting model...'
print '-'*80
ver = 'unet_fd%s_%s_ep%s_lr%s.csv'%(cur_fold, plane, epoch, init_lr)
csv_logger = CSVLogger(log_path + ver)
model_checkpoint = ModelCheckpoint(model_path + ver + ".h5",
monitor='loss',
save_best_only=False,
period=10)
history = model.fit(imgs_train, imgs_mask_train,
batch_size= batch_size, epochs= nb_epoch, verbose=1, shuffle=True,
callbacks=[model_checkpoint, csv_logger])
if __name__ == "__main__":
train(cur_fold, plane, batch_size, epoch, init_lr)
print "training done"