Skip to content
This repository was archived by the owner on Feb 4, 2023. It is now read-only.

Commit

Permalink
Fine tuned network
Browse files Browse the repository at this point in the history
  • Loading branch information
nullJaX committed Mar 28, 2019
1 parent e53a94a commit 078e0ff
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
35 changes: 24 additions & 11 deletions create_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from keras import Input, Model
from keras.layers import Conv2D, Conv1D, LSTMCell
from keras.initializers import RandomUniform, RandomNormal
from keras.layers import Conv2D, Conv1D
from keras.optimizers import SGD
from keras.utils.vis_utils import plot_model

Expand All @@ -11,6 +12,18 @@
from layers.InverseGraphPropagation import InverseGraphPropagation


def get_cells():
init = RandomUniform(minval=-0.1, maxval=0.1, seed=None)
lstm_cell1 = GraphLSTMCell(IMAGE_SHAPE[-1],
kernel_initializer=init,
recurrent_initializer=init,
bias_initializer=init)
lstm_cell2 = GraphLSTMCell(IMAGE_SHAPE[-1],
kernel_initializer=init,
recurrent_initializer=init,
bias_initializer=init)
return lstm_cell1, lstm_cell2

def create_model():
# INPUTS
image = Input(shape=IMAGE_SHAPE, name="Image", batch_shape=(TRAIN_BATCH_SIZE,) + IMAGE_SHAPE)
Expand All @@ -19,10 +32,11 @@ def create_model():
neighbors = Input(shape=(N_SUPERPIXELS, N_SUPERPIXELS), name="Neighborhood", batch_shape=(TRAIN_BATCH_SIZE, N_SUPERPIXELS, N_SUPERPIXELS))

# IMAGE CONVOLUTION
conv1 = Conv2D(8, 5, padding='same')(image)
conv2 = Conv2D(16, 3, padding='same')(conv1)
conv3 = Conv2D(32, 3, padding='same')(conv2)
conv4 = Conv2D(1, 3, padding='same')(conv3)
conv_init = RandomNormal(stddev=0.001)
conv1 = Conv2D(8, 3, padding='same', kernel_initializer=conv_init, bias_initializer=conv_init)(image)
conv2 = Conv2D(16, 3, padding='same', kernel_initializer=conv_init, bias_initializer=conv_init)(conv1)
conv3 = Conv2D(N_FEATURES, 1, padding='same', kernel_initializer=conv_init, bias_initializer=conv_init)(conv2)
conv4 = Conv2D(1, 1, padding='same', kernel_initializer=conv_init, bias_initializer=conv_init)(conv3)

# CONFIDENCE MAP
confidence = Confidence(N_SUPERPIXELS, name="ConfidenceMap", trainable=False)([conv3, slic])
Expand All @@ -31,13 +45,12 @@ def create_model():
graph, reverse, mapping = GraphPropagation(N_SUPERPIXELS, name="GraphPath", trainable=False)([superpixels, confidence, neighbors])

# MAIN LSTM PART
lstm_cell = GraphLSTMCell(IMAGE_SHAPE[-1])
lstm = GraphLSTM(lstm_cell, return_sequences=True, name="G-LSTM", stateful=True)([graph, superpixels, neighbors, mapping, reverse])
# lstm = GraphLSTM(IMAGE_SHAPE[-1], return_sequences=True, name="G-LSTM", stateful=True)([graph, superpixels, neighbors, mapping])
# lstm2 = LSTM(IMAGE_SHAPE[-1], return_sequences=True, name="G-LSTM2")(lstm)
lstm_cell1, lstm_cell2 = get_cells()
lstm1 = GraphLSTM(lstm_cell1, return_sequences=True, name="G-LSTM1", stateful=True)([graph, superpixels, neighbors, mapping, reverse])
lstm2 = GraphLSTM(lstm_cell2, return_sequences=True, name="G-LSTM2", stateful=True)([lstm1, superpixels, neighbors, mapping, reverse])

# INVERSE GRAPH PROPAGATION
out_vertices = InverseGraphPropagation(name="InvGraphPath", trainable=False)([lstm, reverse])
out_vertices = InverseGraphPropagation(name="InvGraphPath", trainable=False)([lstm2, reverse])

out = Conv1D(IMAGE_SHAPE[-1], 1, name="OutputConv")(out_vertices)
# out = out_vertices
Expand All @@ -64,7 +77,7 @@ def create_model():
plot_model(model, show_shapes=True)

# OPTIMIZER
sgd = SGD(momentum=0.9, decay=0.0005)
sgd = SGD(lr=0.001, momentum=0.9, decay=0.0005, nesterov=True)
model.compile(sgd, loss="mse", metrics=["acc"])
model.save(MODEL_PATH)
return model
Expand Down
12 changes: 2 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import numpy
from skimage import io

from keras.callbacks import TerminateOnNaN, ModelCheckpoint, TensorBoard
from keras.engine.saving import load_model
from skimage import io
from skimage.transform import resize

from config import *
from create_model import create_model
from layers.ConfidenceLayer import Confidence
from layers.GraphLSTM import GraphLSTM
from layers.GraphLSTMCell import GraphLSTMCell
from layers.GraphPropagation import GraphPropagation
from layers.ConfidenceLayer import Confidence
from layers.InverseGraphPropagation import InverseGraphPropagation
from utils.utils import obtain_superpixels, get_neighbors, \
average_rgb_for_superpixels
Expand Down Expand Up @@ -48,12 +46,6 @@ def generator(image_list, images_path, expected_images, size=1):
vertices = average_rgb_for_superpixels(img, slic)
neighbors = get_neighbors(slic, N_SUPERPIXELS)
expected = average_rgb_for_superpixels(expected, slic)
assert not numpy.any(numpy.isnan(img))
assert not numpy.any(numpy.isnan(expected))
assert not numpy.any(numpy.isnan(confidence_map))
assert not numpy.any(numpy.isnan(slic))
assert not numpy.any(numpy.isnan(vertices))
assert not numpy.any(numpy.isnan(neighbors))

# ADD TO BATCH
batch_img += [img]
Expand Down

0 comments on commit 078e0ff

Please sign in to comment.