forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpool_op_rtc_gpu.cc
313 lines (283 loc) · 8.63 KB
/
pool_op_rtc_gpu.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
#include <cstdio>
#include "caffe2/core/common_gpu.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/pool_op.h"
#include "caffe2/cuda_rtc/common_rtc.h"
namespace caffe2 {
namespace {
class AveragePool {};
class MaxPool {};
} // namespace
namespace {
// The max pool forward function, with parameters written in const int.
const char kMaxPoolForwardNCHWSource[] = R"(
extern "C"
__global__ void %s(const float* bottom_data, float* top_data) {
const int nthreads = %d;
const int channels = %d;
const int height = %d;
const int width = %d;
const int pooled_height = %d;
const int pooled_width = %d;
const int kernel_h = %d;
const int kernel_w = %d;
const int stride_h = %d;
const int stride_w = %d;
const int pad_t = %d;
const int pad_l = %d;
for (int index = blockIdx.x * blockDim.x + threadIdx.x;
index < nthreads; index += blockDim.x * gridDim.x) {
int pw = index %% pooled_width;
int ph = (index / pooled_width) %% pooled_height;
int c = (index / (pooled_width * pooled_height)) %% channels;
int n = index / (pooled_width * pooled_height * channels);
int hstart = ph * stride_h - pad_t;
int wstart = pw * stride_w - pad_l;
int hend = min(hstart + kernel_h, height);
int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
float maxval = -1.0e37f;
const float* bdata_offset = bottom_data + n * channels * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
maxval = fmaxf(
bdata_offset[c * height * width + h * width + w], maxval);
}
}
top_data[index] = maxval;
}
}
)";
// The max pool forward function, with parameters written in const int.
const char kMaxPoolBackwardNCHWSource[] = R"(
extern "C"
__global__ void %s(
const float* const bottom_data, const float* const top_data,
const float* const top_diff, float* const bottom_diff) {
const int nthreads = %d;
const int num = %d;
const int channels = %d;
const int height = %d;
const int width = %d;
const int pooled_height = %d;
const int pooled_width = %d;
const int kernel_h = %d;
const int kernel_w = %d;
const int stride_h = %d;
const int stride_w = %d;
const int pad_t = %d;
const int pad_l = %d;
for (int index = blockIdx.x * blockDim.x + threadIdx.x;
index < nthreads; index += blockDim.x * gridDim.x) {
const int w = index %% width + pad_l;
const int h = (index / width) %% height + pad_t;
const int c = (index / width / height) %% channels;
const int n = index / width / height / channels;
const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
const int phend = min(h / stride_h + 1, pooled_height);
const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
const int pwend = min(w / stride_w + 1, pooled_width);
const int top_offset =
(n * channels + c) * pooled_height * pooled_width;
bottom_diff[index] = 0;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int top_local_offset = top_offset + ph * pooled_width + pw;
if (bottom_data[index] == top_data[top_local_offset]) {
bottom_diff[index] += top_diff[top_local_offset];
}
}
}
}
}
)";
class MaxPoolRTCFunction : public CudaRTCFunction<MaxPoolRTCFunction> {
public:
MaxPoolRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {}
template <typename... Args>
string KernelName(Args... /*args*/) {
return name_;
}
template <typename... Args>
string GetSource(Args... args);
private:
string name_;
};
class MaxPoolGradientRTCFunction
: public CudaRTCFunction<MaxPoolGradientRTCFunction> {
public:
MaxPoolGradientRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {}
template <typename... Args>
string KernelName(Args... /*args*/) {
return name_;
}
template <typename... Args>
string GetSource(Args... args);
private:
string name_;
};
template <>
string MaxPoolRTCFunction::GetSource(
const int output_size,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_t,
const int pad_l) {
char buffer[65536];
int nbytes = snprintf(
buffer, 65536, kMaxPoolForwardNCHWSource, name_.c_str(), output_size,
channels, height, width, pooled_height, pooled_width, kernel_h, kernel_w,
stride_h, stride_w, pad_t, pad_l);
DCHECK_GE(nbytes, 0);
DCHECK_LT(nbytes, 65536);
return string(buffer);
}
template <>
string MaxPoolGradientRTCFunction::GetSource(
const int output_size,
const int num,
const int channels,
const int height,
const int width,
const int pooled_height,
const int pooled_width,
const int kernel_h,
const int kernel_w,
const int stride_h,
const int stride_w,
const int pad_t,
const int pad_l) {
char buffer[65536];
int nbytes = snprintf(
buffer, 65536, kMaxPoolBackwardNCHWSource, name_.c_str(), output_size,
num, channels, height, width, pooled_height, pooled_width, kernel_h,
kernel_w, stride_h, stride_w, pad_t, pad_l);
DCHECK_GE(nbytes, 0);
DCHECK_LT(nbytes, 65536);
return string(buffer);
}
} // namespace
class MaxPoolRTCOp final : public ConvPoolOpBase<CUDAContext> {
public:
MaxPoolRTCOp(const OperatorDef& operator_def, Workspace* ws)
: ConvPoolOpBase<CUDAContext>(operator_def, ws) {
CAFFE_ENFORCE_EQ(
order_, StorageOrder::NCHW, "Currently only NCHW is supported.");
}
~MaxPoolRTCOp() override {}
bool RunOnDeviceWithOrderNCHW() override {
auto& X = Input(0);
auto output_sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, X.dim32(1));
auto* Y = Output(0, output_sizes, at::dtype<float>());
if (input_dims_ != X.sizes()) {
// recompile
VLOG(1) << "MaxPool RTC recompiling";
CAFFE_ENFORCE_LT(Y->numel(), std::numeric_limits<int>::max());
func_.Compile(
static_cast<int>(Y->numel()),
X.dim32(1),
X.dim32(2),
X.dim32(3),
Y->dim32(2),
Y->dim32(3),
kernel_h(),
kernel_w(),
stride_h(),
stride_w(),
pad_t(),
pad_l());
input_dims_ = X.sizes().vec();
}
// Carry out the pooling computation.
func_.Launch(
CAFFE_GET_BLOCKS(Y->numel()),
1,
1,
CAFFE_CUDA_NUM_THREADS,
1,
1,
0,
context_.cuda_stream(),
X.data<float>(),
Y->mutable_data<float>());
return true;
}
bool RunOnDeviceWithOrderNHWC() override {
LOG(FATAL) << "Not implemented.";
return false;
}
private:
MaxPoolRTCFunction func_;
vector<int64_t> input_dims_;
};
class MaxPoolGradientRTCOp final : public ConvPoolOpBase<CUDAContext> {
public:
MaxPoolGradientRTCOp(const OperatorDef& operator_def, Workspace* ws)
: ConvPoolOpBase<CUDAContext>(operator_def, ws) {
CAFFE_ENFORCE_EQ(
order_, StorageOrder::NCHW, "Currently only NCHW is supported.");
}
~MaxPoolGradientRTCOp() override {}
bool RunOnDeviceWithOrderNCHW() override {
auto& X = Input(0);
auto& Y = Input(1);
auto& dY = Input(2);
CAFFE_ENFORCE_EQ(dY.dim(), 4);
auto* dX = Output(0, X.sizes(), at::dtype<float>());
ConvPoolOpBase<CUDAContext>::ComputePads({X.dim32(2), X.dim32(3)});
if (input_dims_ != X.sizes()) {
VLOG(1) << "MaxPoolGradient RTC recompiling";
CAFFE_ENFORCE_LT(X.numel(), std::numeric_limits<int>::max());
func_.Compile(
static_cast<int>(X.numel()),
X.dim32(0),
X.dim32(1),
X.dim32(2),
X.dim32(3),
dY.dim32(2),
dY.dim32(3),
kernel_h(),
kernel_w(),
stride_h(),
stride_w(),
pad_t(),
pad_l());
input_dims_ = X.sizes().vec();
}
func_.Launch(
CAFFE_GET_BLOCKS(X.numel()),
1,
1,
CAFFE_CUDA_NUM_THREADS,
1,
1,
0,
context_.cuda_stream(),
X.data<float>(),
Y.data<float>(),
dY.data<float>(),
dX->mutable_data<float>());
return true;
}
bool RunOnDeviceWithOrderNHWC() override {
LOG(FATAL) << "Not implemented.";
return false;
}
private:
MaxPoolGradientRTCFunction func_;
vector<int64_t> input_dims_;
};
namespace {
REGISTER_CUDA_OPERATOR_WITH_ENGINE(MaxPool, NVRTC, MaxPoolRTCOp);
REGISTER_CUDA_OPERATOR_WITH_ENGINE(MaxPoolGradient, NVRTC,
MaxPoolGradientRTCOp);
} // namespace
} // namespace caffe2