Skip to content

Commit

Permalink
[fix] fix conv
Browse files Browse the repository at this point in the history
  • Loading branch information
Alcanderian committed Apr 3, 2024
1 parent a5c0ad8 commit 1c97cf2
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 8 deletions.
12 changes: 11 additions & 1 deletion src/ppl/nn/engines/x86/kernels/onnx/conv1d_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,23 @@ ppl::common::RetCode Conv1dKernel::DoExecute(KernelExecContext* ctx) {
}

#ifdef DUMP_CONV
std::string conv_dump_name = GetName();
if (cur_executor->conv_param()->fuse_flag & kernel::x86::conv_fuse_flag::SUM) {
conv_dump_name += "-sum";
}
if (cur_executor->conv_param()->fuse_flag & kernel::x86::conv_fuse_flag::RELU) {
conv_dump_name += "-relu";
}
if (cur_executor->conv_param()->fuse_flag & kernel::x86::conv_fuse_flag::RELU6) {
conv_dump_name += "-relu6";
}
fprintf(stderr, CASE_STRING_FMT() "\n", cur_executor->conv_param()->group, X_shape.GetDim(0),
cur_executor->conv_param()->channels, X_shape.GetDim(2), X_shape.GetDim(3),
cur_executor->conv_param()->num_output, Y_shape.GetDim(2), Y_shape.GetDim(3),
cur_executor->conv_param()->kernel_h, cur_executor->conv_param()->kernel_w,
cur_executor->conv_param()->stride_h, cur_executor->conv_param()->stride_w,
cur_executor->conv_param()->pad_h, cur_executor->conv_param()->pad_w,
cur_executor->conv_param()->dilation_h - 1, cur_executor->conv_param()->dilation_w - 1, GetName().c_str());
cur_executor->conv_param()->dilation_h - 1, cur_executor->conv_param()->dilation_w - 1, conv_dump_name.c_str());
#endif

PPLNN_X86_REALLOC_TENSOR_BUFFER(Y);
Expand Down
12 changes: 11 additions & 1 deletion src/ppl/nn/engines/x86/kernels/onnx/conv2d_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,23 @@ ppl::common::RetCode Conv2dKernel::DoExecute(KernelExecContext* ctx) {
}

#ifdef DUMP_CONV
std::string conv_dump_name = GetName();
if (cur_executor->conv_param()->fuse_flag & kernel::x86::conv_fuse_flag::SUM) {
conv_dump_name += "-sum";
}
if (cur_executor->conv_param()->fuse_flag & kernel::x86::conv_fuse_flag::RELU) {
conv_dump_name += "-relu";
}
if (cur_executor->conv_param()->fuse_flag & kernel::x86::conv_fuse_flag::RELU6) {
conv_dump_name += "-relu6";
}
fprintf(stderr, CASE_STRING_FMT() "\n", cur_executor->conv_param()->group, X->GetShape()->GetDim(0),
cur_executor->conv_param()->channels, X->GetShape()->GetDim(2), X->GetShape()->GetDim(3),
cur_executor->conv_param()->num_output, Y->GetShape()->GetDim(2), Y->GetShape()->GetDim(3),
cur_executor->conv_param()->kernel_h, cur_executor->conv_param()->kernel_w,
cur_executor->conv_param()->stride_h, cur_executor->conv_param()->stride_w,
cur_executor->conv_param()->pad_h, cur_executor->conv_param()->pad_w,
cur_executor->conv_param()->dilation_h - 1, cur_executor->conv_param()->dilation_w - 1, GetName().c_str());
cur_executor->conv_param()->dilation_h - 1, cur_executor->conv_param()->dilation_w - 1, conv_dump_name.c_str());
#endif

PPLNN_X86_REALLOC_TENSOR_BUFFER(Y);
Expand Down
8 changes: 4 additions & 4 deletions src/ppl/nn/engines/x86/kernels/onnx/conv_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,17 @@ ppl::common::RetCode ConvKernel::DoExecute(KernelExecContext* ctx) {
if (kernel_dims == 1) {
return kernel::x86::conv1d_ndarray_fp32(
GetISA(), X->GetShape(), sum_src_shape, Y->GetShape(),
X->GetBufferPtr<float>(), W->GetBufferPtr<float>(),
sum_src_data, b_data, param_->param->group, channels, num_output,
X->GetBufferPtr<float>(), sum_src_data, W->GetBufferPtr<float>(),
b_data, param_->param->group, channels, num_output,
param_->param->kernel_shape[0], param_->param->strides[0],
param_->param->pads[0], param_->param->dilations[0],
param_->fuse_flag, tmp_buffer, Y->GetBufferPtr<float>());
}
if (kernel_dims == 2) {
return kernel::x86::conv2d_ndarray_fp32(
GetISA(), X->GetShape(), sum_src_shape, Y->GetShape(),
X->GetBufferPtr<float>(), W->GetBufferPtr<float>(),
sum_src_data, b_data, param_->param->group, channels, num_output,
X->GetBufferPtr<float>(), sum_src_data, W->GetBufferPtr<float>(),
b_data, param_->param->group, channels, num_output,
param_->param->kernel_shape[0], param_->param->kernel_shape[1],
param_->param->strides[0], param_->param->strides[1],
param_->param->pads[0], param_->param->pads[1],
Expand Down
4 changes: 2 additions & 2 deletions src/ppl/nn/engines/x86/optimizer/ops/onnx/conv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ RetCode ConvOp::DoInit(const OptKernelOptions& options) {
infer_dims_func_ = [this](InputOutputInfo* info) -> RetCode {
auto x = info->GetInput<TensorImpl>(0)->GetShape();
auto w = info->GetInput<TensorImpl>(1)->GetShape();
if (x->GetDim(1) != w->GetDim(1)) {
if (x->GetDim(1) != w->GetDim(1) * param_->group) {
LOG(ERROR) << "input tensor's channels(" << x->GetDim(1) <<
") and weight's channels(" << w->GetDim(1) << ") not match.";
") and weight's channels*group(" << w->GetDim(1) << "*"<< param_->group << ") not match.";
return ppl::common::RC_INVALID_VALUE;
}
return onnx::ReshapeConv(info, param_.get());
Expand Down

0 comments on commit 1c97cf2

Please sign in to comment.