forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvol2col.cuh
270 lines (257 loc) · 8.01 KB
/
vol2col.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
#pragma once
#include <THC/THCGeneral.h>
#include <THC/THCDeviceUtils.cuh>
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/Utils.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <c10/macros/Macros.h>
namespace at {
namespace native {
using namespace at::cuda::detail;
// Kernel for fast unfold+copy on volumes
template <typename T>
__global__ void vol2col_kernel(
const int n,
const T* data_vol,
const int depth,
const int height,
const int width,
const int ksize_t,
const int ksize_h,
const int ksize_w,
const int pad_t,
const int pad_h,
const int pad_w,
const int stride_t,
const int stride_h,
const int stride_w,
const int dilation_t,
const int dilation_h,
const int dilation_w,
const int depth_col,
const int height_col,
const int width_col,
T* data_col) {
CUDA_KERNEL_LOOP(index, n) {
int w_out = index % width_col;
index /= width_col;
int h_out = index % height_col;
index /= height_col;
int t_out = index % depth_col;
int channel_in = index / depth_col;
int channel_out = channel_in * ksize_t * ksize_h * ksize_w;
int t_in = t_out * stride_t - pad_t;
int h_in = h_out * stride_h - pad_h;
int w_in = w_out * stride_w - pad_w;
data_col +=
((channel_out * depth_col + t_out) * height_col + h_out) * width_col +
w_out;
data_vol += ((channel_in * depth + t_in) * height + h_in) * width + w_in;
for (int i = 0; i < ksize_t; ++i) {
for (int j = 0; j < ksize_h; ++j) {
for (int k = 0; k < ksize_w; ++k) {
int t = t_in + i * dilation_t;
int h = h_in + j * dilation_h;
int w = w_in + k * dilation_w;
*data_col = (t >= 0 && h >= 0 && w >= 0 && t < depth && h < height &&
w < width)
? data_vol
[i * dilation_t * height * width + j * dilation_h * width +
k * dilation_w]
: static_cast<T>(0);
data_col += depth_col * height_col * width_col;
}
}
}
}
}
template <typename T>
void vol2col(
cudaStream_t stream,
const T* data_vol,
const int channels,
const int depth,
const int height,
const int width,
const int depth_col,
const int height_col,
const int width_col,
const int ksize_t,
const int ksize_h,
const int ksize_w,
const int pad_t,
const int pad_h,
const int pad_w,
const int stride_t,
const int stride_h,
const int stride_w,
const int dilation_t,
const int dilation_h,
const int dilation_w,
T* data_col) {
// We are going to launch channels * depth_col * height_col * width_col
// kernels, each kernel responsible for copying a single-channel grid.
int num_kernels = channels * depth_col * height_col * width_col;
// Launch
vol2col_kernel<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>(
num_kernels,
data_vol,
depth,
height,
width,
ksize_t,
ksize_h,
ksize_w,
pad_t,
pad_h,
pad_w,
stride_t,
stride_h,
stride_w,
dilation_t,
dilation_h,
dilation_w,
depth_col,
height_col,
width_col,
data_col);
AT_CUDA_CHECK(cudaGetLastError());
}
template <typename T, typename accT>
__global__ void vol2im_kernel(
const unsigned n,
const T* data_col,
const unsigned depth,
const unsigned height,
const unsigned width,
const unsigned channels,
const unsigned kernel_t,
const unsigned kernel_h,
const unsigned kernel_w,
const unsigned pad_t,
const unsigned pad_h,
const unsigned pad_w,
const unsigned stride_t,
const unsigned stride_h,
const unsigned stride_w,
const unsigned dilation_t,
const unsigned dilation_h,
const unsigned dilation_w,
const unsigned depth_col,
const unsigned height_col,
const unsigned width_col,
T* data_vol) {
CUDA_KERNEL_LOOP(index, n) {
accT val = static_cast<accT>(0);
const unsigned w_im = index % width + pad_w;
const unsigned h_im = (index / width) % height + pad_h;
const unsigned t_im = (index / width / height) % depth + pad_t;
const unsigned c_im = index / (width * height * depth);
unsigned kernel_extent_w = (kernel_w - 1) * dilation_w + 1;
unsigned kernel_extent_h = (kernel_h - 1) * dilation_h + 1;
unsigned kernel_extent_t = (kernel_t - 1) * dilation_t + 1;
// compute the start and end of the output
const unsigned w_col_start =
(w_im < kernel_extent_w) ? 0 : (w_im - kernel_extent_w) / stride_w + 1;
const unsigned w_col_end = std::min(w_im / stride_w + 1, width_col);
const unsigned h_col_start =
(h_im < kernel_extent_h) ? 0 : (h_im - kernel_extent_h) / stride_h + 1;
const unsigned h_col_end = std::min(h_im / stride_h + 1, height_col);
const unsigned t_col_start =
(t_im < kernel_extent_t) ? 0 : (t_im - kernel_extent_t) / stride_t + 1;
const unsigned t_col_end = std::min(t_im / stride_t + 1, depth_col);
// TODO: use LCM of stride and dilation to avoid unnecessary loops
for (unsigned t_col = t_col_start; t_col < t_col_end; t_col += 1) {
for (unsigned h_col = h_col_start; h_col < h_col_end; h_col += 1) {
for (unsigned w_col = w_col_start; w_col < w_col_end; w_col += 1) {
unsigned t_k = (t_im - t_col * stride_t);
unsigned h_k = (h_im - h_col * stride_h);
unsigned w_k = (w_im - w_col * stride_w);
if (t_k % dilation_t == 0 && h_k % dilation_h == 0 &&
w_k % dilation_w == 0) {
t_k /= dilation_t;
h_k /= dilation_h;
w_k /= dilation_w;
const int64_t idx_k =
((c_im * kernel_t + t_k) * kernel_h + h_k) * kernel_w + w_k;
const int64_t data_col_index =
((idx_k * depth_col + t_col) *
height_col + h_col) *
width_col + w_col;
val += data_col[data_col_index];
}
}
}
}
data_vol[index] = static_cast<T>(val);
}
}
template <typename T, typename accT>
void col2vol(
cudaStream_t stream,
const T* data_col,
const int64_t channels,
const int64_t depth,
const int64_t height,
const int64_t width,
const int64_t output_depth,
const int64_t output_height,
const int64_t output_width,
const int64_t patch_t,
const int64_t patch_h,
const int64_t patch_w,
const int64_t pad_t,
const int64_t pad_h,
const int64_t pad_w,
const int64_t stride_t,
const int64_t stride_h,
const int64_t stride_w,
const int64_t dilation_t,
const int64_t dilation_h,
const int64_t dilation_w,
T* data_vol) {
const auto num_kernels = channels * depth * height * width;
auto check_fits_in_unsigned =
[](int64_t val, const char * name) {
constexpr auto umax = std::numeric_limits<unsigned>::max();
TORCH_CHECK(val >= 0 && val <= umax,
name, " must fit in a 32-bit unsigned value");
};
check_fits_in_unsigned(num_kernels, "input size");
check_fits_in_unsigned(
channels * patch_t * patch_h * patch_w, "channels x kernel size");
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
vol2im_kernel<T, accT>
<<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS, 0, stream>>>(
num_kernels,
data_col,
depth,
height,
width,
channels,
patch_t,
patch_h,
patch_w,
pad_t,
pad_h,
pad_w,
stride_t,
stride_h,
stride_w,
dilation_t,
dilation_h,
dilation_w,
output_depth,
output_height,
output_width,
data_vol);
AT_CUDA_CHECK(cudaGetLastError());
}
} // namespace native
} // namespace at