diff --git a/create_model.py b/create_model.py index 23df2e5..1d787c0 100644 --- a/create_model.py +++ b/create_model.py @@ -1,5 +1,6 @@ from keras import Input, Model -from keras.layers import Conv2D, Conv1D, LSTMCell +from keras.initializers import RandomUniform, RandomNormal +from keras.layers import Conv2D, Conv1D from keras.optimizers import SGD from keras.utils.vis_utils import plot_model @@ -11,6 +12,18 @@ from layers.InverseGraphPropagation import InverseGraphPropagation +def get_cells(): + init = RandomUniform(minval=-0.1, maxval=0.1, seed=None) + lstm_cell1 = GraphLSTMCell(IMAGE_SHAPE[-1], + kernel_initializer=init, + recurrent_initializer=init, + bias_initializer=init) + lstm_cell2 = GraphLSTMCell(IMAGE_SHAPE[-1], + kernel_initializer=init, + recurrent_initializer=init, + bias_initializer=init) + return lstm_cell1, lstm_cell2 + def create_model(): # INPUTS image = Input(shape=IMAGE_SHAPE, name="Image", batch_shape=(TRAIN_BATCH_SIZE,) + IMAGE_SHAPE) @@ -19,10 +32,11 @@ def create_model(): neighbors = Input(shape=(N_SUPERPIXELS, N_SUPERPIXELS), name="Neighborhood", batch_shape=(TRAIN_BATCH_SIZE, N_SUPERPIXELS, N_SUPERPIXELS)) # IMAGE CONVOLUTION - conv1 = Conv2D(8, 5, padding='same')(image) - conv2 = Conv2D(16, 3, padding='same')(conv1) - conv3 = Conv2D(32, 3, padding='same')(conv2) - conv4 = Conv2D(1, 3, padding='same')(conv3) + conv_init = RandomNormal(stddev=0.001) + conv1 = Conv2D(8, 3, padding='same', kernel_initializer=conv_init, bias_initializer=conv_init)(image) + conv2 = Conv2D(16, 3, padding='same', kernel_initializer=conv_init, bias_initializer=conv_init)(conv1) + conv3 = Conv2D(N_FEATURES, 1, padding='same', kernel_initializer=conv_init, bias_initializer=conv_init)(conv2) + conv4 = Conv2D(1, 1, padding='same', kernel_initializer=conv_init, bias_initializer=conv_init)(conv3) # CONFIDENCE MAP confidence = Confidence(N_SUPERPIXELS, name="ConfidenceMap", trainable=False)([conv3, slic]) @@ -31,13 +45,12 @@ def create_model(): graph, reverse, mapping = GraphPropagation(N_SUPERPIXELS, name="GraphPath", trainable=False)([superpixels, confidence, neighbors]) # MAIN LSTM PART - lstm_cell = GraphLSTMCell(IMAGE_SHAPE[-1]) - lstm = GraphLSTM(lstm_cell, return_sequences=True, name="G-LSTM", stateful=True)([graph, superpixels, neighbors, mapping, reverse]) - # lstm = GraphLSTM(IMAGE_SHAPE[-1], return_sequences=True, name="G-LSTM", stateful=True)([graph, superpixels, neighbors, mapping]) - # lstm2 = LSTM(IMAGE_SHAPE[-1], return_sequences=True, name="G-LSTM2")(lstm) + lstm_cell1, lstm_cell2 = get_cells() + lstm1 = GraphLSTM(lstm_cell1, return_sequences=True, name="G-LSTM1", stateful=True)([graph, superpixels, neighbors, mapping, reverse]) + lstm2 = GraphLSTM(lstm_cell2, return_sequences=True, name="G-LSTM2", stateful=True)([lstm1, superpixels, neighbors, mapping, reverse]) # INVERSE GRAPH PROPAGATION - out_vertices = InverseGraphPropagation(name="InvGraphPath", trainable=False)([lstm, reverse]) + out_vertices = InverseGraphPropagation(name="InvGraphPath", trainable=False)([lstm2, reverse]) out = Conv1D(IMAGE_SHAPE[-1], 1, name="OutputConv")(out_vertices) # out = out_vertices @@ -64,7 +77,7 @@ def create_model(): plot_model(model, show_shapes=True) # OPTIMIZER - sgd = SGD(momentum=0.9, decay=0.0005) + sgd = SGD(lr=0.001, momentum=0.9, decay=0.0005, nesterov=True) model.compile(sgd, loss="mse", metrics=["acc"]) model.save(MODEL_PATH) return model diff --git a/train.py b/train.py index a3c93a4..24ca34c 100644 --- a/train.py +++ b/train.py @@ -1,16 +1,14 @@ import numpy -from skimage import io - from keras.callbacks import TerminateOnNaN, ModelCheckpoint, TensorBoard from keras.engine.saving import load_model +from skimage import io from skimage.transform import resize from config import * -from create_model import create_model +from layers.ConfidenceLayer import Confidence from layers.GraphLSTM import GraphLSTM from layers.GraphLSTMCell import GraphLSTMCell from layers.GraphPropagation import GraphPropagation -from layers.ConfidenceLayer import Confidence from layers.InverseGraphPropagation import InverseGraphPropagation from utils.utils import obtain_superpixels, get_neighbors, \ average_rgb_for_superpixels @@ -48,12 +46,6 @@ def generator(image_list, images_path, expected_images, size=1): vertices = average_rgb_for_superpixels(img, slic) neighbors = get_neighbors(slic, N_SUPERPIXELS) expected = average_rgb_for_superpixels(expected, slic) - assert not numpy.any(numpy.isnan(img)) - assert not numpy.any(numpy.isnan(expected)) - assert not numpy.any(numpy.isnan(confidence_map)) - assert not numpy.any(numpy.isnan(slic)) - assert not numpy.any(numpy.isnan(vertices)) - assert not numpy.any(numpy.isnan(neighbors)) # ADD TO BATCH batch_img += [img]