-
Notifications
You must be signed in to change notification settings - Fork 1
/
gen_patches.py
67 lines (60 loc) · 2.35 KB
/
gen_patches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import random
import numpy as np
def get_rand_patch(img, mask, sz=160):
"""
:param img: ndarray with shape (x_sz, y_sz, num_channels)
:param mask: binary ndarray with shape (x_sz, y_sz, num_classes)
:param sz: size of random patch
:return: patch with shape (sz, sz, num_channels)
"""
assert len(img.shape) == 3 and img.shape[0] > sz and img.shape[1] > sz and img.shape[0:2] == mask.shape[0:2]
xc = random.randint(0, img.shape[0] - sz)
yc = random.randint(0, img.shape[1] - sz)
patch_img = img[xc:(xc + sz), yc:(yc + sz)]
patch_mask = mask[xc:(xc + sz), yc:(yc + sz)]
# Apply some random transformations
random_transformation = np.random.randint(1,8)
if random_transformation == 1: # reverse first dimension
patch_img = patch_img[::-1,:,:]
patch_mask = patch_mask[::-1,:,:]
elif random_transformation == 2: # reverse second dimension
patch_img = patch_img[:,::-1,:]
patch_mask = patch_mask[:,::-1,:]
elif random_transformation == 3: # transpose(interchange) first and second dimensions
#remove transpose
# patch_img = patch_img.transpose([1,0,2])
# patch_mask = patch_mask.transpose([1,0,2])
patch_img = patch_img
patch_mask = patch_mask
elif random_transformation == 4:
# remove 90 degree rotation
# patch_img = np.rot90(patch_img, 1)
# patch_mask = np.rot90(patch_mask, 1)
patch_img = patch_img
patch_mask = patch_mask
elif random_transformation == 5:
patch_img = np.rot90(patch_img, 2)
patch_mask = np.rot90(patch_mask, 2)
elif random_transformation == 6:
#remove 270 degree rotation
# patch_img = np.rot90(patch_img, 3)
# patch_mask = np.rot90(patch_mask, 3)
patch_img = patch_img
patch_mask = patch_mask
else:
pass
return patch_img, patch_mask
def get_patches(x_dict, y_dict, n_patches, sz=160):
x = list()
y = list()
total_patches = 0
while total_patches < n_patches:
img_id = random.sample(x_dict.keys(), 1)[0]
img = x_dict[img_id]
mask = y_dict[img_id]
img_patch, mask_patch = get_rand_patch(img, mask, sz)
x.append(img_patch)
y.append(mask_patch)
total_patches += 1
print('Generated {} patches'.format(total_patches))
return np.array(x), np.array(y)