From c32a7c7c0c688ed81d2f4ad701a09d0edd095ffe Mon Sep 17 00:00:00 2001 From: shaochangxu <85155497+shaochangxu@users.noreply.github.com> Date: Sat, 11 Jan 2025 13:49:39 +0800 Subject: [PATCH] [Bugfix] fused_experts_impl wrong compute type for float32 (#11921) Signed-off-by: shaochangxu.scx Co-authored-by: shaochangxu.scx --- vllm/model_executor/layers/fused_moe/fused_moe.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1bb6bc753d37c..3ea6217d7c0ef 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -701,8 +701,14 @@ def fused_experts_impl(hidden_states: torch.Tensor, device=hidden_states.device, dtype=hidden_states.dtype) - compute_type = (tl.bfloat16 - if hidden_states.dtype == torch.bfloat16 else tl.float16) + if hidden_states.dtype == torch.bfloat16: + compute_type = tl.bfloat16 + elif hidden_states.dtype == torch.float16: + compute_type = tl.float16 + elif hidden_states.dtype == torch.float32: + compute_type = tl.float32 + else: + raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") if inplace: out_hidden_states = hidden_states