Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

warpscan: try float accumulation for bf16 and float16 #4

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions accelerated_scan/warp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

template <typename weight_t, int kNStepsPerThread, int kNThreadsPerWarp, int kNWarpsPerBlock, int kNChunksPerSequence>
template <typename weight_t, typename acc_t, int kNStepsPerThread, int kNThreadsPerWarp, int kNWarpsPerBlock, int kNChunksPerSequence>
__global__ void scan(
const weight_t* gates,
const weight_t* tokens,
Expand All @@ -13,9 +13,9 @@ __global__ void scan(
const int dim_stride,
const bool reverse
) {
__shared__ weight_t warpLastGate[kNWarpsPerBlock];
__shared__ weight_t warpLastToken[kNWarpsPerBlock];
__shared__ weight_t chunkAccGate, chunkAccToken;
__shared__ acc_t warpLastGate[kNWarpsPerBlock];
__shared__ acc_t warpLastToken[kNWarpsPerBlock];
__shared__ acc_t chunkAccGate, chunkAccToken;

const int seqoffset = blockIdx.x * batch_stride + blockIdx.y * dim_stride;
const int warpId = threadIdx.x / kNThreadsPerWarp;
Expand All @@ -24,16 +24,16 @@ __global__ void scan(
constexpr int kBlockLast = kNWarpsPerBlock - 1;
constexpr int kWarpLast = kNThreadsPerWarp - 1;
constexpr int kThreadLast = kNStepsPerThread - 1;
const weight_t kEmptyGate = 1.0;
const weight_t kEmptyToken = 0.0;
const acc_t kEmptyGate = 1.0;
const acc_t kEmptyToken = 0.0;

//
// Read from global memory.
// Scan sequentially in thread registers (level 0).
//

weight_t accGate[kNStepsPerThread];
weight_t accToken[kNStepsPerThread];
acc_t accGate[kNStepsPerThread];
acc_t accToken[kNStepsPerThread];

for (int chunk = 0; chunk < kNChunksPerSequence; chunk++) {
const int offset = seqoffset + (reverse ? kNChunksPerSequence - 1 - chunk : chunk) * chunklen;
Expand All @@ -45,8 +45,8 @@ __global__ void scan(
#pragma unroll
for (int i = 0; i < kNStepsPerThread; ++i) {
const int chunkOffset = reverse ? chunklen - 1 - (threadIdx.x * kNStepsPerThread + i) : (threadIdx.x * kNStepsPerThread + i);
weight_t gate = gates[offset + chunkOffset];
weight_t token = tokens[offset + chunkOffset];
acc_t gate = gates[offset + chunkOffset];
acc_t token = tokens[offset + chunkOffset];
if (i == 0) {
if (chunk == 0) {
accGate[0] = threadIdx.x == 0 ? kEmptyGate : gate;
Expand All @@ -73,8 +73,8 @@ __global__ void scan(

#pragma unroll
for (int delta = 1; delta < kNThreadsPerWarp; delta *= 2) {
weight_t prev_gate = __shfl_up_sync(0xffffffff, accGate[kThreadLast], delta);
weight_t prev_token = __shfl_up_sync(0xffffffff, accToken[kThreadLast], delta);
acc_t prev_gate = __shfl_up_sync(0xffffffff, accGate[kThreadLast], delta);
acc_t prev_token = __shfl_up_sync(0xffffffff, accToken[kThreadLast], delta);

if (laneId >= delta) {
#pragma unroll
Expand Down Expand Up @@ -103,14 +103,14 @@ __global__ void scan(
//

if (warpId == 0) {
weight_t warpAccGate, warpAccToken;
acc_t warpAccGate, warpAccToken;
warpAccGate = (laneId < kNWarpsPerBlock) ? warpLastGate[laneId] : kEmptyGate;
warpAccToken = (laneId < kNWarpsPerBlock) ? warpLastToken[laneId] : kEmptyToken;

#pragma unroll
for (int delta = 1; delta < warpSize; delta *= 2) {
weight_t prev_gate = __shfl_up_sync(0xffffffff, warpAccGate, delta);
weight_t prev_token = __shfl_up_sync(0xffffffff, warpAccToken, delta);
acc_t prev_gate = __shfl_up_sync(0xffffffff, warpAccGate, delta);
acc_t prev_token = __shfl_up_sync(0xffffffff, warpAccToken, delta);

if (laneId >= delta) {
warpAccToken = prev_token * warpAccGate + warpAccToken;
Expand Down Expand Up @@ -148,7 +148,7 @@ __global__ void scan(
}
}

template <typename weight_t, typename torch_weight_t>
template <typename weight_t, typename torch_weight_t, typename acc_t>
void
warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &out, const bool reverse) {
const auto strides = tokens.strides();
Expand All @@ -171,7 +171,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
constexpr int kNWarpsPerBlock = 1;
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
scan<weight_t, acc_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
batch_stride, dim_stride, reverse
);
Expand All @@ -180,7 +180,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
constexpr int kNWarpsPerBlock = 1;
constexpr int kNChunksPerSequence = 1;
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
scan<weight_t, acc_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
batch_stride, dim_stride, reverse
);
Expand All @@ -189,7 +189,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
constexpr int kNWarpsPerBlock = 4;
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
scan<weight_t, acc_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
batch_stride, dim_stride, reverse
);
Expand All @@ -198,7 +198,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
constexpr int kNWarpsPerBlock = 8;
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
scan<weight_t, acc_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
batch_stride, dim_stride, reverse
);
Expand All @@ -207,7 +207,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
constexpr int kNWarpsPerBlock = 16;
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
scan<weight_t, acc_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
batch_stride, dim_stride, reverse
);
Expand All @@ -216,7 +216,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
constexpr int kNWarpsPerBlock = 16;
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
scan<weight_t, acc_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
batch_stride, dim_stride, reverse
);
Expand All @@ -225,7 +225,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
constexpr int kNWarpsPerBlock = 32;
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
scan<weight_t, acc_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
batch_stride, dim_stride, reverse
);
Expand All @@ -234,7 +234,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
constexpr int kNWarpsPerBlock = 32;
int kNThreads = seqlen / kNStepsPerThread;
constexpr int kNChunksPerSequence = 1;
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
scan<weight_t, acc_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
batch_stride, dim_stride, reverse
);
Expand All @@ -243,7 +243,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
constexpr int kNWarpsPerBlock = 32;
constexpr int kNChunksPerSequence = 2;
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
scan<weight_t, acc_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
batch_stride, dim_stride, reverse
);
Expand All @@ -252,7 +252,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
constexpr int kNWarpsPerBlock = 32;
constexpr int kNChunksPerSequence = 4;
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
scan<weight_t, acc_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
batch_stride, dim_stride, reverse
);
Expand All @@ -261,7 +261,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
constexpr int kNWarpsPerBlock = 32;
constexpr int kNChunksPerSequence = 8;
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
scan<weight_t, acc_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
batch_stride, dim_stride, reverse
);
Expand All @@ -270,7 +270,7 @@ warpscan(const at::Tensor &gates, const at::Tensor &tokens, const at::Tensor &ou
constexpr int kNWarpsPerBlock = 32;
constexpr int kNChunksPerSequence = 16;
int kNThreads = seqlen / kNStepsPerThread / kNChunksPerSequence;
scan<weight_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
scan<weight_t, acc_t, kNStepsPerThread, kNThreadsPerWarp, kNWarpsPerBlock, kNChunksPerSequence><<<grid, kNThreads, kNWarpsPerBlock * sizeof(weight_t) * 2, stream>>>(
reinterpret_cast<weight_t*>(gates.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(tokens.data_ptr<torch_weight_t>()), reinterpret_cast<weight_t*>(out.data_ptr<torch_weight_t>()),
batch_stride, dim_stride, reverse
);
Expand All @@ -288,13 +288,13 @@ warpscan_forward(const at::Tensor &gates, const at::Tensor &tokens, const at::Te

if (tokens.scalar_type() == at::ScalarType::BFloat16) {
TORCH_CHECK(gates.scalar_type() == at::ScalarType::BFloat16);
warpscan<__nv_bfloat16, at::BFloat16>(gates, tokens, out, reverse);
warpscan<__nv_bfloat16, at::BFloat16, float>(gates, tokens, out, reverse);
} else if (tokens.scalar_type() == at::ScalarType::Half) {
TORCH_CHECK(gates.scalar_type() == at::ScalarType::Half);
warpscan<__half, at::Half>(gates, tokens, out, reverse);
warpscan<__half, at::Half, float>(gates, tokens, out, reverse);
} else if (tokens.scalar_type() == at::ScalarType::Float) {
TORCH_CHECK(gates.scalar_type() == at::ScalarType::Float);
warpscan<float, float>(gates, tokens, out, reverse);
warpscan<float, float, float>(gates, tokens, out, reverse);
} else {
TORCH_CHECK(false && "Unsupported tensor dtype: expecting bfloat16, float16 or float32");
}
Expand Down