Skip to content

Commit

Permalink
Merge pull request #2 from rishikksh20/pytest
Browse files Browse the repository at this point in the history
Pytest
  • Loading branch information
rishikksh20 authored Aug 31, 2020
2 parents 88488a2 + 67f5629 commit 01566c9
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 0 deletions.
Empty file added tests/__init__.py
Empty file.
20 changes: 20 additions & 0 deletions tests/test_res_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from core.res_unet import ResUnet, ResidualConv, Upsample

def test_resunet():
img = torch.ones(1, 3, 224, 224)
resunet = ResUnet(3)
assert resunet(img).shape == torch.Size([1, 1, 224, 224])


def test_residual_conv():
x = torch.ones(1, 64, 224, 224)
res_conv = ResidualConv(64, 128, 2, 1)
assert res_conv(x).shape == torch.Size([1, 128, 112, 112])


def test_upsample():
x = torch.ones(1, 512, 28, 28)
upsample = Upsample(512, 512, 2, 2)
assert upsample(x).shape == torch.Size([1, 512, 56, 56])

112 changes: 112 additions & 0 deletions utils/augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import warnings

warnings.simplefilter("ignore", UserWarning)

from skimage import transform
from torchvision import transforms

import numpy as np
import torch


class RescaleTarget(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""

def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, tuple):
self.output_size = int(np.random.uniform(output_size[0], output_size[1]))
else:
self.output_size = output_size

def __call__(self, sample):
sat_img, map_img = sample["sat_img"], sample["map_img"]

h, w = sat_img.shape[:2]

if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h

new_h, new_w = int(new_h), int(new_w)

# change the range to 0-1 rather than 0-255, makes it easier to use sigmoid later
sat_img = transform.resize(sat_img, (new_h, new_w))

map_img = transform.resize(map_img, (new_h, new_w))

return {"sat_img": sat_img, "map_img": map_img}


class RandomRotationTarget(object):
"""Rotate the image and target randomly in a sample.
Args:
degrees (tuple or int): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees).
resize (boolean): Expand the image to fit
"""

def __init__(self, degrees, resize=False):
if isinstance(degrees, int):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
if isinstance(degrees, tuple):
raise ValueError("Degrees needs to be either an int or tuple")
self.degrees = degrees

assert isinstance(resize, bool)

self.resize = resize
self.angle = np.random.uniform(self.degrees[0], self.degrees[1])

def __call__(self, sample):

sat_img = transform.rotate(sample["sat_img"], self.angle, self.resize)
map_img = transform.rotate(sample["map_img"], self.angle, self.resize)

return {"sat_img": sat_img, "map_img": map_img}


class RandomCropTarget(object):
"""
Crop the image and target randomly in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop
is made.
"""

def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size

def __call__(self, sample):

sat_img, map_img = sample["sat_img"], sample["map_img"]

h, w = sat_img.shape[:2]
new_h, new_w = self.output_size

top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)

sat_img = sat_img[top : top + new_h, left : left + new_w]
map_img = map_img[top : top + new_h, left : left + new_w]

return {"sat_img": sat_img, "map_img": map_img}

0 comments on commit 01566c9

Please sign in to comment.