Skip to content
This repository was archived by the owner on Feb 4, 2023. It is now read-only.

Commit

Permalink
Validation batches fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
nullJaX committed Mar 28, 2019
1 parent 599fba6 commit e53a94a
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 76 deletions.
5 changes: 2 additions & 3 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
MODEL_PATH = "glstm.hdf5"
# MODEL_PATH = "./data/checkpoints/model_10_0.88.hdf5"

IMAGE_SHAPE = (500, 500, 3)
IMAGE_SHAPE = (250, 250, 3)
SLIC_SHAPE = (IMAGE_SHAPE[0], IMAGE_SHAPE[1])
N_SUPERPIXELS = 1000
N_FEATURES = 5
Expand All @@ -19,7 +19,6 @@


TRAIN_ELEMS = 1464
TRAIN_BATCH_SIZE = 2
TRAIN_BATCH_SIZE = VALIDATION_BATCH_SIZE = PREDICT_BATCH_SIZE = 2

VALIDATION_ELEMS = 2913
VALIDATION_BATCH_SIZE = 2
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import matplotlib.pyplot as plt

from config import N_SUPERPIXELS
from utils.utils import obtain_superpixels, get_confidence_map, get_neighbors, average_rgb_for_superpixels, sort_values
from utils.utils import obtain_superpixels, get_neighbors, average_rgb_for_superpixels

if __name__ == '__main__':
image = io.imread("lena.png")
Expand Down
38 changes: 0 additions & 38 deletions utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import operator
import random

from cv2 import cv2
import numpy as np
from skimage import img_as_float
Expand Down Expand Up @@ -29,10 +26,6 @@ def obtain_superpixels(image, n_segments, sigma):
return slic(img_as_float(image), n_segments=n_segments, sigma=sigma)


def get_confidence_map(image, segments):
return {i: random.uniform(a=0, b=1) for i in np.unique(segments)}


def get_neighbors(segments, n_segments):
# get unique labels
vertices = np.unique(segments)
Expand All @@ -57,34 +50,3 @@ def get_neighbors(segments, n_segments):
matrix[start_node, neighbor] = 1
matrix[neighbor, start_node] = 1
return matrix


def sort_values(values: dict, neighbors: dict, confidence_map: dict,
mode="dfs"):
assert len(confidence_map) == len(values) == len(neighbors)
if mode == "confidence":
sorted_vertices = sorted(confidence_map.items(),
key=operator.itemgetter(1), reverse=True)
sorted_vertices = [x[0] for x in sorted_vertices]
return [values[v] for v in sorted_vertices]
else:
start_vertex = \
sorted(confidence_map.items(), key=operator.itemgetter(1),
reverse=True)[0][0]
visited, queue = set(), [start_vertex]
result = []
if mode == "bfs":
while queue:
vertex = queue.pop(0)
if vertex not in visited:
visited.add(vertex)
result.append(values[vertex])
queue.extend(neighbors[vertex] - visited)
elif mode == "dfs":
while queue:
vertex = queue.pop()
if vertex not in visited:
visited.add(vertex)
result.append(values[vertex])
queue.extend(neighbors[vertex] - visited)
return result
97 changes: 63 additions & 34 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from tqdm import tqdm

from config import MODEL_PATH, VALSET_FILE, IMAGES_PATH, N_SUPERPIXELS, \
SLIC_SIGMA, OUTPUT_PATH, IMAGE_SHAPE, VALIDATION_IMAGES
SLIC_SIGMA, OUTPUT_PATH, IMAGE_SHAPE, VALIDATION_IMAGES, \
PREDICT_BATCH_SIZE, SLIC_SHAPE
from layers.ConfidenceLayer import Confidence
from layers.GraphLSTM import GraphLSTM
from layers.GraphLSTMCell import GraphLSTMCell
Expand All @@ -18,42 +19,70 @@
with open(VALSET_FILE) as f:
image_list = [line.replace("\n", "") for line in f]

while len(image_list) % PREDICT_BATCH_SIZE != 0:
image_list.append(None)

model = load_model(MODEL_PATH, custom_objects={'Confidence': Confidence,
'GraphPropagation': GraphPropagation,
'InverseGraphPropagation': InverseGraphPropagation,
'GraphLSTM': GraphLSTM,
'GraphLSTMCell': GraphLSTMCell})

# TODO: MAKE VALIDATION IN BATCHES THE SAME AS TRAIN_BATCH_SIZE
for image_name in tqdm(image_list):
image = io.imread(IMAGES_PATH + image_name + ".jpg")
shape = image.shape
img = resize(image, IMAGE_SHAPE, anti_aliasing=True)
slic = obtain_superpixels(img, N_SUPERPIXELS, SLIC_SIGMA)
vertices = average_rgb_for_superpixels(img, slic)
neighbors = get_neighbors(slic, N_SUPERPIXELS)

# TO NUMPIES
img = numpy.expand_dims(numpy.array(img), axis=0)
slic = numpy.expand_dims(numpy.array(slic), axis=0)
vertices = numpy.expand_dims(numpy.array(vertices), axis=0)
neighbors = numpy.expand_dims(numpy.array(neighbors), axis=0)

output_vertices, _ = model.predict_on_batch([img, slic, vertices, neighbors])

slic_out = slic
output_image = numpy.zeros(img[0].shape, dtype="uint8")
for segment_num in range(output_vertices.shape[1]):
if segment_num not in numpy.unique(slic):
break
mask = numpy.zeros(slic[0, :, :].shape + (3,), dtype="uint8")
mask[slic[0, :, :] == segment_num] = 255 * output_vertices[0, segment_num, :]
output_image += mask

output_image = resize(output_image, shape, anti_aliasing=True)
output_image = numpy.clip(output_image * 255, 0, 255)
expected_image = io.imread(VALIDATION_IMAGES + image_name + ".png")
i = numpy.concatenate((image, expected_image, output_image), axis=1)
output = numpy.clip(i, 0, 255)
output = output.astype(numpy.uint8)
io.imsave(OUTPUT_PATH + image_name + ".png", output)
for img_batch_start in tqdm(range(int(numpy.ceil(len(image_list) / PREDICT_BATCH_SIZE)))):
batch_img = []
batch_slic = []
batch_vertices = []
batch_neighbors = []
scale_list = []
image_names = []
images_list = []
for img_index in range(PREDICT_BATCH_SIZE):
real_index = PREDICT_BATCH_SIZE * img_batch_start + img_index
image_name = image_list[real_index]
if image_name is not None:
# LOAD IMAGES
image = io.imread(IMAGES_PATH + image_name + ".jpg")
images_list.append(image)
scale_list.append(image.shape)
image_names.append(image_name)
img = resize(image, IMAGE_SHAPE, anti_aliasing=True)

# OBTAIN OTHER USEFUL DATA
slic = obtain_superpixels(img, N_SUPERPIXELS, SLIC_SIGMA)
vertices = average_rgb_for_superpixels(img, slic)
neighbors = get_neighbors(slic, N_SUPERPIXELS)
else:
img = numpy.zeros(IMAGE_SHAPE, dtype=float)
slic = numpy.zeros(SLIC_SHAPE, dtype=float)
vertices = average_rgb_for_superpixels(img, slic)
neighbors = get_neighbors(slic, N_SUPERPIXELS)

# ADD TO BATCH
batch_img += [img]
batch_slic += [slic]
batch_vertices += [vertices]
batch_neighbors += [neighbors]
batch_img = numpy.array(batch_img)
batch_slic = numpy.array(batch_slic)
batch_vertices = numpy.array(batch_vertices)
batch_neighbors = numpy.array(batch_neighbors)

output_vertices, _ = model.predict_on_batch([batch_img, batch_slic, batch_vertices, batch_neighbors])

for index, shape in enumerate(scale_list):
slic_out = batch_slic[index]
output_image = numpy.zeros(batch_img[index].shape, dtype="uint8")
for segment_num in range(output_vertices[index].shape[1]):
if segment_num not in numpy.unique(batch_slic[index]):
break
mask = numpy.zeros(batch_slic[index, :, :].shape + (3,), dtype="uint8")
mask[batch_slic[index, :, :] == segment_num] = 255 * output_vertices[index, segment_num, :]
output_image += mask

output_image = resize(output_image, shape, anti_aliasing=True)
output_image = numpy.clip(output_image * 255, 0, 255)
expected_image = io.imread(VALIDATION_IMAGES + image_names[index] + ".png")
i = numpy.concatenate((images_list[index], expected_image, output_image), axis=1)
output = numpy.clip(i, 0, 255)
output = output.astype(numpy.uint8)
io.imsave(OUTPUT_PATH + image_names[index] + ".png", output)

0 comments on commit e53a94a

Please sign in to comment.