From 1c97cf271101fca60f6faaf2467d671940dea674 Mon Sep 17 00:00:00 2001 From: alcanderian Date: Wed, 3 Apr 2024 21:41:51 +0800 Subject: [PATCH] [fix] fix conv --- src/ppl/nn/engines/x86/kernels/onnx/conv1d_kernel.cc | 12 +++++++++++- src/ppl/nn/engines/x86/kernels/onnx/conv2d_kernel.cc | 12 +++++++++++- src/ppl/nn/engines/x86/kernels/onnx/conv_kernel.cc | 8 ++++---- src/ppl/nn/engines/x86/optimizer/ops/onnx/conv_op.cc | 4 ++-- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/ppl/nn/engines/x86/kernels/onnx/conv1d_kernel.cc b/src/ppl/nn/engines/x86/kernels/onnx/conv1d_kernel.cc index 7ed051bc4..2928986ec 100644 --- a/src/ppl/nn/engines/x86/kernels/onnx/conv1d_kernel.cc +++ b/src/ppl/nn/engines/x86/kernels/onnx/conv1d_kernel.cc @@ -84,13 +84,23 @@ ppl::common::RetCode Conv1dKernel::DoExecute(KernelExecContext* ctx) { } #ifdef DUMP_CONV + std::string conv_dump_name = GetName(); + if (cur_executor->conv_param()->fuse_flag & kernel::x86::conv_fuse_flag::SUM) { + conv_dump_name += "-sum"; + } + if (cur_executor->conv_param()->fuse_flag & kernel::x86::conv_fuse_flag::RELU) { + conv_dump_name += "-relu"; + } + if (cur_executor->conv_param()->fuse_flag & kernel::x86::conv_fuse_flag::RELU6) { + conv_dump_name += "-relu6"; + } fprintf(stderr, CASE_STRING_FMT() "\n", cur_executor->conv_param()->group, X_shape.GetDim(0), cur_executor->conv_param()->channels, X_shape.GetDim(2), X_shape.GetDim(3), cur_executor->conv_param()->num_output, Y_shape.GetDim(2), Y_shape.GetDim(3), cur_executor->conv_param()->kernel_h, cur_executor->conv_param()->kernel_w, cur_executor->conv_param()->stride_h, cur_executor->conv_param()->stride_w, cur_executor->conv_param()->pad_h, cur_executor->conv_param()->pad_w, - cur_executor->conv_param()->dilation_h - 1, cur_executor->conv_param()->dilation_w - 1, GetName().c_str()); + cur_executor->conv_param()->dilation_h - 1, cur_executor->conv_param()->dilation_w - 1, conv_dump_name.c_str()); #endif PPLNN_X86_REALLOC_TENSOR_BUFFER(Y); diff --git a/src/ppl/nn/engines/x86/kernels/onnx/conv2d_kernel.cc b/src/ppl/nn/engines/x86/kernels/onnx/conv2d_kernel.cc index 23f64de1f..d5ab99b5f 100644 --- a/src/ppl/nn/engines/x86/kernels/onnx/conv2d_kernel.cc +++ b/src/ppl/nn/engines/x86/kernels/onnx/conv2d_kernel.cc @@ -79,13 +79,23 @@ ppl::common::RetCode Conv2dKernel::DoExecute(KernelExecContext* ctx) { } #ifdef DUMP_CONV + std::string conv_dump_name = GetName(); + if (cur_executor->conv_param()->fuse_flag & kernel::x86::conv_fuse_flag::SUM) { + conv_dump_name += "-sum"; + } + if (cur_executor->conv_param()->fuse_flag & kernel::x86::conv_fuse_flag::RELU) { + conv_dump_name += "-relu"; + } + if (cur_executor->conv_param()->fuse_flag & kernel::x86::conv_fuse_flag::RELU6) { + conv_dump_name += "-relu6"; + } fprintf(stderr, CASE_STRING_FMT() "\n", cur_executor->conv_param()->group, X->GetShape()->GetDim(0), cur_executor->conv_param()->channels, X->GetShape()->GetDim(2), X->GetShape()->GetDim(3), cur_executor->conv_param()->num_output, Y->GetShape()->GetDim(2), Y->GetShape()->GetDim(3), cur_executor->conv_param()->kernel_h, cur_executor->conv_param()->kernel_w, cur_executor->conv_param()->stride_h, cur_executor->conv_param()->stride_w, cur_executor->conv_param()->pad_h, cur_executor->conv_param()->pad_w, - cur_executor->conv_param()->dilation_h - 1, cur_executor->conv_param()->dilation_w - 1, GetName().c_str()); + cur_executor->conv_param()->dilation_h - 1, cur_executor->conv_param()->dilation_w - 1, conv_dump_name.c_str()); #endif PPLNN_X86_REALLOC_TENSOR_BUFFER(Y); diff --git a/src/ppl/nn/engines/x86/kernels/onnx/conv_kernel.cc b/src/ppl/nn/engines/x86/kernels/onnx/conv_kernel.cc index b761019ae..bd3a38f95 100644 --- a/src/ppl/nn/engines/x86/kernels/onnx/conv_kernel.cc +++ b/src/ppl/nn/engines/x86/kernels/onnx/conv_kernel.cc @@ -137,8 +137,8 @@ ppl::common::RetCode ConvKernel::DoExecute(KernelExecContext* ctx) { if (kernel_dims == 1) { return kernel::x86::conv1d_ndarray_fp32( GetISA(), X->GetShape(), sum_src_shape, Y->GetShape(), - X->GetBufferPtr(), W->GetBufferPtr(), - sum_src_data, b_data, param_->param->group, channels, num_output, + X->GetBufferPtr(), sum_src_data, W->GetBufferPtr(), + b_data, param_->param->group, channels, num_output, param_->param->kernel_shape[0], param_->param->strides[0], param_->param->pads[0], param_->param->dilations[0], param_->fuse_flag, tmp_buffer, Y->GetBufferPtr()); @@ -146,8 +146,8 @@ ppl::common::RetCode ConvKernel::DoExecute(KernelExecContext* ctx) { if (kernel_dims == 2) { return kernel::x86::conv2d_ndarray_fp32( GetISA(), X->GetShape(), sum_src_shape, Y->GetShape(), - X->GetBufferPtr(), W->GetBufferPtr(), - sum_src_data, b_data, param_->param->group, channels, num_output, + X->GetBufferPtr(), sum_src_data, W->GetBufferPtr(), + b_data, param_->param->group, channels, num_output, param_->param->kernel_shape[0], param_->param->kernel_shape[1], param_->param->strides[0], param_->param->strides[1], param_->param->pads[0], param_->param->pads[1], diff --git a/src/ppl/nn/engines/x86/optimizer/ops/onnx/conv_op.cc b/src/ppl/nn/engines/x86/optimizer/ops/onnx/conv_op.cc index 8bb337f32..e28ac17e7 100644 --- a/src/ppl/nn/engines/x86/optimizer/ops/onnx/conv_op.cc +++ b/src/ppl/nn/engines/x86/optimizer/ops/onnx/conv_op.cc @@ -66,9 +66,9 @@ RetCode ConvOp::DoInit(const OptKernelOptions& options) { infer_dims_func_ = [this](InputOutputInfo* info) -> RetCode { auto x = info->GetInput(0)->GetShape(); auto w = info->GetInput(1)->GetShape(); - if (x->GetDim(1) != w->GetDim(1)) { + if (x->GetDim(1) != w->GetDim(1) * param_->group) { LOG(ERROR) << "input tensor's channels(" << x->GetDim(1) << - ") and weight's channels(" << w->GetDim(1) << ") not match."; + ") and weight's channels*group(" << w->GetDim(1) << "*"<< param_->group << ") not match."; return ppl::common::RC_INVALID_VALUE; } return onnx::ReshapeConv(info, param_.get());