-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpatches.py
executable file
·100 lines (83 loc) · 4.8 KB
/
patches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
def extract_patches_2d(img, patch_shape, step=[1.0,1.0], batch_first=False):
"""
img - the image to be patched. The image is a tensor of shape (1,batch,C,H,W)
patch_shape - the size of the patches the image should be patched to
step - the height and the width of the center crop of each patch that will be
used in the reconstruction.
Example image =torch(batch,3,4000,6000); patch_shape=[256,256]; step=[236,236]
return torch(batch,num_patches,3,256,256)
"""
patch_H, patch_W = patch_shape[0], patch_shape[1]
if(img.size(2)<patch_H):
num_padded_H_Top = (patch_H - img.size(2))//2
num_padded_H_Bottom = patch_H - img.size(2) - num_padded_H_Top
padding_H = nn.ConstantPad2d((0,0,num_padded_H_Top,num_padded_H_Bottom),0)
img = padding_H(img)
if(img.size(3)<patch_W):
num_padded_W_Left = (patch_W - img.size(3))//2
num_padded_W_Right = patch_W - img.size(3) - num_padded_W_Left
padding_W = nn.ConstantPad2d((num_padded_W_Left,num_padded_W_Right,0,0),0)
img = padding_W(img)
step_int = [0,0]
step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0]
step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1]
patches_fold_H = img.unfold(2, patch_H, step_int[0])
if((img.size(2) - patch_H) % step_int[0] != 0):
patches_fold_H = torch.cat((patches_fold_H,img[:,:,-patch_H:,].permute(0,1,3,2).unsqueeze(2)),dim=2)
patches_fold_HW = patches_fold_H.unfold(3, patch_W, step_int[1])
if((img.size(3) - patch_W) % step_int[1] != 0):
patches_fold_HW = torch.cat((patches_fold_HW,patches_fold_H[:,:,:,-patch_W:,:].permute(0,1,2,4,3).unsqueeze(3)),dim=3)
patches = patches_fold_HW.permute(2,3,0,1,4,5)
patches = patches.reshape(-1,img.size(0),img.size(1),patch_H,patch_W)
if(batch_first):
patches = patches.permute(1,0,2,3,4)
return patches
def reconstruct_from_patches_2d(patches,img_shape,step=[1.0,1.0],batch_first=False):
"""
patches - a stack of patches to be stiched together during the reconstruction of the image
patches size (batch,#,C,patch_H, patch_W)
img_shape - the size the output image should be
step - the height and the width of the center crop of each patch that will be
used in the reconstruction.
Example patches=torch(batch,220,3,256,256); img_shape =[4000,6000]; step=[236,236]
returns torch(batch,3,4000,6000)
"""
if(batch_first):
patches = patches.permute(1,0,2,3,4)
patch_H, patch_W = patches.size(3), patches.size(4)
img_size = (patches.size(1), patches.size(2),max(img_shape[0], patch_H), max(img_shape[1], patch_W))
step_int = [0,0]
step_int[0] = int(patch_H*step[0]) if(isinstance(step[0], float)) else step[0]
step_int[1] = int(patch_W*step[1]) if(isinstance(step[1], float)) else step[1]
nrow, ncol = 1 + (img_size[-2] - patch_H)//step_int[0], 1 + (img_size[-1] - patch_W)//step_int[1]
r_nrow = nrow + 1 if((img_size[2] - patch_H) % step_int[0] != 0) else nrow
r_ncol = ncol + 1 if((img_size[3] - patch_W) % step_int[1] != 0) else ncol
patches = patches.reshape(r_nrow,r_ncol,img_size[0],img_size[1],patch_H,patch_W)
img = torch.zeros(img_size, device = patches.device)
overlap_counter = torch.zeros(img_size, device = patches.device)
for i in range(nrow):
for j in range(ncol):
img[:,:,i*step_int[0]:i*step_int[0]+patch_H,j*step_int[1]:j*step_int[1]+patch_W] += patches[i,j,]
overlap_counter[:,:,i*step_int[0]:i*step_int[0]+patch_H,j*step_int[1]:j*step_int[1]+patch_W] += 1
if((img_size[2] - patch_H) % step_int[0] != 0):
for j in range(ncol):
img[:,:,-patch_H:,j*step_int[1]:j*step_int[1]+patch_W] += patches[-1,j,]
overlap_counter[:,:,-patch_H:,j*step_int[1]:j*step_int[1]+patch_W] += 1
if((img_size[3] - patch_W) % step_int[1] != 0):
for i in range(nrow):
img[:,:,i*step_int[0]:i*step_int[0]+patch_H,-patch_W:] += patches[i,-1,]
overlap_counter[:,:,i*step_int[0]:i*step_int[0]+patch_H,-patch_W:] += 1
if((img_size[2] - patch_H) % step_int[0] != 0 and (img_size[3] - patch_W) % step_int[1] != 0):
img[:,:,-patch_H:,-patch_W:] += patches[-1,-1,]
overlap_counter[:,:,-patch_H:,-patch_W:] += 1
img /= overlap_counter
if(img_shape[0]<patch_H):
num_padded_H_Top = (patch_H - img_shape[0])//2
num_padded_H_Bottom = patch_H - img_shape[0] - num_padded_H_Top
img = img[:,:,num_padded_H_Top:-num_padded_H_Bottom,]
if(img_shape[1]<patch_W):
num_padded_W_Left = (patch_W - img_shape[1])//2
num_padded_W_Right = patch_W - img_shape[1] - num_padded_W_Left
img = img[:,:,:,num_padded_W_Left:-num_padded_W_Right]
return img