diff --git a/src/ppl/nn/engines/llm_cuda/kernels/opmx/rms_norm_kernel.cc b/src/ppl/nn/engines/llm_cuda/kernels/opmx/rms_norm_kernel.cc index df65f2723..6c4e3fdb8 100644 --- a/src/ppl/nn/engines/llm_cuda/kernels/opmx/rms_norm_kernel.cc +++ b/src/ppl/nn/engines/llm_cuda/kernels/opmx/rms_norm_kernel.cc @@ -37,11 +37,11 @@ ppl::common::RetCode RMSNormKernel::DoExecute(KernelExecContext* ctx) { PPLNN_LLM_CUDA_DEBUG_TRACE("Input [weight]:\n"); PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(weight); - void *skip_in_data = nullptr; + void *skip_in_ptr = nullptr; if (skip_in) { PPLNN_LLM_CUDA_DEBUG_TRACE("Input [skip_in]:\n"); PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(skip_in); - skip_in_data = skip_in->GetBufferPtr(); + skip_in_ptr = skip_in->GetBufferPtr(); } PPLNN_LLM_CUDA_DEBUG_TRACE("eps: %f\n", param_->eps); @@ -52,11 +52,6 @@ ppl::common::RetCode RMSNormKernel::DoExecute(KernelExecContext* ctx) { auto input_shape = input->GetShape(); - if (param_->skip_term == false) { - LOG(ERROR) << "currently only support skip_term == true."; - return ppl::common::RC_UNSUPPORTED; - } - if (param_->axis != -1 && param_->axis != input_shape->GetDim(input_shape->GetDimCount() - 1)) { LOG(ERROR) << "currently only support axis == -1 or input's last dim."; return ppl::common::RC_UNSUPPORTED; @@ -74,6 +69,7 @@ ppl::common::RetCode RMSNormKernel::DoExecute(KernelExecContext* ctx) { PPLNN_LLM_CUDA_DEBUG_TRACE("Output [output]:\n"); PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(output); + void *skip_out_ptr = nullptr; if (skip_out) { if (can_trans_skip_in) { skip_out->TransferBufferFrom(skip_in); @@ -82,6 +78,7 @@ ppl::common::RetCode RMSNormKernel::DoExecute(KernelExecContext* ctx) { } PPLNN_LLM_CUDA_DEBUG_TRACE("Output [skip_out]:\n"); PPLNN_LLM_CUDA_TENSOR_PRINT_DEBUG_MSG(skip_out); + skip_out_ptr = skip_out->GetBufferPtr(); } if (param_->skip_term && !skip_out) { @@ -94,17 +91,22 @@ ppl::common::RetCode RMSNormKernel::DoExecute(KernelExecContext* ctx) { return ppl::common::RC_UNSUPPORTED; } - return ppl::kernel::llm::cuda::pmx::rms_norm( + const int64_t dim_count = input_shape->GetDimCount(); + const int64_t real_axis = param_->axis > 0 ? param_->axis : (param_->axis + dim_count); + + const int64_t batch = input_shape->CalcElementsToDimensionIncludingPadding(real_axis); + const int64_t norm_dim = input_shape->CalcElementsFromDimensionIncludingPadding(real_axis); + + return ppl::kernel::llm::cuda::pmx::rms_norm_fp16( GetStream(), - input_shape, input_data, weight->GetBufferPtr(), - skip_in_data, - param_->axis, + skip_in_ptr, param_->eps, - param_->skip_term, - output->GetBufferPtr(), - skip_out->GetBufferPtr() + batch, + norm_dim, + skip_out_ptr, + output->GetBufferPtr() ); }