Skip to content

Commit

Permalink
[CUDA] Fix SplitV gridDim.y > 65535 case bug
Browse files Browse the repository at this point in the history
  • Loading branch information
doxutx committed Aug 28, 2024
1 parent c828a4b commit d7f6f79
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion source/tnn/device/cuda/acc/cuda_splitv_layer_acc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ namespace TNN_NS {

DECLARE_CUDA_ACC(SplitV, LAYER_SPLITV);

int SPLITV_GRID_DIM_Y_MAX = 65535;

template<int THREAD_PER_BLOCK, int ELE_PER_THREAD>
__global__ void splitv_separate_kernel(
const float * __restrict__ src, float * dst,
Expand All @@ -44,6 +46,32 @@ __global__ void splitv_separate_kernel(

}

// Cases when real_grid_dim_y > SPLITV_GRID_DIM_Y_MAX
template<int THREAD_PER_BLOCK, int ELE_PER_THREAD>
__global__ void splitv_separate_ylarge_kernel(
const float * __restrict__ src, float * dst,
const int inner_size, const int in_stride,
const int split_start, const int split_end, const int real_grid_dim_y, const int GRID_DIM_Y_MAX)
{
for (int block_idx_y = blockIdx.y; block_idx_y < real_grid_dim_y; block_idx_y += GRID_DIM_Y_MAX) {
int block_offset = blockIdx.x * THREAD_PER_BLOCK * ELE_PER_THREAD;

const int split_size = split_end - split_start;
const int size = split_size * inner_size;
const float* src_offsetted = src + (blockIdx.z * real_grid_dim_y + block_idx_y) * in_stride;
float* dst_offsetted = dst + (blockIdx.z * real_grid_dim_y + block_idx_y) * size;

#pragma unroll
for (int i = 0; i < ELE_PER_THREAD ; i++) {
int index = block_offset + i * THREAD_PER_BLOCK + threadIdx.x;
if (index < size) {
int input_index = index + split_start * inner_size;
dst_offsetted[index] = __ldg(src_offsetted + input_index);
}
}
}
}

Status CudaSplitVLayerAcc::Init(Context *context, LayerParam *param, LayerResource *resource,
const std::vector<Blob *> &inputs, const std::vector<Blob *> &outputs) {
CudaLayerAcc::Init(context, param, resource, inputs, outputs);
Expand Down Expand Up @@ -88,8 +116,15 @@ Status CudaSplitVLayerAcc::Forward(const std::vector<Blob *> &inputs, const std:
griddim.z = DimsVectorUtils::Count(dims, 0, min(axis, 1));

float* output_data = static_cast<float*>(output_blob->GetHandle().base);
splitv_separate_kernel<THREAD_PER_BLOCK, ELE_PER_THREAD><<<griddim, THREAD_PER_BLOCK, 0, context_->GetStream()>>>
if (griddim.y <= SPLITV_GRID_DIM_Y_MAX) {
splitv_separate_kernel<THREAD_PER_BLOCK, ELE_PER_THREAD><<<griddim, THREAD_PER_BLOCK, 0, context_->GetStream()>>>
(input_data, output_data, inner_size, in_stride, split_begin, split_end);
} else {
int real_grid_dim_y = griddim.y;
griddim.y = SPLITV_GRID_DIM_Y_MAX;
splitv_separate_ylarge_kernel<THREAD_PER_BLOCK, ELE_PER_THREAD><<<griddim, THREAD_PER_BLOCK, 0, context_->GetStream()>>>
(input_data, output_data, inner_size, in_stride, split_begin, split_end, real_grid_dim_y, SPLITV_GRID_DIM_Y_MAX);
}
split_begin = split_end;
}
}
Expand Down

0 comments on commit d7f6f79

Please sign in to comment.