forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUpSample.h
250 lines (228 loc) · 6.96 KB
/
UpSample.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
#include <math.h>
#include <ATen/ATen.h>
namespace at {
namespace native {
// Corresponds to THNN_CHECK_DIM_SIZE
static inline void check_dim_size(
const Tensor& data,
int64_t dim,
int64_t dim_size,
int64_t size) {
/* Check dimension size of a tensor */
AT_CHECK(
data.dim() == dim && data.size(dim_size) == size,
"Expected tensor of dimension ",
dim,
" and tensor.size[",
dim_size,
"] == ",
size,
" but got: dimension ",
data.dim(),
" and tensor.size[",
dim_size,
"] = ",
data.size(dim_size));
}
static inline void upsample_1d_shape_check(
const Tensor& input,
const Tensor& grad_output,
int64_t nbatch,
int64_t nchannels,
int64_t input_width,
int64_t output_width) {
AT_CHECK(
input_width > 0 && output_width > 0,
"Input and output sizes should be greater than 0, but got input (W: ",
input_width,
") and output (W: ",
output_width,
")");
if (input.defined()) {
AT_CHECK(
input.numel() != 0 && input.dim() == 3,
"Non-empty 3D data tensor expected but got a tensor with sizes ",
input.sizes());
} else if (grad_output.defined()) {
check_dim_size(grad_output, 3, 0, nbatch);
check_dim_size(grad_output, 3, 1, nchannels);
check_dim_size(grad_output, 3, 2, output_width);
}
}
static inline void upsample_2d_shape_check(
const Tensor& input,
const Tensor& grad_output,
int64_t nbatch,
int64_t nchannels,
int64_t input_height,
int64_t input_width,
int64_t output_height,
int64_t output_width) {
AT_CHECK(
input_height > 0 && input_width > 0 && output_height > 0 &&
output_width > 0,
"Input and output sizes should be greater than 0,"
" but got input (H: ",
input_height,
", W: ",
input_width,
") output (H: ",
output_height,
", W: ",
output_width,
")");
if (input.defined()) {
AT_CHECK(
input.numel() != 0 && input.dim() == 4,
"Non-empty 4D data tensor expected but got a tensor with sizes ",
input.sizes());
} else if (grad_output.defined()) {
check_dim_size(grad_output, 4, 0, nbatch);
check_dim_size(grad_output, 4, 1, nchannels);
check_dim_size(grad_output, 4, 2, output_height);
check_dim_size(grad_output, 4, 3, output_width);
}
}
static inline void upsample_3d_shape_check(
const Tensor& input,
const Tensor& grad_output,
int64_t nbatch,
int64_t nchannels,
int64_t input_depth,
int64_t input_height,
int64_t input_width,
int64_t output_depth,
int64_t output_height,
int64_t output_width) {
AT_CHECK(
input_depth > 0 && input_height > 0 && input_width > 0 &&
output_depth > 0 && output_height > 0 && output_width > 0,
"Input and output sizes should be greater than 0, but got input (D: ",
input_depth,
", H: ",
input_height,
", W: ",
input_width,
") output (D: ",
output_depth,
", H: ",
output_height,
", W: ",
output_width,
")");
if (input.defined()) {
AT_CHECK(
input.numel() != 0 && input.dim() == 5,
"Non-empty 5D data tensor expected but got a tensor with sizes ",
input.sizes());
} else if (grad_output.defined()) {
check_dim_size(grad_output, 5, 0, nbatch);
check_dim_size(grad_output, 5, 1, nchannels);
check_dim_size(grad_output, 5, 2, output_depth);
check_dim_size(grad_output, 5, 3, output_height);
check_dim_size(grad_output, 5, 4, output_width);
}
}
template <typename scalar_t>
static inline scalar_t linear_upsample_compute_scale(
int64_t input_size,
int64_t output_size,
bool align_corners) {
/* We view each pixel as an area, idx + 0.5 as its center index.
* Here is an example formula in 1D case.
* if align_corners: center of two corner pixel areas are preserved,
* (0.5, 0.5) -> (0.5, 0.5),
* (input_size - 0.5, 0.5) -> (output_size - 0.5)
* scale = (input_size - 0.5 - 0.5) / (output_size - 0.5 - 0.5)
* src_index + 0.5 - 0.5 = scale * (dst_index + 0.5 - 0.5)
* if not align_corners: the whole range is scaled accordingly
* scale = input_size / output_size
* src_idx + 0.5 = scale * (dst_index + 0.5)
*/
if (output_size > 1) {
return align_corners
? static_cast<scalar_t>(input_size - 1) / (output_size - 1)
: static_cast<scalar_t>(input_size) / output_size;
} else {
return scalar_t(0);
}
}
template <typename scalar_t>
static inline scalar_t linear_upsample_compute_source_index(
scalar_t scale,
int64_t dst_index,
bool align_corners) {
if (align_corners) {
return scale * dst_index;
} else {
scalar_t src_idx = scale * (dst_index + 0.5) - 0.5;
return src_idx < 0 ? scalar_t(0) : src_idx;
}
}
static inline int64_t nearest_neighbor_compute_source_index(
const float scale,
int64_t dst_index,
int64_t input_size) {
const int64_t src_index =
std::min(static_cast<int64_t>(floorf(dst_index * scale)), input_size - 1);
return src_index;
}
template <typename scalar_t>
static scalar_t upsample_get_value_bounded(
scalar_t* data,
int64_t width,
int64_t height,
int64_t x,
int64_t y) {
int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
return data[access_y * width + access_x];
}
template <typename scalar_t>
static void upsample_increment_value_bounded(
scalar_t* data,
int64_t width,
int64_t height,
int64_t x,
int64_t y,
scalar_t value) {
int64_t access_x = std::max(std::min(x, width - 1), static_cast<int64_t>(0));
int64_t access_y = std::max(std::min(y, height - 1), static_cast<int64_t>(0));
data[access_y * width + access_x] += value;
}
// Based on
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
template <typename scalar_t>
static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) {
return ((A + 2) * x - (A + 3)) * x * x + 1;
}
template <typename scalar_t>
static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) {
return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A;
}
template <typename scalar_t>
static inline void get_cubic_upsample_coefficients(
scalar_t coeffs[4],
scalar_t t) {
scalar_t A = -0.75;
scalar_t x1 = t;
coeffs[0] = cubic_convolution2<scalar_t>(x1 + 1.0, A);
coeffs[1] = cubic_convolution1<scalar_t>(x1, A);
// opposite coefficients
scalar_t x2 = 1.0 - t;
coeffs[2] = cubic_convolution1<scalar_t>(x2, A);
coeffs[3] = cubic_convolution2<scalar_t>(x2 + 1.0, A);
}
template <typename scalar_t>
static inline scalar_t cubic_interp1d(
scalar_t x0,
scalar_t x1,
scalar_t x2,
scalar_t x3,
scalar_t t) {
scalar_t coeffs[4];
get_cubic_upsample_coefficients<scalar_t>(coeffs, t);
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
} // namespace native
} // namespace at