-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathmain_segmentation.py
executable file
·63 lines (52 loc) · 2.26 KB
/
main_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from __future__ import print_function
import os
from src.utils.HED_data_parser import DataParser
from src.networks.hed import hed
from keras.utils import plot_model
from keras import backend as K
from keras import callbacks
import numpy as np
# import pdb
def generate_minibatches(dataParser, train=True):
# pdb.set_trace()
while True:
if train:
batch_ids = np.random.choice(dataParser.training_ids, dataParser.batch_size_train)
else:
batch_ids = np.random.choice(dataParser.validation_ids, dataParser.batch_size_train*2)
ims, ems, _ = dataParser.get_batch(batch_ids)
yield(ims, [ems, ems, ems, ems, ems, ems])
######
if __name__ == "__main__":
# params
model_name = 'HEDSeg'
model_dir = os.path.join('checkpoints', model_name)
csv_fn = os.path.join(model_dir, 'train_log.csv')
checkpoint_fn = os.path.join(model_dir, 'checkpoint.{epoch:02d}-{val_loss:.2f}.hdf5')
batch_size_train = 10
# environment
K.set_image_data_format('channels_last')
K.image_data_format()
# os.environ["CUDA_VISIBLE_DEVICES"]= '0'
if not os.path.isdir(model_dir): os.makedirs(model_dir)
# prepare data
dataParser = DataParser(batch_size_train)
# model
model = hed()
plot_model(model, to_file=os.path.join(model_dir, 'model.png'), show_shapes=True)
# training
# call backs
checkpointer = callbacks.ModelCheckpoint(filepath=checkpoint_fn, verbose=1, save_best_only=True)
csv_logger = callbacks.CSVLogger(csv_fn, append=True, separator=';')
tensorboard = callbacks.TensorBoard(log_dir=model_dir, histogram_freq=0, batch_size=10,
write_graph=False, write_grads=True, write_images=False)
train_history = model.fit_generator(
generate_minibatches(dataParser,),
# max_q_size=40, workers=1,
steps_per_epoch=dataParser.steps_per_epoch, #batch size
epochs=2048*2,
validation_data=generate_minibatches(dataParser, train=False),
validation_steps=dataParser.validation_steps,
callbacks=[checkpointer, csv_logger, tensorboard])
# pdb.set_trace()
print(train_history)