forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathScanKernels.cu
666 lines (594 loc) · 25.6 KB
/
ScanKernels.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
#include <ATen/Dispatch.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/NumericLimits.cuh>
#include <THC/THCNumerics.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCGeneral.h>
#include <THC/THCThrustAllocator.cuh>
#include <thrust/execution_policy.h>
#include <thrust/device_ptr.h>
#include <thrust/scan.h>
#include <cub/device/device_scan.cuh>
namespace at { namespace native {
template <typename integer>
constexpr inline integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
}
template<typename scalar_t, typename idx_t, typename BinaryOperation>
__device__ void binary_op_update(const scalar_t lhs, scalar_t& rhs, const idx_t lhs_idx, idx_t& rhs_idx, BinaryOperation binary_op) {
if(!THCNumerics<scalar_t>::isnan(rhs) && (THCNumerics<scalar_t>::isnan(lhs) || !binary_op(rhs, lhs))) {
rhs = lhs;
rhs_idx = lhs_idx;
}
}
/* Perform an inclusive scan along the innermost dimension of a tensor.
*
* - num_rows is the size of the flattened outer dimensions;
* - row_size is the size of the innermost dimension;
*
* The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
* considered as having 'num_rows' rows of size 'row_size'.
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
* per thread block is quicker than processing a single row, especially for short rows).
*/
template<typename scalar_t, int num_threads_x, int num_threads_y, class BinaryFunction>
__global__ void tensor_kernel_scan_innermost_dim_with_indices(const scalar_t *self_, scalar_t *values_, int64_t *indices_,
int num_rows, int row_size,
scalar_t init, BinaryFunction binary_op) {
__shared__ scalar_t vbuf[num_threads_y][2 * num_threads_x];
__shared__ int64_t ibuf[num_threads_y][2 * num_threads_x];
scalar_t* row_buf = vbuf[threadIdx.y];
int64_t* row_idx_buf = ibuf[threadIdx.y];
for (int block_row = blockIdx.x * blockDim.y;
block_row < num_rows;
block_row += blockDim.y * gridDim.x) {
int row = block_row + threadIdx.y;
const scalar_t *row_self = self_ + row * row_size;
scalar_t *row_values = values_ + row * row_size;
int64_t *row_indices = indices_ + row * row_size;
scalar_t block_total = init;
int64_t block_idx_final = 0;
// Perform scan on one block at a time, keeping track of the total value of
// all blocks processed so far.
for (int block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
// Load data into shared memory (two values per thread).
int col1 = block_col + threadIdx.x;
int col2 = block_col + num_threads_x + threadIdx.x;
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_self[col1];
row_idx_buf[threadIdx.x] = col1;
} else {
row_buf[threadIdx.x] = init;
// No need to set the index here as the value in init will never be selected
}
if (col2 < row_size) {
row_buf[num_threads_x + threadIdx.x] = row_self[col2];
row_idx_buf[num_threads_x + threadIdx.x] = col2;
} else {
row_buf[num_threads_x + threadIdx.x] = init;
// No need to set the index here as the value in init will never be selected
}
// Add the total value of all previous blocks to the first value of this block.
if (threadIdx.x == 0) {
binary_op_update(block_total, row_buf[0], block_idx_final, row_idx_buf[0], binary_op);
}
}
__syncthreads();
// Parallel reduction (up-sweep).
for (int s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
int offset = (2 * threadIdx.x + 1) * d - 1;
binary_op_update(row_buf[offset], row_buf[offset + d], row_idx_buf[offset], row_idx_buf[offset + d], binary_op);
}
__syncthreads();
}
// Down-sweep.
for (int s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
int offset = 2 * (threadIdx.x + 1) * d - 1;
binary_op_update(row_buf[offset], row_buf[offset + d], row_idx_buf[offset], row_idx_buf[offset + d], binary_op);
}
__syncthreads();
}
// Write back to output.
if (row < num_rows) {
if (col1 < row_size){
row_values[col1] = row_buf[threadIdx.x];
row_indices[col1] = row_idx_buf[threadIdx.x];
}
if (col2 < row_size) {
row_values[col2] = row_buf[num_threads_x + threadIdx.x];
row_indices[col2] = row_idx_buf[num_threads_x + threadIdx.x];
}
}
block_total = row_buf[2 * num_threads_x - 1];
block_idx_final = row_idx_buf[2 * num_threads_x - 1];
__syncthreads();
}
}
}
/* Perform an inclusive scan along an outer dimension of a tensor.
*
* - num_orows is the size of the flattened outer dimensions;
* - num_irows is the size of the flattened inner dimensions;
* - row_size is the size of the dimension along which to compute the variance;
*
* The dimensions to the outside and inside of the specified dimension are considered as flattened.
* Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
* outer dimensions, which contains several "inner rows").
* Each thread processes a single inner row at a time.
*/
template<typename scalar_t, class BinaryFunction>
__global__ void tensor_kernel_scan_outer_dim_with_indices(scalar_t *self_, scalar_t *values_, int64_t *indices_,
int num_orows, int num_irows, int row_size, scalar_t init, BinaryFunction binary_op) {
for (int orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (int irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
scalar_t *self = self_ + orow * row_size * num_irows + irow;
scalar_t *values = values_ + orow * row_size * num_irows + irow;
int64_t *indices = indices_ + orow * row_size * num_irows + irow;
scalar_t out = init;
int64_t out_idx = 0;
for (int64_t col = 0; col < row_size; ++col) {
if(THCNumerics<scalar_t>::isnan(*self) || (!THCNumerics<scalar_t>::isnan(out) && binary_op(*self, out))) {
out = *self;
out_idx = col;
}
*values = out;
*indices = out_idx;
self += num_irows;
values += num_irows;
indices += num_irows;
}
}
}
}
template<typename scalar_t, class BinaryFunction>
__host__ void scan_outer_dim_with_indices(const Tensor& self, Tensor& values, Tensor& indices,
int dim, scalar_t init, BinaryFunction binary_op) {
int row_size = self.size(dim);
auto sizes = self.sizes();
// Treat all outer dimensions (i.e. dim_ < dim) as one.
int num_orows = std::accumulate(sizes.begin(), sizes.begin() + dim, 1, std::multiplies<int>());
// Treat all inner dimensions (i.e. dim > dimension) as one.
int num_irows = std::accumulate(sizes.begin() + dim + 1, sizes.end(), 1, std::multiplies<int>());
dim3 threads(std::min(512, int(num_irows)));
int maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int(threads.x))));
tensor_kernel_scan_outer_dim_with_indices<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
self.data_ptr<scalar_t>(), values.data_ptr<scalar_t>(), indices.data_ptr<int64_t>(),
num_orows, num_irows, row_size, init, binary_op);
AT_CUDA_CHECK(cudaGetLastError());
}
template <typename scalar_t, class BinaryFunction>
__host__ void scan_innermost_dim_with_indices(const Tensor& self, Tensor& values, Tensor& indices, scalar_t init, BinaryFunction binary_op) {
int ndim = self.dim();
// Treat all outer dimensions as a single dimension.
int row_size = self.size(ndim - 1);
int num_rows = self.numel() / row_size;
dim3 threads(16, 32);
dim3 grid(std::min(at::cuda::getCurrentDeviceProperties()->maxGridSize[0], ceil_div(num_rows, int(threads.y))));
tensor_kernel_scan_innermost_dim_with_indices<scalar_t, 16, 32><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
self.data_ptr<scalar_t>(), values.data_ptr<scalar_t>(), indices.data_ptr<int64_t>(),
num_rows, row_size, init, binary_op);
AT_CUDA_CHECK(cudaGetLastError());
}
template<typename scalar_t, typename BinaryFunction>
void scan_dim_with_indices(const Tensor& self, Tensor& values, Tensor& indices, //int64_t dim) {
int64_t dim, scalar_t init, BinaryFunction binary_op) {
int ndim = self.dim();
Tensor self_ = self.contiguous();
Tensor values_ = values.contiguous();
Tensor indices_ = indices.contiguous();
bool copy_values = !values.is_contiguous();
bool copy_indices = !indices.is_contiguous();
if (dim == ndim - 1) {
scan_innermost_dim_with_indices<scalar_t>(self_, values_, indices_, init, binary_op);
} else {
scan_outer_dim_with_indices<scalar_t>(self_, values_, indices_, dim, init, binary_op);
}
if (copy_values){
values.copy_(values_);
}
if (copy_indices){
indices.copy_(indices_);
}
}
void cummax_helper_cuda(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
TensorArg output_arg{ values, "output", 1 };
TensorArg indices_arg{ indices, "indices", 2 };
TensorArg input_arg{ self, "input", 3 };
checkAllSameGPU("cummax", {output_arg, indices_arg, input_arg});
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half,
self.scalar_type(), "cummax_cuda", [&]() {
scalar_t init = self.is_floating_point() ? (-1*std::numeric_limits<scalar_t>::infinity()) : std::numeric_limits<scalar_t>::lowest();
scan_dim_with_indices<scalar_t>(self, values, indices, dim, init, std::greater_equal<scalar_t>());
});
}
void cummin_helper_cuda(const Tensor& self, Tensor& values, Tensor& indices, int64_t dim) {
TensorArg output_arg{ values, "output", 1 };
TensorArg indices_arg{ indices, "indices", 2 };
TensorArg input_arg{ self, "input", 3 };
checkAllSameGPU("cummin", {output_arg, indices_arg, input_arg});
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Bool, at::ScalarType::Half,
self.scalar_type(), "cummin_cuda", [&]() {
scalar_t init = self.is_floating_point() ? std::numeric_limits<scalar_t>::infinity() : std::numeric_limits<scalar_t>::max();
scan_dim_with_indices<scalar_t>(self, values, indices, dim, init, std::less_equal<scalar_t>());
});
}
// TODO: The implementation of `tensor_kernel_scan_outer_dim` and
// `tensor_kernel_scan_innermost_dim` is similar to
// `tensor_kernel_scan_outer_dim_with_indices`
// `tensor_kernel_scan_outer_dim_with_indices` and should be refactored to
// remove the duplication.
/* Perform an inclusive scan along an outer dimension of a tensor.
*
* - num_orows is the size of the flattened outer dimensions;
* - num_irows is the size of the flattened inner dimensions;
* - row_size is the size of the dimension along which to scan;
*
* The dimensions to the outside and inside of the specified dimension are considered as flattened.
* Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
* outer dimensions, which contains several "inner rows").
* Each thread processes a single inner row at a time.
*/
template<typename scalar_t, class BinaryOp>
__global__ void tensor_kernel_scan_outer_dim(scalar_t *tgt_, scalar_t *src_,
unsigned num_orows, unsigned num_irows, unsigned row_size,
scalar_t init, BinaryOp binary_op)
{
for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
scalar_t *src = src_ + orow * row_size * num_irows + irow;
scalar_t *tgt = tgt_ + orow * row_size * num_irows + irow;
scalar_t acc = init;
for (unsigned col = 0; col < row_size; ++col) {
acc = binary_op(acc, *src);
*tgt = acc;
src += num_irows;
tgt += num_irows;
}
}
}
}
/* Perform an inclusive scan along the innermost dimension of a tensor.
*
* - num_rows is the size of the flattened outer dimensions;
* - row_size is the size of the innermost dimension;
*
* The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
* considered as having 'num_rows' rows of size 'row_size'.
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
* per thread block is quicker than processing a single row, especially for short rows).
*/
template<typename T, int num_threads_x, int num_threads_y, class BinaryFunction>
__device__ void tensor_kernel_scan_innermost_dim_impl(T* row_buf, T *tgt_, T *src_,
unsigned num_rows, unsigned row_size,
T init, BinaryFunction binary_op){
for (unsigned block_row = blockIdx.x * blockDim.y;
block_row < num_rows;
block_row += blockDim.y * gridDim.x) {
unsigned row = block_row + threadIdx.y;
T block_total = init;
T *row_src = src_ + row * row_size;
T *row_tgt = tgt_ + row * row_size;
// Perform scan on one block at a time, keeping track of the total value of
// all blocks processed so far.
for (unsigned block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) {
// Load data into shared memory (two values per thread).
unsigned col1 = block_col + threadIdx.x;
unsigned col2 = block_col + num_threads_x + threadIdx.x;
if (row < num_rows) {
if (col1 < row_size) {
row_buf[threadIdx.x] = row_src[col1];
} else {
row_buf[threadIdx.x] = init;
}
if (col2 < row_size) {
row_buf[num_threads_x + threadIdx.x] = row_src[col2];
} else {
row_buf[num_threads_x + threadIdx.x] = init;
}
// Add the total value of all previous blocks to the first value of this block.
if (threadIdx.x == 0) {
row_buf[0] = binary_op(row_buf[0], block_total);
}
}
__syncthreads();
// Parallel reduction (up-sweep).
for (unsigned s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) {
if (row < num_rows && threadIdx.x < s) {
unsigned offset = (2 * threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Down-sweep.
for (unsigned s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) {
if (row < num_rows && threadIdx.x < s - 1) {
unsigned offset = 2 * (threadIdx.x + 1) * d - 1;
row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]);
}
__syncthreads();
}
// Write back to output.
if (row < num_rows) {
if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x];
if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x];
}
block_total = row_buf[2 * num_threads_x - 1];
__syncthreads();
}
}
}
template <
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction>
__global__ typename std::enable_if<!c10::is_complex_t<T>::value, void>::type
tensor_kernel_scan_innermost_dim(
T* tgt_,
T* src_,
unsigned num_rows,
unsigned row_size,
T init,
BinaryFunction binary_op) {
__shared__ T sbuf[num_threads_y][2 * num_threads_x];
T* row_buf = sbuf[threadIdx.y];
tensor_kernel_scan_innermost_dim_impl<T, num_threads_x, num_threads_y>(
row_buf, tgt_, src_, num_rows, row_size, init, binary_op);
}
template <
typename T,
int num_threads_x,
int num_threads_y,
class BinaryFunction>
__global__ typename std::enable_if<c10::is_complex_t<T>::value, void>::type
tensor_kernel_scan_innermost_dim(
T* tgt_,
T* src_,
unsigned num_rows,
unsigned row_size,
T init,
BinaryFunction binary_op) {
// As we cannot directly initialize shared array for complex types
// Reference:
// `error: initializer not allowed for __shared__ variable`
// We instead get the base scalar type and allocate twice number of
// elements required of base type and reinterpret them as complex.
using base_t = typename scalar_value_type<T>::type;
__shared__ base_t sbuf[num_threads_y][4 * num_threads_x];
T* row_buf = reinterpret_cast<T*>(sbuf[threadIdx.y]);
tensor_kernel_scan_innermost_dim_impl<T, num_threads_x, num_threads_y>(
row_buf, tgt_, src_, num_rows, row_size, init, binary_op);
}
void check_fits_in_unsigned(int64_t val, const char* name) {
constexpr auto umax = std::numeric_limits<unsigned>::max();
TORCH_CHECK(
val >= 0 && val <= umax, name, " must fit in a 32-bit unsigned value");
}
template<typename scalar_t, class BinaryFunction>
__host__ void scan_outer_dim(const Tensor& self, Tensor& result,
int dim, scalar_t init, BinaryFunction binary_op) {
int64_t row_size = self.size(dim);
auto sizes = self.sizes();
// Treat all outer dimensions (i.e. dim_ < dim) as one.
int64_t num_orows = std::accumulate(sizes.begin(), sizes.begin() + dim, 1, std::multiplies<int64_t>());
// Treat all inner dimensions (i.e. dim > dimension) as one.
int64_t num_irows = std::accumulate(sizes.begin() + dim + 1, sizes.end(), 1, std::multiplies<int64_t>());
dim3 threads(std::min(512, int(num_irows)));
int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));
check_fits_in_unsigned(num_irows, "num_irows");
check_fits_in_unsigned(num_orows, "num_orows");
check_fits_in_unsigned(row_size, "row_size");
tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
result.data_ptr<scalar_t>(), self.data_ptr<scalar_t>(),
num_orows, num_irows, row_size, init, binary_op);
AT_CUDA_CHECK(cudaGetLastError());
}
template <typename scalar_t, class BinaryFunction>
void scan_innermost_dim(const Tensor& self, Tensor& result, scalar_t init, BinaryFunction binary_op) {
int64_t ndim = self.dim();
// Treat all outer dimensions as a single dimension.
int64_t row_size = self.size(ndim - 1);
int64_t num_rows = self.numel() / row_size;
dim3 threads(16, 32);
int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y})));
check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))");
check_fits_in_unsigned(row_size, "row_size");
tensor_kernel_scan_innermost_dim<scalar_t, 16, 32><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
result.data_ptr<scalar_t>(), self.data_ptr<scalar_t>(),
num_rows, row_size, init, binary_op);
AT_CUDA_CHECK(cudaGetLastError());
}
template<typename scalar_t, class func_t>
__global__ void transform_vals(scalar_t * a, scalar_t * b, scalar_t * out, func_t binary_op){
*out = binary_op(*a, *b);
}
#ifdef __HIP_PLATFORM_HCC__
template<typename T>
struct ROCm_Bug {
char bytes[sizeof(T)];
};
#endif
template<typename scalar_t, typename BinaryFunction>
void scan_thrust_or_cub(const Tensor& self, Tensor& result, scalar_t init, BinaryFunction binary_op) {
#ifdef __HIP_PLATFORM_HCC__
auto allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
using rocm_bug_t = ROCm_Bug<scalar_t>;
thrust::device_ptr<rocm_bug_t> src_data(reinterpret_cast<rocm_bug_t *>(self.data_ptr<scalar_t>()));
thrust::device_ptr<rocm_bug_t> dst_data(reinterpret_cast<rocm_bug_t *>(result.data_ptr<scalar_t>()));
ptrdiff_t size = self.numel();
auto rocm_bug_binary_op = [=]C10_HOST_DEVICE(const rocm_bug_t a, const rocm_bug_t b) -> rocm_bug_t {
auto result = binary_op((*reinterpret_cast<const scalar_t*>(&a)),
(*reinterpret_cast<const scalar_t*>(&b)));
return *reinterpret_cast<rocm_bug_t *>(&result);
};
thrust::inclusive_scan(
thrust::cuda::par(allocator).on(c10::cuda::getCurrentCUDAStream()),
src_data, src_data + size, dst_data,
rocm_bug_binary_op);
#else
int64_t size = self.numel();
// non synchronizing cub call
// even though cub is supposed to support tensors with int_max elements, in reality it doesn't,
// so split at int_max/2
constexpr int max_cub_size = std::numeric_limits<int>::max() / 2 + 1; // 2**30
for (int64_t i = 0; i < size; i += max_cub_size) {
int size_cub = std::min<int64_t>(size - i, max_cub_size);
Tensor first_elem; // need to save it for all iterations other than first
if (i > 0) {
// need to temporarily transform first element of the range we are
// operating on; self might be multi-d, but we need to index a single
// element
auto self_view = at::_unsafe_view(self, -1);
first_elem = self_view[i].clone();
transform_vals<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
self.data_ptr<scalar_t>() + i,
result.data_ptr<scalar_t>() + i - 1,
self.data_ptr<scalar_t>() + i,
binary_op);
}
size_t temp_storage_bytes = 0;
AT_CUDA_CHECK(cub::DeviceScan::InclusiveScan(
nullptr,
temp_storage_bytes,
self.data_ptr<scalar_t>() + i,
result.data_ptr<scalar_t>() + i,
binary_op,
size_cub,
at::cuda::getCurrentCUDAStream()));
auto temp_storage = at::native::empty_cuda(
{static_cast<int64_t>(temp_storage_bytes)},
self.options().dtype(kByte));
AT_CUDA_CHECK(cub::DeviceScan::InclusiveScan(
temp_storage.data_ptr(),
temp_storage_bytes,
self.data_ptr<scalar_t>() + i,
result.data_ptr<scalar_t>() + i,
binary_op,
size_cub,
at::cuda::getCurrentCUDAStream()));
if (i > 0) {
if (self.data_ptr<scalar_t>() != result.data_ptr<scalar_t>()) {
// restore modified first element only if it's not an inplace operation
auto self_view = at::_unsafe_view(self, -1);
self_view[i].copy_(first_elem, /*non_blocking=*/true);
}
}
}
#endif
}
template<typename scalar_t, typename BinaryFunction>
void scan_dim(const Tensor& self, Tensor& result,
int64_t dim, scalar_t init, BinaryFunction binary_op) {
int ndim = self.dim();
Tensor self_ = self.contiguous();
bool copy_result = !result.is_contiguous();
Tensor result_ = result.contiguous();
if (self.numel() == self.size(dim)) {
scan_thrust_or_cub<scalar_t>(self_, result_, init, binary_op);
} else if (dim == ndim - 1) {
scan_innermost_dim<scalar_t>(self_, result_, init, binary_op);
} else {
scan_outer_dim<scalar_t>(self_, result_, dim, init, binary_op);
}
if (copy_result) {
result.copy_(result_);
}
}
Tensor& _logcumsumexp_out_cuda(Tensor& result, const Tensor& self, int64_t dim) {
result.resize_(self.sizes());
if (self.dim() == 0) {
result.fill_(self);
return result;
}
if (self.numel() == 0) {
result.zero_();
return result;
}
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
TensorArg output_arg{ result, "output", 1 };
TensorArg input_arg{ self, "input", 2 };
checkAllSameGPU("logcumsumexp", {output_arg, input_arg});
AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half,
self.scalar_type(), "logcumsumexp_cuda", [&]() {
scalar_t init = -std::numeric_limits<scalar_t>::infinity();
auto log_add_exp = [] C10_HOST_DEVICE (const scalar_t x, const scalar_t y) -> scalar_t {
return ::log1p(std::exp(std::min(x, y) - std::max(x, y))) +
std::max(x, y);
};
scan_dim<scalar_t>(self, result, wrap_dim, init, log_add_exp);
});
return result;
}
Tensor _logcumsumexp_cuda(const Tensor& self, int64_t dim) {
Tensor result = at::empty_like(self, MemoryFormat::Contiguous);
return _logcumsumexp_out_cuda(result, self, dim);
}
Tensor& _cumsum_out_cuda(Tensor& result, const Tensor& self, int64_t dim) {
TensorArg output_arg{result, "output", 1};
TensorArg input_arg{self, "input", 2};
checkAllSameGPU("cumsum", {output_arg, input_arg});
checkSameType("cumsum", output_arg, input_arg);
result.resize_(self.sizes());
if (self.dim() == 0) {
result.fill_(self);
return result;
}
if (self.numel() == 0) {
result.zero_();
return result;
}
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(
at::ScalarType::Half, self.scalar_type(), "cumsum_cuda", [&]() {
scalar_t init = 0;
scan_dim<scalar_t>(
self,
result,
wrap_dim,
init,
std::plus<scalar_t>());
});
return result;
}
Tensor _cumsum_cuda(const Tensor& self, int64_t dim) {
Tensor result = at::empty_like(self, MemoryFormat::Contiguous);
return _cumsum_out_cuda(result, self, dim);
}
Tensor& _cumprod_out_cuda(Tensor& result, const Tensor& self, int64_t dim) {
TensorArg output_arg{result, "output", 1};
TensorArg input_arg{self, "input", 2};
checkAllSameGPU("cumprod", {output_arg, input_arg});
checkSameType("cumprod", output_arg, input_arg);
result.resize_(self.sizes());
if (self.dim() == 0) {
result.fill_(self);
return result;
}
if (self.numel() == 0) {
result.zero_();
return result;
}
auto wrap_dim = maybe_wrap_dim(dim, self.dim());
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(
at::ScalarType::Half, self.scalar_type(), "cumprod_cuda", [&]() {
scalar_t init = 1;
scan_dim<scalar_t>(
self,
result,
wrap_dim,
init,
std::multiplies<scalar_t>());
});
return result;
}
Tensor _cumprod_cuda(const Tensor& self, int64_t dim) {
Tensor result = at::empty_like(self, MemoryFormat::Contiguous);
return _cumprod_out_cuda(result, self, dim);
}
}} // namespace at::native