Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
bnsreenu authored Aug 25, 2021
1 parent 0f952cf commit 56a59e6
Showing 1 changed file with 279 additions and 0 deletions.
279 changes: 279 additions & 0 deletions smooth_tiled_predictions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
# https://youtu.be/0W6MKZqSke8

"""
Author: Dr. Sreenivas Bhattiprolu
Original code is from the following source. It comes with MIT License so please mention
the original reference when sharing.
The original code has been modified to fix a couple of bugs and chunks of code
unnecessary for smooth tiling are removed.
# MIT License
# Copyright (c) 2017 Vooban Inc.
# Coded by: Guillaume Chevalier
# Source to original code and license:
# https://github.com/Vooban/Smoothly-Blend-Image-Patches
# https://github.com/Vooban/Smoothly-Blend-Image-Patches/blob/master/LICENSE
"""
"""Perform smooth predictions on an image from tiled prediction patches."""


import numpy as np
import scipy.signal
from tqdm import tqdm

import gc


if __name__ == '__main__':
import matplotlib.pyplot as plt
PLOT_PROGRESS = True
# See end of file for the rest of the __main__.
else:
PLOT_PROGRESS = False


def _spline_window(window_size, power=2):
"""
Squared spline (power=2) window function:
https://www.wolframalpha.com/input/?i=y%3Dx**2,+y%3D-(x-2)**2+%2B2,+y%3D(x-4)**2,+from+y+%3D+0+to+2
"""
intersection = int(window_size/4)
wind_outer = (abs(2*(scipy.signal.triang(window_size))) ** power)/2
wind_outer[intersection:-intersection] = 0

wind_inner = 1 - (abs(2*(scipy.signal.triang(window_size) - 1)) ** power)/2
wind_inner[:intersection] = 0
wind_inner[-intersection:] = 0

wind = wind_inner + wind_outer
wind = wind / np.average(wind)
return wind


cached_2d_windows = dict()
def _window_2D(window_size, power=2):
"""
Make a 1D window function, then infer and return a 2D window function.
Done with an augmentation, and self multiplication with its transpose.
Could be generalized to more dimensions.
"""
# Memoization
global cached_2d_windows
key = "{}_{}".format(window_size, power)
if key in cached_2d_windows:
wind = cached_2d_windows[key]
else:
wind = _spline_window(window_size, power)
wind = np.expand_dims(np.expand_dims(wind, 1), 1) #SREENI: Changed from 3, 3, to 1, 1
wind = wind * wind.transpose(1, 0, 2)
if PLOT_PROGRESS:
# For demo purpose, let's look once at the window:
plt.imshow(wind[:, :, 0], cmap="viridis")
plt.title("2D Windowing Function for a Smooth Blending of "
"Overlapping Patches")
plt.show()
cached_2d_windows[key] = wind
return wind


def _pad_img(img, window_size, subdivisions):
"""
Add borders to img for a "valid" border pattern according to "window_size" and
"subdivisions".
Image is an np array of shape (x, y, nb_channels).
"""
aug = int(round(window_size * (1 - 1.0/subdivisions)))
more_borders = ((aug, aug), (aug, aug), (0, 0))
ret = np.pad(img, pad_width=more_borders, mode='reflect')
# gc.collect()

if PLOT_PROGRESS:
# For demo purpose, let's look once at the window:
plt.imshow(ret)
plt.title("Padded Image for Using Tiled Prediction Patches\n"
"(notice the reflection effect on the padded borders)")
plt.show()
return ret


def _unpad_img(padded_img, window_size, subdivisions):
"""
Undo what's done in the `_pad_img` function.
Image is an np array of shape (x, y, nb_channels).
"""
aug = int(round(window_size * (1 - 1.0/subdivisions)))
ret = padded_img[
aug:-aug,
aug:-aug,
:
]
# gc.collect()
return ret


def _rotate_mirror_do(im):
"""
Duplicate an np array (image) of shape (x, y, nb_channels) 8 times, in order
to have all the possible rotations and mirrors of that image that fits the
possible 90 degrees rotations.
It is the D_4 (D4) Dihedral group:
https://en.wikipedia.org/wiki/Dihedral_group
"""
mirrs = []
mirrs.append(np.array(im))
mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=1))
mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=2))
mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=3))
im = np.array(im)[:, ::-1]
mirrs.append(np.array(im))
mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=1))
mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=2))
mirrs.append(np.rot90(np.array(im), axes=(0, 1), k=3))
return mirrs


def _rotate_mirror_undo(im_mirrs):
"""
merges a list of 8 np arrays (images) of shape (x, y, nb_channels) generated
from the `_rotate_mirror_do` function. Each images might have changed and
merging them implies to rotated them back in order and average things out.
It is the D_4 (D4) Dihedral group:
https://en.wikipedia.org/wiki/Dihedral_group
"""
origs = []
origs.append(np.array(im_mirrs[0]))
origs.append(np.rot90(np.array(im_mirrs[1]), axes=(0, 1), k=3))
origs.append(np.rot90(np.array(im_mirrs[2]), axes=(0, 1), k=2))
origs.append(np.rot90(np.array(im_mirrs[3]), axes=(0, 1), k=1))
origs.append(np.array(im_mirrs[4])[:, ::-1])
origs.append(np.rot90(np.array(im_mirrs[5]), axes=(0, 1), k=3)[:, ::-1])
origs.append(np.rot90(np.array(im_mirrs[6]), axes=(0, 1), k=2)[:, ::-1])
origs.append(np.rot90(np.array(im_mirrs[7]), axes=(0, 1), k=1)[:, ::-1])
return np.mean(origs, axis=0)


def _windowed_subdivs(padded_img, window_size, subdivisions, nb_classes, pred_func):
"""
Create tiled overlapping patches.
Returns:
5D numpy array of shape = (
nb_patches_along_X,
nb_patches_along_Y,
patches_resolution_along_X,
patches_resolution_along_Y,
nb_output_channels
)
Note:
patches_resolution_along_X == patches_resolution_along_Y == window_size
"""
WINDOW_SPLINE_2D = _window_2D(window_size=window_size, power=2)

step = int(window_size/subdivisions)
padx_len = padded_img.shape[0]
pady_len = padded_img.shape[1]
subdivs = []

for i in range(0, padx_len-window_size+1, step):
subdivs.append([])
for j in range(0, pady_len-window_size+1, step): #SREENI: Changed padx to pady (Bug in original code)
patch = padded_img[i:i+window_size, j:j+window_size, :]
subdivs[-1].append(patch)

# Here, `gc.collect()` clears RAM between operations.
# It should run faster if they are removed, if enough memory is available.
gc.collect()
subdivs = np.array(subdivs)
gc.collect()
a, b, c, d, e = subdivs.shape
subdivs = subdivs.reshape(a * b, c, d, e)
gc.collect()

subdivs = pred_func(subdivs)
gc.collect()
subdivs = np.array([patch * WINDOW_SPLINE_2D for patch in subdivs])
gc.collect()

# Such 5D array:
subdivs = subdivs.reshape(a, b, c, d, nb_classes)
gc.collect()

return subdivs


def _recreate_from_subdivs(subdivs, window_size, subdivisions, padded_out_shape):
"""
Merge tiled overlapping patches smoothly.
"""
step = int(window_size/subdivisions)
padx_len = padded_out_shape[0]
pady_len = padded_out_shape[1]

y = np.zeros(padded_out_shape)

a = 0
for i in range(0, padx_len-window_size+1, step):
b = 0
for j in range(0, pady_len-window_size+1, step): #SREENI: Changed padx to pady (Bug in original code)
windowed_patch = subdivs[a, b]
y[i:i+window_size, j:j+window_size] = y[i:i+window_size, j:j+window_size] + windowed_patch
b += 1
a += 1
return y / (subdivisions ** 2)


def predict_img_with_smooth_windowing(input_img, window_size, subdivisions, nb_classes, pred_func):
"""
Apply the `pred_func` function to square patches of the image, and overlap
the predictions to merge them smoothly.
See 6th, 7th and 8th idea here:
http://blog.kaggle.com/2017/05/09/dstl-satellite-imagery-competition-3rd-place-winners-interview-vladimir-sergey/
"""
pad = _pad_img(input_img, window_size, subdivisions)
pads = _rotate_mirror_do(pad)

# Note that the implementation could be more memory-efficient by merging
# the behavior of `_windowed_subdivs` and `_recreate_from_subdivs` into
# one loop doing in-place assignments to the new image matrix, rather than
# using a temporary 5D array.

# It would also be possible to allow different (and impure) window functions
# that might not tile well. Adding their weighting to another matrix could
# be done to later normalize the predictions correctly by dividing the whole
# reconstructed thing by this matrix of weightings - to normalize things
# back from an impure windowing function that would have badly weighted
# windows.

# For example, since the U-net of Kaggle's DSTL satellite imagery feature
# prediction challenge's 3rd place winners use a different window size for
# the input and output of the neural net's patches predictions, it would be
# possible to fake a full-size window which would in fact just have a narrow
# non-zero dommain. This may require to augment the `subdivisions` argument
# to 4 rather than 2.

res = []
for pad in tqdm(pads):
# For every rotation:
sd = _windowed_subdivs(pad, window_size, subdivisions, nb_classes, pred_func)
one_padded_result = _recreate_from_subdivs(
sd, window_size, subdivisions,
padded_out_shape=list(pad.shape[:-1])+[nb_classes])

res.append(one_padded_result)

# Merge after rotations:
padded_results = _rotate_mirror_undo(res)

prd = _unpad_img(padded_results, window_size, subdivisions)

prd = prd[:input_img.shape[0], :input_img.shape[1], :]

if PLOT_PROGRESS:
plt.imshow(prd)
plt.title("Smoothly Merged Patches that were Tiled Tighter")
plt.show()
return prd


0 comments on commit 56a59e6

Please sign in to comment.