Skip to content
This repository was archived by the owner on Feb 4, 2023. It is now read-only.

Commit

Permalink
Contextual training
Browse files Browse the repository at this point in the history
  • Loading branch information
nullJaX committed Aug 13, 2019
1 parent 0978ae3 commit aea4292
Show file tree
Hide file tree
Showing 298 changed files with 303 additions and 223 deletions.
4 changes: 2 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import keras.backend as K

MODEL_PATH = "glstm.hdf5"
EPOCHS = 400
EPOCHS = 250
RAW_MODEL_PATH = "glstm_raw.hdf5"
VALIDATION_MODEL = "../data/checkpoints/model_340_0.15_0.13_0.83_0.85.hdf5"
IMAGE_SHAPE = (250, 250, 3)
SLIC_SHAPE = (IMAGE_SHAPE[0], IMAGE_SHAPE[1])
N_SUPERPIXELS = 100
N_FEATURES = 3
N_FEATURES = 4
INPUT_PATHS = 2

SLIC_SIGMA = 0
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed data/checkpoints/model_05_0.54_0.56_0.75_0.64.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_100_0.31_0.32_0.72_0.80.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_105_0.25_0.23_0.79_0.82.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_10_0.46_0.46_0.68_0.70.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_110_0.27_0.28_0.77_0.73.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_115_0.28_0.29_0.70_0.78.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_120_0.26_0.33_0.75_0.80.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_125_0.28_0.28_0.70_0.84.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_130_0.26_0.26_0.75_0.77.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_135_0.24_0.24_0.77_0.80.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_140_0.26_0.22_0.76_0.82.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_145_0.24_0.24_0.76_0.78.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_150_0.22_0.24_0.75_0.79.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_155_0.25_0.25_0.76_0.73.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_15_0.38_0.42_0.76_0.70.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_160_0.23_0.24_0.75_0.83.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_165_0.21_0.25_0.82_0.77.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_170_0.24_0.25_0.74_0.74.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_175_0.22_0.23_0.77_0.79.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_180_0.23_0.25_0.73_0.74.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_185_0.20_0.21_0.84_0.83.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_190_0.21_0.23_0.76_0.77.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_195_0.21_0.22_0.74_0.77.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_200_0.21_0.21_0.77_0.83.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_205_0.19_0.21_0.79_0.78.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_20_0.39_0.38_0.68_0.68.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_210_0.21_0.21_0.78_0.77.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_215_0.18_0.18_0.81_0.76.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_220_0.20_0.21_0.76_0.84.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_225_0.20_0.21_0.72_0.81.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_230_0.19_0.22_0.76_0.85.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_235_0.17_0.20_0.78_0.77.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_240_0.17_0.20_0.78_0.82.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_245_0.17_0.21_0.81_0.75.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_250_0.19_0.17_0.77_0.84.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_255_0.17_0.17_0.75_0.76.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_25_0.38_0.55_0.74_0.60.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_260_0.17_0.17_0.81_0.81.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_265_0.16_0.20_0.84_0.78.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_270_0.16_0.18_0.81_0.78.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_275_0.18_0.15_0.73_0.83.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_280_0.16_0.20_0.83_0.85.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_285_0.17_0.16_0.76_0.83.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_290_0.14_0.17_0.82_0.79.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_295_0.15_0.17_0.84_0.77.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_300_0.16_0.19_0.79_0.84.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_305_0.15_0.17_0.80_0.82.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_30_0.36_0.41_0.72_0.68.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_310_0.16_0.17_0.79_0.80.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_315_0.17_0.20_0.75_0.82.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_320_0.14_0.17_0.80_0.93.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_325_0.16_0.18_0.76_0.81.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_330_0.14_0.16_0.84_0.83.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_335_0.16_0.16_0.78_0.83.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_340_0.15_0.13_0.83_0.85.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_345_0.15_0.15_0.80_0.84.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_350_0.14_0.15_0.84_0.84.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_355_0.14_0.15_0.79_0.85.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_35_0.36_0.38_0.74_0.75.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_360_0.14_0.16_0.84_0.81.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_365_0.13_0.14_0.84_0.84.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_375_0.14_0.15_0.74_0.79.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_380_0.13_0.14_0.83_0.82.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_385_0.13_0.13_0.83_0.85.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_395_0.13_0.16_0.76_0.83.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_400_0.13_0.13_0.80_0.81.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_40_0.37_0.44_0.76_0.70.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_45_0.36_0.35_0.72_0.69.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_50_0.32_0.44_0.76_0.76.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_55_0.33_0.39_0.75_0.72.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_60_0.38_0.35_0.63_0.70.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_65_0.36_0.35_0.70_0.76.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_70_0.33_0.42_0.76_0.79.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_75_0.29_0.40_0.74_0.83.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_80_0.26_0.37_0.78_0.75.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_85_0.31_0.34_0.74_0.74.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_90_0.31_0.33_0.74_0.79.hdf5
Binary file not shown.
Binary file removed data/checkpoints/model_95_0.25_0.32_0.77_0.76.hdf5
Binary file not shown.
296 changes: 174 additions & 122 deletions test/evaluate.py

Large diffs are not rendered by default.

105 changes: 54 additions & 51 deletions test/generate_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from keras import Input, Model
from keras.initializers import RandomUniform
from keras.layers import Softmax, Dense, Concatenate, Dropout
from keras.optimizers import SGD
from keras.optimizers import SGD, RMSprop
from keras.utils import plot_model

from config import IMAGE_SHAPE, TRAIN_BATCH_SIZE, N_SUPERPIXELS, N_FEATURES, \
Expand All @@ -12,64 +12,67 @@


def generate_model():
init = RandomUniform(minval=-0.1, maxval=0.1, seed=None)

vertices = [Input(shape=(N_SUPERPIXELS, IMAGE_SHAPE[2]),
name="Vertices",
batch_shape=(TRAIN_BATCH_SIZE,
N_SUPERPIXELS,
IMAGE_SHAPE[2]))]
neighbors = [Input(shape=(N_SUPERPIXELS, N_SUPERPIXELS),
name="Neighborhood",
batch_shape=(TRAIN_BATCH_SIZE,
N_SUPERPIXELS,
N_SUPERPIXELS))]
inputs = [Input(shape=(N_SUPERPIXELS, IMAGE_SHAPE[2]),
name="Vertices_{0!s}".format(i),
batch_shape=(TRAIN_BATCH_SIZE,
N_SUPERPIXELS,
IMAGE_SHAPE[2]))
for i in range(INPUT_PATHS)]
indexes = [Input(shape=(N_SUPERPIXELS,),
batch_shape=(TRAIN_BATCH_SIZE, N_SUPERPIXELS),
name="Index_{0!s}".format(i), dtype="int32")
for i in range(INPUT_PATHS)]
r_indexes = [Input(shape=(N_SUPERPIXELS,),
batch_shape=(TRAIN_BATCH_SIZE, N_SUPERPIXELS),
name="ReverseIndex_{0!s}".format(i), dtype="int32")
for k in [2, 3, 5, 7, 9]:
INPUT_PATHS = k
init = RandomUniform(minval=-0.1, maxval=0.1, seed=None)

vertices = [Input(shape=(N_SUPERPIXELS, IMAGE_SHAPE[2]),
name="Vertices",
batch_shape=(TRAIN_BATCH_SIZE,
N_SUPERPIXELS,
IMAGE_SHAPE[2]))]
neighbors = [Input(shape=(N_SUPERPIXELS, N_SUPERPIXELS),
name="Neighborhood",
batch_shape=(TRAIN_BATCH_SIZE,
N_SUPERPIXELS,
N_SUPERPIXELS))]
inputs = [Input(shape=(N_SUPERPIXELS, IMAGE_SHAPE[2]),
name="Vertices_{0!s}".format(i),
batch_shape=(TRAIN_BATCH_SIZE,
N_SUPERPIXELS,
IMAGE_SHAPE[2]))
for i in range(INPUT_PATHS)]
indexes = [Input(shape=(N_SUPERPIXELS,),
batch_shape=(TRAIN_BATCH_SIZE, N_SUPERPIXELS),
name="Index_{0!s}".format(i), dtype="int32")
for i in range(INPUT_PATHS)]
r_indexes = [Input(shape=(N_SUPERPIXELS,),
batch_shape=(TRAIN_BATCH_SIZE, N_SUPERPIXELS),
name="ReverseIndex_{0!s}".format(i), dtype="int32")
for i in range(INPUT_PATHS)]

cells = [GraphLSTMCell(N_FEATURES, kernel_initializer=init,
recurrent_initializer=init,
bias_initializer=init) for _ in range(INPUT_PATHS)]
lstms = [GraphLSTM(cells[i], return_sequences=True,
name="G-LSTM_{0!s}".format(i), stateful=False)
([inputs[i], vertices[0], neighbors[0], indexes[i], r_indexes[i]])
for i in range(INPUT_PATHS)]

cells = [GraphLSTMCell(N_FEATURES, kernel_initializer=init,
recurrent_initializer=init,
bias_initializer=init) for _ in range(INPUT_PATHS)]
lstms = [GraphLSTM(cells[i], return_sequences=True,
name="G-LSTM_{0!s}".format(i), stateful=False)
([inputs[i], vertices[0], neighbors[0], indexes[i], r_indexes[i]])
for i in range(INPUT_PATHS)]
inverse = [InverseGraphPropagation()([lstms[i], r_indexes[i]]) for i in range(INPUT_PATHS)]

inverse = [InverseGraphPropagation()([lstms[i], r_indexes[i]]) for i in range(INPUT_PATHS)]
concat = Concatenate(axis=-1)(inverse)
# drop0 = Dropout(0.5)(lstms[0])
# d1 = Dense(int(INPUT_PATHS * N_FEATURES))(drop0)
# drop1 = Dropout(0.4)(concat)
# d2 = Dense(int(INPUT_PATHS * N_FEATURES))(drop1)
d3 = Dense(N_FEATURES)(concat)

concat = Concatenate(axis=-1)(inverse)
# drop0 = Dropout(0.5)(lstms[0])
# d1 = Dense(int(INPUT_PATHS * N_FEATURES))(drop0)
# drop1 = Dropout(0.4)(concat)
# d2 = Dense(int(INPUT_PATHS * N_FEATURES))(drop1)
d3 = Dense(N_FEATURES)(concat)
soft = Softmax()(d3)

soft = Softmax()(d3)
model = Model(inputs=vertices + neighbors + inputs + indexes + r_indexes,
outputs=[soft])

model = Model(inputs=vertices + neighbors + inputs + indexes + r_indexes,
outputs=[soft])
model.summary()

model.summary()
# PLOT
# plot_model(model, show_shapes=True)

# PLOT
plot_model(model, show_shapes=True)

# OPTIMIZER
sgd = SGD(lr=0.01, momentum=0.9, decay=0.005, nesterov=False)
model.compile(sgd, loss="categorical_crossentropy", metrics=["acc"])
model.save(RAW_MODEL_PATH)
# OPTIMIZER
# sgd = SGD(lr=0.001, momentum=0.9, decay=0.005, nesterov=False)
rms = RMSprop()
model.compile(rms, loss="categorical_crossentropy", metrics=["acc"])
model.save("glstm_raw{0!s}.hdf5".format(k))


# return model
Expand Down
Binary file removed test/glstm.hdf5
Binary file not shown.
Binary file not shown.
Binary file added test/glstm3.hdf5
Binary file not shown.
Binary file added test/glstm5.hdf5
Binary file not shown.
Binary file added test/glstm7.hdf5
Binary file not shown.
Binary file added test/glstm9.hdf5
Binary file not shown.
Binary file renamed test/glstm_raw.hdf5 → test/glstm_raw2.hdf5
Binary file not shown.
Binary file added test/glstm_raw3.hdf5
Binary file not shown.
Binary file added test/glstm_raw5.hdf5
Binary file not shown.
Binary file added test/glstm_raw7.hdf5
Binary file not shown.
Binary file added test/glstm_raw9.hdf5
Binary file not shown.
Binary file modified test/model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed test/test_0.png
Binary file not shown.
Binary file added test/test_0_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_0_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_0_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_0_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_0_9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed test/test_1.png
Binary file not shown.
Binary file removed test/test_10.png
Binary file not shown.
Binary file added test/test_10_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_10_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_10_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_10_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_10_9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed test/test_11.png
Binary file not shown.
Binary file added test/test_11_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_11_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_11_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_11_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_11_9.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed test/test_12.png
Binary file not shown.
Binary file added test/test_12_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_12_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_12_5.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_12_7.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added test/test_12_9.png
Binary file removed test/test_13.png
Diff not rendered.
Binary file added test/test_13_2.png
Binary file added test/test_13_3.png
Binary file added test/test_13_5.png
Binary file added test/test_13_7.png
Binary file added test/test_13_9.png
Binary file removed test/test_14.png
Diff not rendered.
Binary file added test/test_14_2.png
Binary file added test/test_14_3.png
Binary file added test/test_14_5.png
Binary file added test/test_14_7.png
Binary file added test/test_14_9.png
Binary file removed test/test_15.png
Diff not rendered.
Binary file added test/test_15_2.png
Binary file added test/test_15_3.png
Binary file added test/test_15_5.png
Binary file added test/test_15_7.png
Binary file added test/test_15_9.png
Binary file removed test/test_16.png
Diff not rendered.
Binary file added test/test_16_2.png
Binary file added test/test_16_3.png
Binary file added test/test_16_5.png
Binary file added test/test_16_7.png
Binary file added test/test_16_9.png
Binary file removed test/test_17.png
Diff not rendered.
Binary file added test/test_17_2.png
Binary file added test/test_17_3.png
Binary file added test/test_17_5.png
Binary file added test/test_17_7.png
Binary file added test/test_17_9.png
Binary file removed test/test_18.png
Diff not rendered.
Binary file added test/test_18_2.png
Binary file added test/test_18_3.png
Binary file added test/test_18_5.png
Binary file added test/test_18_7.png
Binary file added test/test_18_9.png
Binary file removed test/test_19.png
Diff not rendered.
Binary file added test/test_19_2.png
Binary file added test/test_19_3.png
Binary file added test/test_19_5.png
Binary file added test/test_19_7.png
Binary file added test/test_19_9.png
Binary file added test/test_1_2.png
Binary file added test/test_1_3.png
Binary file added test/test_1_5.png
Binary file added test/test_1_7.png
Binary file added test/test_1_9.png
Binary file removed test/test_2.png
Diff not rendered.
Binary file removed test/test_20.png
Diff not rendered.
Binary file added test/test_20_2.png
Binary file added test/test_20_3.png
Binary file added test/test_20_5.png
Binary file added test/test_20_7.png
Binary file added test/test_20_9.png
Binary file removed test/test_21.png
Diff not rendered.
Binary file added test/test_21_2.png
Binary file added test/test_21_3.png
Binary file added test/test_21_5.png
Binary file added test/test_21_7.png
Binary file added test/test_21_9.png
Binary file removed test/test_22.png
Diff not rendered.
Binary file added test/test_22_2.png
Binary file added test/test_22_3.png
Binary file added test/test_22_5.png
Binary file added test/test_22_7.png
Binary file added test/test_22_9.png
Binary file removed test/test_23.png
Diff not rendered.
Binary file added test/test_23_2.png
Binary file added test/test_23_3.png
Binary file added test/test_23_5.png
Binary file added test/test_23_7.png
Binary file added test/test_23_9.png
Binary file removed test/test_24.png
Diff not rendered.
Binary file added test/test_24_2.png
Binary file added test/test_24_3.png
Binary file added test/test_24_5.png
Binary file added test/test_24_7.png
Binary file added test/test_24_9.png
Binary file removed test/test_25.png
Diff not rendered.
Binary file added test/test_25_2.png
Binary file added test/test_25_3.png
Binary file added test/test_25_5.png
Binary file added test/test_25_7.png
Binary file added test/test_25_9.png
Binary file removed test/test_26.png
Diff not rendered.
Binary file added test/test_26_2.png
Binary file added test/test_26_3.png
Binary file added test/test_26_5.png
Binary file added test/test_26_7.png
Binary file added test/test_26_9.png
Binary file removed test/test_27.png
Diff not rendered.
Binary file added test/test_27_2.png
Binary file added test/test_27_3.png
Binary file added test/test_27_5.png
Binary file added test/test_27_7.png
Binary file added test/test_27_9.png
Binary file removed test/test_28.png
Diff not rendered.
Binary file added test/test_28_2.png
Binary file added test/test_28_3.png
Binary file added test/test_28_5.png
Binary file added test/test_28_7.png
Binary file added test/test_28_9.png
Binary file removed test/test_29.png
Diff not rendered.
Binary file added test/test_29_2.png
Binary file added test/test_29_3.png
Binary file added test/test_29_5.png
Binary file added test/test_29_7.png
Binary file added test/test_29_9.png
Binary file added test/test_2_2.png
Binary file added test/test_2_3.png
Binary file added test/test_2_5.png
Binary file added test/test_2_7.png
Binary file added test/test_2_9.png
Binary file removed test/test_3.png
Diff not rendered.
Binary file removed test/test_30.png
Diff not rendered.
Binary file added test/test_30_2.png
Binary file added test/test_30_3.png
Binary file added test/test_30_5.png
Binary file added test/test_30_7.png
Binary file added test/test_30_9.png
Binary file removed test/test_31.png
Diff not rendered.
Binary file added test/test_31_2.png
Binary file added test/test_31_3.png
Binary file added test/test_31_5.png
Binary file added test/test_31_7.png
Binary file added test/test_31_9.png
Binary file removed test/test_32.png
Diff not rendered.
Binary file added test/test_32_2.png
Binary file added test/test_32_3.png
Binary file added test/test_32_5.png
Binary file added test/test_32_7.png
Binary file added test/test_32_9.png
Binary file added test/test_3_2.png
Binary file added test/test_3_3.png
Binary file added test/test_3_5.png
Binary file added test/test_3_7.png
Binary file added test/test_3_9.png
Binary file removed test/test_4.png
Diff not rendered.
Binary file added test/test_4_2.png
Binary file added test/test_4_3.png
Binary file added test/test_4_5.png
Binary file added test/test_4_7.png
Binary file added test/test_4_9.png
Binary file removed test/test_5.png
Diff not rendered.
Binary file added test/test_5_2.png
Binary file added test/test_5_3.png
Binary file added test/test_5_5.png
Binary file added test/test_5_7.png
Binary file added test/test_5_9.png
Binary file removed test/test_6.png
Diff not rendered.
Binary file added test/test_6_2.png
Binary file added test/test_6_3.png
Binary file added test/test_6_5.png
Binary file added test/test_6_7.png
Binary file added test/test_6_9.png
Binary file removed test/test_7.png
Diff not rendered.
Binary file added test/test_7_2.png
Binary file added test/test_7_3.png
Binary file added test/test_7_5.png
Binary file added test/test_7_7.png
Binary file added test/test_7_9.png
Binary file removed test/test_8.png
Diff not rendered.
Binary file added test/test_8_2.png
Binary file added test/test_8_3.png
Binary file added test/test_8_5.png
Binary file added test/test_8_7.png
Binary file added test/test_8_9.png
Binary file removed test/test_9.png
Diff not rendered.
Binary file added test/test_9_2.png
Binary file added test/test_9_3.png
Binary file added test/test_9_5.png
Binary file added test/test_9_7.png
Binary file added test/test_9_9.png
118 changes: 70 additions & 48 deletions test/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import random
import matplotlib.pyplot as plt
import numpy
Expand All @@ -15,24 +16,25 @@
get_superpixels_index_for_hot_areas


def init_callbacks():
def init_callbacks(k):
terminator = TerminateOnNaN()
checkpointer = ModelCheckpoint(
"../data/checkpoints/model_{epoch:02d}_{loss:.2f}_{val_loss:.2f}_{acc:.2f}_{val_acc:.2f}.hdf5",
"../data/checkpoints/model"+str(k)+"_{epoch:02d}_{loss:.2f}_{val_loss:.2f}_{acc:.2f}_{val_acc:.2f}.hdf5",
monitor="val_loss", save_weights_only=False, mode="min", period=5)
return [terminator, checkpointer]


def generator(image_list, images_path, expected_images, size=1):
def generator(image_list, images_path, expected_images, k, size=1):
while True:
batch_names = numpy.random.choice(image_list, size=size)

batch_vertices = []
batch_neighbors = []
batch_expected = []

batch_inputs = [[] for _ in range(INPUT_PATHS)]
batch_indexes = [[] for _ in range(INPUT_PATHS)]
batch_r_indexes = [[] for _ in range(INPUT_PATHS)]
batch_inputs = [[] for _ in range(k)]
batch_indexes = [[] for _ in range(k)]
batch_r_indexes = [[] for _ in range(k)]

for image_name in batch_names:
# LOAD IMAGES
Expand All @@ -43,8 +45,13 @@ def generator(image_list, images_path, expected_images, size=1):
slic = obtain_superpixels(img, N_SUPERPIXELS, SLIC_SIGMA)
vertices = average_rgb_for_superpixels(img, slic)
neighbors = get_neighbors(slic, N_SUPERPIXELS)

expected = copy.deepcopy(vertices)
for i in expected:
i.append(0.0)

areas = get_superpixels_index_for_hot_areas(slic)
for paths in range(INPUT_PATHS):
for paths in range(k):
vertex_index = areas[paths]
path, mapping, r_mapping = sort_values(vertices, neighbors,
vertex_index,
Expand All @@ -53,12 +60,26 @@ def generator(image_list, images_path, expected_images, size=1):
batch_indexes[paths].append(mapping)
batch_r_indexes[paths].append(r_mapping)

green_near_red_indices = []
for i, v in enumerate(vertices):
if v[0] != 1.0:
continue
neighborhood_indexes = numpy.where(neighbors[i] == 1)[0]
if any(vertices[n][1] == 1.0 for n in neighborhood_indexes):
green_near_red_indices.append(i)

for i in green_near_red_indices:
expected[i] = [0.0] * len(expected[i])
expected[i][-1] = 1.0

# ADD TO BATCH
batch_vertices += [vertices]
batch_neighbors += [neighbors]
batch_expected += [expected]

batch_vertices = numpy.array(batch_vertices)
batch_neighbors = numpy.array(batch_neighbors)
batch_expected = numpy.array(batch_expected)
batch_inputs = [numpy.array(i) for i in batch_inputs]
batch_indexes = [numpy.array(i) for i in batch_indexes]
batch_r_indexes = [numpy.array(i) for i in batch_r_indexes]
Expand All @@ -67,47 +88,48 @@ def generator(image_list, images_path, expected_images, size=1):
batch_inputs +
batch_indexes +
batch_r_indexes,
[batch_vertices])
[batch_expected])


if __name__ == '__main__':
callbacks = init_callbacks()
val_image_list = image_list[:VALIDATION_ELEMS]
train_image_list = image_list[VALIDATION_ELEMS:]

model = load_model(RAW_MODEL_PATH,
custom_objects={'GraphLSTM': GraphLSTM,
'GraphLSTMCell': GraphLSTMCell,
'InverseGraphPropagation': InverseGraphPropagation})
# model = create_model()
history = model.fit_generator(
generator(train_image_list, "../data/test/", "../data/test/",
TRAIN_BATCH_SIZE),
steps_per_epoch=numpy.ceil(TRAIN_ELEMS / TRAIN_BATCH_SIZE),
epochs=EPOCHS,
verbose=1,
callbacks=callbacks,
validation_data=generator(val_image_list, "../data/test/",
"../data/test/", VALIDATION_BATCH_SIZE),
validation_steps=numpy.ceil(VALIDATION_ELEMS / VALIDATION_BATCH_SIZE),
max_queue_size=10,
shuffle=True)
model.save(MODEL_PATH)

plt.plot(history.history['acc'], color="#FF3864")
plt.plot(history.history['val_acc'], color="#261447")
plt.title('Accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
plt.savefig('model_accuracy.png')
# summarize history for loss
plt.plot(history.history['loss'], color="#FF3864")
plt.plot(history.history['val_loss'], color="#261447")
plt.title('Loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
plt.savefig('model_loss.png')
for k in [2, 3, 5, 7, 9]:
callbacks = init_callbacks(k)
val_image_list = image_list[:VALIDATION_ELEMS]
train_image_list = image_list[VALIDATION_ELEMS:]

model = load_model("glstm_raw{0!s}.hdf5".format(k),
custom_objects={'GraphLSTM': GraphLSTM,
'GraphLSTMCell': GraphLSTMCell,
'InverseGraphPropagation': InverseGraphPropagation})
# model = create_model()
history = model.fit_generator(
generator(train_image_list, "../data/test/", "../data/test/", k,
TRAIN_BATCH_SIZE),
steps_per_epoch=numpy.ceil(TRAIN_ELEMS / TRAIN_BATCH_SIZE),
epochs=EPOCHS,
verbose=1,
callbacks=callbacks,
validation_data=generator(val_image_list, "../data/test/",
"../data/test/", k, VALIDATION_BATCH_SIZE),
validation_steps=numpy.ceil(VALIDATION_ELEMS / VALIDATION_BATCH_SIZE),
max_queue_size=10,
shuffle=True)
model.save("glstm{0!s}.hdf5".format(k))

plt.plot(history.history['acc'], color="#FF3864")
plt.plot(history.history['val_acc'], color="#261447")
plt.title('Accuracy - {0!s} paths'.format(k))
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
plt.savefig('model_accuracy_{0!s}_paths.png'.format(k))
# summarize history for loss
plt.plot(history.history['loss'], color="#FF3864")
plt.plot(history.history['val_loss'], color="#261447")
plt.title('Loss - {0!s} paths'.format(k))
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()
plt.savefig('model_loss_{0!s}_paths.png'.format(k))
3 changes: 3 additions & 0 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def average_rgb_for_superpixels(image, segments):
av_local = []
for c in range(filtered.shape[-1]):
av_local.append(np.sum(filtered[:, :, c], axis=1).sum() / non_zero_pixels_amount)
i = av_local.index(max(av_local))
av_local = [0.0] * len(av_local)
av_local[i] = 1.0
averages.append(av_local)
while len(averages) != N_SUPERPIXELS:
averages.append([0.0] * IMAGE_SHAPE[-1])
Expand Down

0 comments on commit aea4292

Please sign in to comment.