diff --git a/config.py b/config.py index 526a58b..ed7458e 100644 --- a/config.py +++ b/config.py @@ -1,7 +1,7 @@ MODEL_PATH = "glstm.hdf5" # MODEL_PATH = "./data/checkpoints/model_10_0.88.hdf5" -IMAGE_SHAPE = (500, 500, 3) +IMAGE_SHAPE = (250, 250, 3) SLIC_SHAPE = (IMAGE_SHAPE[0], IMAGE_SHAPE[1]) N_SUPERPIXELS = 1000 N_FEATURES = 5 @@ -19,7 +19,6 @@ TRAIN_ELEMS = 1464 -TRAIN_BATCH_SIZE = 2 +TRAIN_BATCH_SIZE = VALIDATION_BATCH_SIZE = PREDICT_BATCH_SIZE = 2 VALIDATION_ELEMS = 2913 -VALIDATION_BATCH_SIZE = 2 diff --git a/main.py b/main.py index acfd643..b198b7f 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt from config import N_SUPERPIXELS -from utils.utils import obtain_superpixels, get_confidence_map, get_neighbors, average_rgb_for_superpixels, sort_values +from utils.utils import obtain_superpixels, get_neighbors, average_rgb_for_superpixels if __name__ == '__main__': image = io.imread("lena.png") diff --git a/utils/utils.py b/utils/utils.py index ee06958..9f7c24b 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,6 +1,3 @@ -import operator -import random - from cv2 import cv2 import numpy as np from skimage import img_as_float @@ -29,10 +26,6 @@ def obtain_superpixels(image, n_segments, sigma): return slic(img_as_float(image), n_segments=n_segments, sigma=sigma) -def get_confidence_map(image, segments): - return {i: random.uniform(a=0, b=1) for i in np.unique(segments)} - - def get_neighbors(segments, n_segments): # get unique labels vertices = np.unique(segments) @@ -57,34 +50,3 @@ def get_neighbors(segments, n_segments): matrix[start_node, neighbor] = 1 matrix[neighbor, start_node] = 1 return matrix - - -def sort_values(values: dict, neighbors: dict, confidence_map: dict, - mode="dfs"): - assert len(confidence_map) == len(values) == len(neighbors) - if mode == "confidence": - sorted_vertices = sorted(confidence_map.items(), - key=operator.itemgetter(1), reverse=True) - sorted_vertices = [x[0] for x in sorted_vertices] - return [values[v] for v in sorted_vertices] - else: - start_vertex = \ - sorted(confidence_map.items(), key=operator.itemgetter(1), - reverse=True)[0][0] - visited, queue = set(), [start_vertex] - result = [] - if mode == "bfs": - while queue: - vertex = queue.pop(0) - if vertex not in visited: - visited.add(vertex) - result.append(values[vertex]) - queue.extend(neighbors[vertex] - visited) - elif mode == "dfs": - while queue: - vertex = queue.pop() - if vertex not in visited: - visited.add(vertex) - result.append(values[vertex]) - queue.extend(neighbors[vertex] - visited) - return result diff --git a/val.py b/val.py index 1be7470..9595b6a 100644 --- a/val.py +++ b/val.py @@ -5,7 +5,8 @@ from tqdm import tqdm from config import MODEL_PATH, VALSET_FILE, IMAGES_PATH, N_SUPERPIXELS, \ - SLIC_SIGMA, OUTPUT_PATH, IMAGE_SHAPE, VALIDATION_IMAGES + SLIC_SIGMA, OUTPUT_PATH, IMAGE_SHAPE, VALIDATION_IMAGES, \ + PREDICT_BATCH_SIZE, SLIC_SHAPE from layers.ConfidenceLayer import Confidence from layers.GraphLSTM import GraphLSTM from layers.GraphLSTMCell import GraphLSTMCell @@ -18,42 +19,70 @@ with open(VALSET_FILE) as f: image_list = [line.replace("\n", "") for line in f] + while len(image_list) % PREDICT_BATCH_SIZE != 0: + image_list.append(None) + model = load_model(MODEL_PATH, custom_objects={'Confidence': Confidence, 'GraphPropagation': GraphPropagation, 'InverseGraphPropagation': InverseGraphPropagation, 'GraphLSTM': GraphLSTM, 'GraphLSTMCell': GraphLSTMCell}) - # TODO: MAKE VALIDATION IN BATCHES THE SAME AS TRAIN_BATCH_SIZE - for image_name in tqdm(image_list): - image = io.imread(IMAGES_PATH + image_name + ".jpg") - shape = image.shape - img = resize(image, IMAGE_SHAPE, anti_aliasing=True) - slic = obtain_superpixels(img, N_SUPERPIXELS, SLIC_SIGMA) - vertices = average_rgb_for_superpixels(img, slic) - neighbors = get_neighbors(slic, N_SUPERPIXELS) - - # TO NUMPIES - img = numpy.expand_dims(numpy.array(img), axis=0) - slic = numpy.expand_dims(numpy.array(slic), axis=0) - vertices = numpy.expand_dims(numpy.array(vertices), axis=0) - neighbors = numpy.expand_dims(numpy.array(neighbors), axis=0) - - output_vertices, _ = model.predict_on_batch([img, slic, vertices, neighbors]) - - slic_out = slic - output_image = numpy.zeros(img[0].shape, dtype="uint8") - for segment_num in range(output_vertices.shape[1]): - if segment_num not in numpy.unique(slic): - break - mask = numpy.zeros(slic[0, :, :].shape + (3,), dtype="uint8") - mask[slic[0, :, :] == segment_num] = 255 * output_vertices[0, segment_num, :] - output_image += mask - - output_image = resize(output_image, shape, anti_aliasing=True) - output_image = numpy.clip(output_image * 255, 0, 255) - expected_image = io.imread(VALIDATION_IMAGES + image_name + ".png") - i = numpy.concatenate((image, expected_image, output_image), axis=1) - output = numpy.clip(i, 0, 255) - output = output.astype(numpy.uint8) - io.imsave(OUTPUT_PATH + image_name + ".png", output) + for img_batch_start in tqdm(range(int(numpy.ceil(len(image_list) / PREDICT_BATCH_SIZE)))): + batch_img = [] + batch_slic = [] + batch_vertices = [] + batch_neighbors = [] + scale_list = [] + image_names = [] + images_list = [] + for img_index in range(PREDICT_BATCH_SIZE): + real_index = PREDICT_BATCH_SIZE * img_batch_start + img_index + image_name = image_list[real_index] + if image_name is not None: + # LOAD IMAGES + image = io.imread(IMAGES_PATH + image_name + ".jpg") + images_list.append(image) + scale_list.append(image.shape) + image_names.append(image_name) + img = resize(image, IMAGE_SHAPE, anti_aliasing=True) + + # OBTAIN OTHER USEFUL DATA + slic = obtain_superpixels(img, N_SUPERPIXELS, SLIC_SIGMA) + vertices = average_rgb_for_superpixels(img, slic) + neighbors = get_neighbors(slic, N_SUPERPIXELS) + else: + img = numpy.zeros(IMAGE_SHAPE, dtype=float) + slic = numpy.zeros(SLIC_SHAPE, dtype=float) + vertices = average_rgb_for_superpixels(img, slic) + neighbors = get_neighbors(slic, N_SUPERPIXELS) + + # ADD TO BATCH + batch_img += [img] + batch_slic += [slic] + batch_vertices += [vertices] + batch_neighbors += [neighbors] + batch_img = numpy.array(batch_img) + batch_slic = numpy.array(batch_slic) + batch_vertices = numpy.array(batch_vertices) + batch_neighbors = numpy.array(batch_neighbors) + + output_vertices, _ = model.predict_on_batch([batch_img, batch_slic, batch_vertices, batch_neighbors]) + + for index, shape in enumerate(scale_list): + slic_out = batch_slic[index] + output_image = numpy.zeros(batch_img[index].shape, dtype="uint8") + for segment_num in range(output_vertices[index].shape[1]): + if segment_num not in numpy.unique(batch_slic[index]): + break + mask = numpy.zeros(batch_slic[index, :, :].shape + (3,), dtype="uint8") + mask[batch_slic[index, :, :] == segment_num] = 255 * output_vertices[index, segment_num, :] + output_image += mask + + output_image = resize(output_image, shape, anti_aliasing=True) + output_image = numpy.clip(output_image * 255, 0, 255) + expected_image = io.imread(VALIDATION_IMAGES + image_names[index] + ".png") + i = numpy.concatenate((images_list[index], expected_image, output_image), axis=1) + output = numpy.clip(i, 0, 255) + output = output.astype(numpy.uint8) + io.imsave(OUTPUT_PATH + image_names[index] + ".png", output)