forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathUpSampleNearest1d.cpp
222 lines (187 loc) · 5.55 KB
/
UpSampleNearest1d.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/UpSample.h>
namespace at {
namespace native {
namespace {
template <typename scalar_t>
static void upsample_nearest1d_out_frame(
scalar_t* odata,
scalar_t* idata,
int64_t input_width,
int64_t output_width,
int64_t nbatch,
int64_t channels) {
const float scale = (float)input_width / (float)output_width;
channels = channels * nbatch;
// special case: just copy
if (input_width == output_width) {
for (int64_t w2 = 0; w2 < output_width; ++w2) {
const int64_t w1 = w2;
const scalar_t* pos1 = &idata[w1];
scalar_t* pos2 = &odata[w2];
for (int64_t c = 0; c < channels; ++c) {
pos2[0] = pos1[0];
pos1 += input_width;
pos2 += output_width;
}
}
return;
}
for (int64_t w2 = 0; w2 < output_width; ++w2) {
const scalar_t src_x =
nearest_neighbor_compute_source_index(scale, w2, input_width);
const int64_t w1 = src_x;
const scalar_t* pos1 = &idata[w1];
scalar_t* pos2 = &odata[w2];
for (int64_t c = 0; c < channels; ++c) {
pos2[0] = pos1[0];
pos1 += input_width;
pos2 += output_width;
}
}
}
template <typename scalar_t>
static void upsample_nearest1d_backward_out_frame(
scalar_t* odata,
scalar_t* idata,
int64_t input_width,
int64_t output_width,
int64_t nbatch,
int64_t channels) {
const float scale = (float)input_width / (float)output_width;
channels = channels * nbatch;
// special case: same-size matching grids
if (input_width == output_width) {
for (int64_t w2 = 0; w2 < output_width; ++w2) {
const int64_t w1 = w2;
scalar_t* pos1 = &idata[w1];
const scalar_t* pos2 = &odata[w2];
for (int64_t c = 0; c < channels; ++c) {
pos1[0] += pos2[0];
pos1 += input_width;
pos2 += output_width;
}
}
return;
}
for (int64_t w2 = 0; w2 < output_width; ++w2) {
const int64_t w1 =
nearest_neighbor_compute_source_index(scale, w2, input_width);
scalar_t* pos1 = &idata[w1];
const scalar_t* pos2 = &odata[w2];
for (int64_t c = 0; c < channels; ++c) {
pos1[0] += pos2[0];
pos1 += input_width;
pos2 += output_width;
}
}
}
static void upsample_nearest1d_out_cpu_template(
Tensor& output,
const Tensor& input_,
IntArrayRef output_size) {
AT_CHECK(
output_size.size() == 1,
"It is expected output_size equals to 1, but got size ",
output_size.size());
int64_t output_width = output_size[0];
int64_t nbatch = input_.size(0);
int64_t channels = input_.size(1);
int64_t input_width = input_.size(2);
upsample_1d_shape_check(
input_,
Tensor(),
nbatch,
channels,
input_width,
output_width);
auto input = input_.contiguous();
output.resize_({nbatch, channels, output_width});
output.zero_();
AT_ASSERT(input_width > 0 && output_width > 0);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "upsample_nearest1d", [&] {
auto* idata = input.data<scalar_t>();
auto* odata = output.data<scalar_t>();
upsample_nearest1d_out_frame<scalar_t>(
odata,
idata,
input_width,
output_width,
nbatch,
channels);
});
}
static void upsample_nearest1d_backward_out_cpu_template(
Tensor& grad_input,
const Tensor& grad_output_,
IntArrayRef output_size,
IntArrayRef input_size) {
AT_CHECK(
output_size.size() == 1,
"It is expected output_size equals to 1, but got size ",
output_size.size());
AT_CHECK(
input_size.size() == 3,
"It is expected input_size equals to 3, but got size ",
input_size.size());
int64_t output_width = output_size[0];
int64_t nbatch = input_size[0];
int64_t channels = input_size[1];
int64_t input_width = input_size[2];
upsample_1d_shape_check(
Tensor(),
grad_output_,
nbatch,
channels,
input_width,
output_width);
auto grad_output = grad_output_.contiguous();
grad_input.resize_({nbatch, channels, input_width});
grad_input.zero_();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_output.scalar_type(), "upsample_nearest1d_backward", [&] {
scalar_t* idata = grad_input.data<scalar_t>();
scalar_t* odata = grad_output.data<scalar_t>();
upsample_nearest1d_backward_out_frame<scalar_t>(
odata,
idata,
input_width,
output_width,
nbatch,
channels);
});
}
} // namespace
Tensor& upsample_nearest1d_out_cpu(
Tensor& output,
const Tensor& input,
IntArrayRef output_size) {
upsample_nearest1d_out_cpu_template(output, input, output_size);
return output;
}
Tensor upsample_nearest1d_cpu(const Tensor& input, IntArrayRef output_size) {
auto output = at::empty({0}, input.options());
upsample_nearest1d_out_cpu_template(output, input, output_size);
return output;
}
Tensor& upsample_nearest1d_backward_out_cpu(
Tensor& grad_input,
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size) {
upsample_nearest1d_backward_out_cpu_template(
grad_input, grad_output, output_size, input_size);
return grad_input;
}
Tensor upsample_nearest1d_backward_cpu(
const Tensor& grad_output,
IntArrayRef output_size,
IntArrayRef input_size) {
auto grad_input = at::zeros(input_size, grad_output.options());
upsample_nearest1d_backward_out_cpu_template(
grad_input, grad_output, output_size, input_size);
return grad_input;
}
} // namespace native
} // namespace at