diff --git a/accelerated_scan/warp.cuh b/accelerated_scan/warp.cuh index 467b2ed..0818582 100644 --- a/accelerated_scan/warp.cuh +++ b/accelerated_scan/warp.cuh @@ -4,7 +4,7 @@ #include #include -template +template __global__ void scan( const weight_t* gates, const weight_t* tokens, @@ -13,9 +13,9 @@ __global__ void scan( const int dim_stride, const bool reverse ) { - __shared__ weight_t warpLastGate[kNWarpsPerBlock]; - __shared__ weight_t warpLastToken[kNWarpsPerBlock]; - __shared__ weight_t chunkAccGate, chunkAccToken; + __shared__ acc_t warpLastGate[kNWarpsPerBlock]; + __shared__ acc_t warpLastToken[kNWarpsPerBlock]; + __shared__ acc_t chunkAccGate, chunkAccToken; const int seqoffset = blockIdx.x * batch_stride + blockIdx.y * dim_stride; const int warpId = threadIdx.x / kNThreadsPerWarp; @@ -24,16 +24,16 @@ __global__ void scan( constexpr int kBlockLast = kNWarpsPerBlock - 1; constexpr int kWarpLast = kNThreadsPerWarp - 1; constexpr int kThreadLast = kNStepsPerThread - 1; - const weight_t kEmptyGate = 1.0; - const weight_t kEmptyToken = 0.0; + const acc_t kEmptyGate = 1.0; + const acc_t kEmptyToken = 0.0; // // Read from global memory. // Scan sequentially in thread registers (level 0). // - weight_t accGate[kNStepsPerThread]; - weight_t accToken[kNStepsPerThread]; + acc_t accGate[kNStepsPerThread]; + acc_t accToken[kNStepsPerThread]; for (int chunk = 0; chunk < kNChunksPerSequence; chunk++) { const int offset = seqoffset + (reverse ? kNChunksPerSequence - 1 - chunk : chunk) * chunklen; @@ -45,8 +45,8 @@ __global__ void scan( #pragma unroll for (int i = 0; i < kNStepsPerThread; ++i) { const int chunkOffset = reverse ? chunklen - 1 - (threadIdx.x * kNStepsPerThread + i) : (threadIdx.x * kNStepsPerThread + i); - weight_t gate = gates[offset + chunkOffset]; - weight_t token = tokens[offset + chunkOffset]; + acc_t gate = gates[offset + chunkOffset]; + acc_t token = tokens[offset + chunkOffset]; if (i == 0) { if (chunk == 0) { accGate[0] = threadIdx.x == 0 ? kEmptyGate : gate; @@ -73,8 +73,8 @@ __global__ void scan( #pragma unroll for (int delta = 1; delta < kNThreadsPerWarp; delta *= 2) { - weight_t prev_gate = __shfl_up_sync(0xffffffff, accGate[kThreadLast], delta); - weight_t prev_token = __shfl_up_sync(0xffffffff, accToken[kThreadLast], delta); + acc_t prev_gate = __shfl_up_sync(0xffffffff, accGate[kThreadLast], delta); + acc_t prev_token = __shfl_up_sync(0xffffffff, accToken[kThreadLast], delta); if (laneId >= delta) { #pragma unroll @@ -103,14 +103,14 @@ __global__ void scan( // if (warpId == 0) { - weight_t warpAccGate, warpAccToken; + acc_t warpAccGate, warpAccToken; warpAccGate = (laneId < kNWarpsPerBlock) ? warpLastGate[laneId] : kEmptyGate; warpAccToken = (laneId < kNWarpsPerBlock) ? warpLastToken[laneId] : kEmptyToken; #pragma unroll for (int delta = 1; delta < warpSize; delta *= 2) { - weight_t prev_gate = __shfl_up_sync(0xffffffff, warpAccGate, delta); - weight_t prev_token = __shfl_up_sync(0xffffffff, warpAccToken, delta); + acc_t prev_gate = __shfl_up_sync(0xffffffff, warpAccGate, delta); + acc_t prev_token = __shfl_up_sync(0xffffffff, warpAccToken, delta); if (laneId >= delta) { warpAccToken = prev_token * warpAccGate + warpAccToken; @@ -148,7 +148,7 @@ __global__ void scan( } } -template +template void warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse) { const auto strides = tokens.strides(); @@ -171,7 +171,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou constexpr int kNWarpsPerBlock = 1; int kNThreads = seqlen / kNStepsPerThread; constexpr int kNChunksPerSequence = 1; - scan<<>>( + scan<<>>( reinterpret_cast(gates.data_ptr()), reinterpret_cast(tokens.data_ptr()), reinterpret_cast(out.data_ptr()), batch_stride, dim_stride, reverse ); @@ -180,7 +180,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou constexpr int kNWarpsPerBlock = 1; constexpr int kNChunksPerSequence = 1; int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence; - scan<<>>( + scan<<>>( reinterpret_cast(gates.data_ptr()), reinterpret_cast(tokens.data_ptr()), reinterpret_cast(out.data_ptr()), batch_stride, dim_stride, reverse ); @@ -189,7 +189,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou constexpr int kNWarpsPerBlock = 4; int kNThreads = seqlen / kNStepsPerThread; constexpr int kNChunksPerSequence = 1; - scan<<>>( + scan<<>>( reinterpret_cast(gates.data_ptr()), reinterpret_cast(tokens.data_ptr()), reinterpret_cast(out.data_ptr()), batch_stride, dim_stride, reverse ); @@ -198,7 +198,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou constexpr int kNWarpsPerBlock = 8; int kNThreads = seqlen / kNStepsPerThread; constexpr int kNChunksPerSequence = 1; - scan<<>>( + scan<<>>( reinterpret_cast(gates.data_ptr()), reinterpret_cast(tokens.data_ptr()), reinterpret_cast(out.data_ptr()), batch_stride, dim_stride, reverse ); @@ -207,7 +207,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou constexpr int kNWarpsPerBlock = 16; int kNThreads = seqlen / kNStepsPerThread; constexpr int kNChunksPerSequence = 1; - scan<<>>( + scan<<>>( reinterpret_cast(gates.data_ptr()), reinterpret_cast(tokens.data_ptr()), reinterpret_cast(out.data_ptr()), batch_stride, dim_stride, reverse ); @@ -216,7 +216,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou constexpr int kNWarpsPerBlock = 16; int kNThreads = seqlen / kNStepsPerThread; constexpr int kNChunksPerSequence = 1; - scan<<>>( + scan<<>>( reinterpret_cast(gates.data_ptr()), reinterpret_cast(tokens.data_ptr()), reinterpret_cast(out.data_ptr()), batch_stride, dim_stride, reverse ); @@ -225,7 +225,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou constexpr int kNWarpsPerBlock = 32; int kNThreads = seqlen / kNStepsPerThread; constexpr int kNChunksPerSequence = 1; - scan<<>>( + scan<<>>( reinterpret_cast(gates.data_ptr()), reinterpret_cast(tokens.data_ptr()), reinterpret_cast(out.data_ptr()), batch_stride, dim_stride, reverse ); @@ -234,7 +234,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou constexpr int kNWarpsPerBlock = 32; int kNThreads = seqlen / kNStepsPerThread; constexpr int kNChunksPerSequence = 1; - scan<<>>( + scan<<>>( reinterpret_cast(gates.data_ptr()), reinterpret_cast(tokens.data_ptr()), reinterpret_cast(out.data_ptr()), batch_stride, dim_stride, reverse ); @@ -243,7 +243,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou constexpr int kNWarpsPerBlock = 32; constexpr int kNChunksPerSequence = 2; int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence; - scan<<>>( + scan<<>>( reinterpret_cast(gates.data_ptr()), reinterpret_cast(tokens.data_ptr()), reinterpret_cast(out.data_ptr()), batch_stride, dim_stride, reverse ); @@ -252,7 +252,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou constexpr int kNWarpsPerBlock = 32; constexpr int kNChunksPerSequence = 4; int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence; - scan<<>>( + scan<<>>( reinterpret_cast(gates.data_ptr()), reinterpret_cast(tokens.data_ptr()), reinterpret_cast(out.data_ptr()), batch_stride, dim_stride, reverse ); @@ -261,7 +261,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou constexpr int kNWarpsPerBlock = 32; constexpr int kNChunksPerSequence = 8; int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence; - scan<<>>( + scan<<>>( reinterpret_cast(gates.data_ptr()), reinterpret_cast(tokens.data_ptr()), reinterpret_cast(out.data_ptr()), batch_stride, dim_stride, reverse ); @@ -270,7 +270,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou constexpr int kNWarpsPerBlock = 32; constexpr int kNChunksPerSequence = 16; int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence; - scan<<>>( + scan<<>>( reinterpret_cast(gates.data_ptr()), reinterpret_cast(tokens.data_ptr()), reinterpret_cast(out.data_ptr()), batch_stride, dim_stride, reverse ); @@ -288,13 +288,13 @@ warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Te if (tokens.scalar_type() == at::ScalarType::BFloat16) { TORCH_CHECK(gates.scalar_type() == at::ScalarType::BFloat16); - warpscan<__nv_bfloat16, at::BFloat16>(gates, tokens, out, reverse); + warpscan<__nv_bfloat16, at::BFloat16, float>(gates, tokens, out, reverse); } else if (tokens.scalar_type() == at::ScalarType::Half) { TORCH_CHECK(gates.scalar_type() == at::ScalarType::Half); - warpscan<__half, at::Half>(gates, tokens, out, reverse); + warpscan<__half, at::Half, float>(gates, tokens, out, reverse); } else if (tokens.scalar_type() == at::ScalarType::Float) { TORCH_CHECK(gates.scalar_type() == at::ScalarType::Float); - warpscan(gates, tokens, out, reverse); + warpscan(gates, tokens, out, reverse); } else { TORCH_CHECK(false && "Unsupported tensor dtype: expecting bfloat16, float16 or float32"); }