Skip to content

Commit

Permalink
fix moe wna16 kernel
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhen Lin <[email protected]>
  • Loading branch information
jinzhen-lin committed Feb 15, 2025
1 parent 31dc339 commit 2c1497b
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 20 deletions.
40 changes: 22 additions & 18 deletions csrc/moe/moe_wna16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ __global__ void moe_wna16_gemm_kernel(

int32_t num_valid_tokens = 0;
extern __shared__ uint16_t block_input_tmp[];
scalar_t* block_input = reinterpret_cast<scalar_t*>(&block_input_tmp);
scalar_t2* block_input_half2 = reinterpret_cast<scalar_t2*>(&block_input);
scalar_t* block_input = reinterpret_cast<scalar_t*>(block_input_tmp);
scalar_t2* block_input_half2 = reinterpret_cast<scalar_t2*>(block_input);

// load BLOCK_SIZE_M * BLOCK_SIZE_K into shared memory
for (int m = 0; m < BLOCK_SIZE_M; m++) {
Expand Down Expand Up @@ -101,7 +101,7 @@ __global__ void moe_wna16_gemm_kernel(
// weight would be loaded in loop
uint32_t expert_qweight_tmp[4];
float4* expert_qweight_tmp_float4 =
reinterpret_cast<float4*>(&expert_qweight_tmp);
reinterpret_cast<float4*>(expert_qweight_tmp);

// load all required scales one time
scalar_t expert_scales_groups[GROUPS];
Expand All @@ -111,48 +111,52 @@ __global__ void moe_wna16_gemm_kernel(
*expert_scales_groups = expert_scales[scales_offset_tmp];
} else if constexpr (GROUPS == 2) {
float* expert_scales_groups_tmp =
reinterpret_cast<float*>(&expert_scales_groups);
reinterpret_cast<float*>(expert_scales_groups);
*expert_scales_groups_tmp =
reinterpret_cast<float*>(&expert_scales)[scales_offset_tmp];
reinterpret_cast<const float*>(expert_scales)[scales_offset_tmp];
} else if constexpr (GROUPS == 4) {
float2* expert_scales_groups_tmp =
reinterpret_cast<float2*>(&expert_scales_groups);
reinterpret_cast<float2*>(expert_scales_groups);
*expert_scales_groups_tmp =
reinterpret_cast<float2*>(&expert_scales)[scales_offset_tmp];
reinterpret_cast<const float2*>(expert_scales)[scales_offset_tmp];
} else if constexpr (GROUPS == 8) {
float4* expert_scales_groups_tmp =
reinterpret_cast<float4*>(&expert_scales_groups);
reinterpret_cast<float4*>(expert_scales_groups);
*expert_scales_groups_tmp =
reinterpret_cast<float4*>(&expert_scales)[scales_offset_tmp];
reinterpret_cast<const float4*>(expert_scales)[scales_offset_tmp];
}

// load all required qzeros one time
uint8_t expert_qzeros_groups[GROUPS];
if (!has_zp) {
qzero_f2 = Dtype::num2num2(Dtype::int2num(8));
if constexpr (bit == 4) {
qzero_f2 = Dtype::num2num2(Dtype::int2num(8));
} else {
qzero_f2 = Dtype::num2num2(Dtype::int2num(128));
}
} else {
int qzeros_offset_tmp = (offset_n / 2) * (size_k / group_size / GROUPS) +
offset_k / group_size / GROUPS;
if constexpr (GROUPS == 1) {
uint8_t* expert_qzeros_groups_tmp =
reinterpret_cast<uint8_t*>(&expert_qzeros_groups);
reinterpret_cast<uint8_t*>(expert_qzeros_groups);
*expert_qzeros_groups_tmp =
reinterpret_cast<uint16_t*>(&expert_qzeros)[qzeros_offset_tmp];
reinterpret_cast<const uint8_t*>(expert_qzeros)[qzeros_offset_tmp];
} else if constexpr (GROUPS == 2) {
uint16_t* expert_qzeros_groups_tmp =
reinterpret_cast<uint16_t*>(&expert_qzeros_groups);
reinterpret_cast<uint16_t*>(expert_qzeros_groups);
*expert_qzeros_groups_tmp =
reinterpret_cast<uint16_t*>(&expert_qzeros)[qzeros_offset_tmp];
reinterpret_cast<const uint16_t*>(expert_qzeros)[qzeros_offset_tmp];
} else if constexpr (GROUPS == 4) {
uint32_t* expert_qzeros_groups_tmp =
reinterpret_cast<uint32_t*>(&expert_qzeros_groups);
reinterpret_cast<uint32_t*>(expert_qzeros_groups);
*expert_qzeros_groups_tmp =
reinterpret_cast<uint32_t*>(&expert_qzeros)[qzeros_offset_tmp];
reinterpret_cast<const uint32_t*>(expert_qzeros)[qzeros_offset_tmp];
} else if constexpr (GROUPS == 8) {
uint64_t* expert_qzeros_groups_tmp =
reinterpret_cast<uint64_t*>(&expert_qzeros_groups);
reinterpret_cast<uint64_t*>(expert_qzeros_groups);
*expert_qzeros_groups_tmp =
reinterpret_cast<uint64_t*>(&expert_qzeros)[qzeros_offset_tmp];
reinterpret_cast<const uint64_t*>(expert_qzeros)[qzeros_offset_tmp];
}
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/moe/moe_wna16_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ __device__ inline void dequant<nv_bfloat162, 8>(int q, nv_bfloat162* res) {
fp32_intermediates[2] -= 8388608.f;
fp32_intermediates[3] -= 8388608.f;

uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&res);
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(res);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,12 +694,13 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
block_size_m=config["BLOCK_SIZE_M"]))

if use_moe_wna16_cuda:
bit = 4 if use_int4_w4a16 else 8
ops.moe_wna16_gemm(A, C, B, B_scale, B_zp,
topk_weights if mul_routed_weight else None,
sorted_token_ids, expert_ids,
num_tokens_post_padded, top_k,
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"], 4)
config["BLOCK_SIZE_K"], bit)
return

fused_moe_kernel_gptq_awq[grid](
Expand Down

0 comments on commit 2c1497b

Please sign in to comment.