Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConvInteger: fix parsing for x_zero_point and w_zero_point (#3763) #3781

Open
wants to merge 1 commit into
base: release/rocm-rel-6.3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions src/onnx/parse_convolution.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -141,17 +141,14 @@
return all_zeros;
}

static auto
static migraphx::operation
qparam_broadcast_op(instruction_ref qparam, std::vector<std::size_t> lens, std::size_t axis)
{
if(qparam->get_shape().scalar())
if(qparam->get_shape().elements() == 1)
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});

Check warning on line 151 in src/onnx/parse_convolution.cpp

View check run for this annotation

Codecov / codecov/patch

src/onnx/parse_convolution.cpp#L151

Added line #L151 was not covered by tests
}

static instruction_ref handle_quant_bias(const operation& op,
Expand All @@ -162,27 +159,37 @@
const instruction_ref& w_zp,
onnx_parser::node_info& info)
{
// to handle the bias, apply the following transformation:
// conv(x-x_zp,w-w_zp) = conv(x,w) - conv(x_zp,w) - conv(x,w_zp) + conv(x_zp,w_zp)
instruction_ref ret = input;

// multibroadcast (or broadcast) zero points according to spec
// x_zp should be a scalar or literal with one element
// w_zp can be either a single element or a 1d tensor with size out_channels
migraphx::operation x_zp_bc =
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}});
migraphx::operation w_zp_bc = qparam_broadcast_op(w_zp, weights->get_shape().lens(), 0);

if(not is_symmetric_zero_point(x_zp))
{
auto out_zp_1 = info.add_common_op(op.name(), x_zp, weights);
auto x_zp_mb = info.add_instruction(x_zp_bc, x_zp);
auto out_zp_1 = info.add_instruction(op, x_zp_mb, weights);
ret = info.add_common_op("sub", ret, out_zp_1);
}

if(not is_symmetric_zero_point(w_zp))
{
auto out_zp_2 = info.add_common_op(op.name(), x, w_zp);
auto w_zp_mb = info.add_instruction(w_zp_bc, w_zp);
auto out_zp_2 = info.add_instruction(op, x, w_zp_mb);
ret = info.add_common_op("sub", ret, out_zp_2);
}

if(not(is_symmetric_zero_point(x_zp)) and not(is_symmetric_zero_point(w_zp)))
{
auto x_zp_bc =
info.add_instruction(qparam_broadcast_op(x_zp, x->get_shape().lens(), 0), x_zp);
auto w_zp_bc = info.add_instruction(
qparam_broadcast_op(w_zp, weights->get_shape().lens(), 0), w_zp);
auto x_zp_mb = info.add_instruction(x_zp_bc, x_zp);
auto w_zp_mb = info.add_instruction(w_zp_bc, w_zp);

auto out_zp_3 = info.add_instruction(op, x_zp_bc, w_zp_bc);
auto out_zp_3 = info.add_instruction(op, x_zp_mb, w_zp_mb);

ret = info.add_common_op("add", ret, out_zp_3);
}
Expand Down
4 changes: 3 additions & 1 deletion src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -744,6 +744,8 @@ struct find_inner_broadcast
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
if(ins->get_operator().name() == "layout")
return;
const auto& broadcasts = ins->inputs();
if(broadcasts.empty())
return;
Expand Down
6 changes: 3 additions & 3 deletions test/onnx/convinteger_bias_test.onnx
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
strides@@ convinteger_bias_testZ
0





 Z
1





Z
Expand All @@ -23,7 +23,7 @@
b
3





B
34 changes: 34 additions & 0 deletions test/onnx/convinteger_dual_bias_simple_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
 !convinteger_dual_bias_simple_test:à
B
0
1
2
34" ConvInteger*
dilations@@ *
strides@@ !convinteger_dual_bias_simple_testZ
0




Z
1




Z
2


Z
3


b
4




B
20 changes: 11 additions & 9 deletions test/onnx/convinteger_dual_bias_test.onnx
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@ B
strides@@ convinteger_dual_bias_testZ
0





Z



Z
1





Z

Z
2


Expand All @@ -28,7 +30,7 @@ B
b
4





B

B
23 changes: 20 additions & 3 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,10 +1655,10 @@ def convinteger_no_bias_uint8_test():

@onnx_test()
def convinteger_bias_test():
x = helper.make_tensor_value_info('0', TensorProto.INT8, [1, 3, 32, 32])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [1, 3, 5, 5])
x = helper.make_tensor_value_info('0', TensorProto.INT8, [2, 3, 32, 32])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3, 5, 5])
z = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
out = helper.make_tensor_value_info('3', TensorProto.INT32, [1, 2, 28, 28])
out = helper.make_tensor_value_info('3', TensorProto.INT32, [2, 4, 28, 28])

node = onnx.helper.make_node('ConvInteger',
inputs=['0', '1', '2'],
Expand All @@ -1671,6 +1671,23 @@ def convinteger_bias_test():

@onnx_test()
def convinteger_dual_bias_test():
x = helper.make_tensor_value_info('0', TensorProto.INT8, [2, 3, 10, 10])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3, 3, 3])
z = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
w = helper.make_tensor_value_info('3', TensorProto.INT8, [1])
out = helper.make_tensor_value_info('4', TensorProto.INT32, [2, 4, 8, 8])

node = onnx.helper.make_node('ConvInteger',
inputs=['0', '1', '2', '3'],
outputs=['4'],
dilations=[1, 1],
strides=[1, 1])

return ([node], [x, y, z, w], [out])


@onnx_test()
def convinteger_dual_bias_simple_test():
x = helper.make_tensor_value_info('0', TensorProto.INT8, [1, 3, 5, 5])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [1, 3, 2, 2])
z = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
Expand Down
14 changes: 5 additions & 9 deletions test/onnx/parse/convinteger_bias_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -28,24 +28,20 @@ TEST_CASE(convinteger_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {1, 3, 32, 32}});
auto weights = mm->add_parameter("1", {migraphx::shape::int8_type, {1, 3, 5, 5}});
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {2, 3, 32, 32}});
auto weights = mm->add_parameter("1", {migraphx::shape::int8_type, {4, 3, 5, 5}});
auto data_bias = mm->add_parameter("2", {migraphx::shape::int8_type, {1}, {1}});

mm->add_literal(migraphx::literal{migraphx::shape{data->get_shape().type(), {1}, {0}}, {0}});
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), data, weights);

auto bcast_data_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", weights->get_shape().lens()}}),
data_bias);
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);

auto quant2 =
mm->add_instruction(migraphx::make_op("quant_convolution"), bcast_data_bias, weights);

auto bcast_quant2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant2);

mm->add_instruction(migraphx::make_op("sub"), quant, bcast_quant2);
mm->add_instruction(migraphx::make_op("sub"), quant, quant2);

auto prog = optimize_onnx("convinteger_bias_test.onnx");
EXPECT(p == prog);
Expand Down
32 changes: 14 additions & 18 deletions test/onnx/parse/convinteger_dual_bias_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -28,41 +28,37 @@ TEST_CASE(convinteger_dual_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {1, 3, 5, 5}});
auto weight = mm->add_parameter("1", {migraphx::shape::int8_type, {1, 3, 2, 2}});
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {2, 3, 10, 10}});
auto weight = mm->add_parameter("1", {migraphx::shape::int8_type, {4, 3, 3, 3}});
auto data_bias = mm->add_parameter("2", {migraphx::shape::int8_type, {1}, {1}});
auto weight_bias = mm->add_parameter("3", {migraphx::shape::int8_type, {1}, {1}});

auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), data, weight);

auto mbcast_data_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}), data_bias);
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);

auto quant_db_w =
auto quant_mb_w =
mm->add_instruction(migraphx::make_op("quant_convolution"), mbcast_data_bias, weight);

auto quant_mb_w = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant_db_w);

quant = mm->add_instruction(migraphx::make_op("sub"), quant, quant_mb_w);

auto mbcast_weight_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), weight_bias);
migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}),
weight_bias);

auto quant_d_wb =
auto quant_md_wb =
mm->add_instruction(migraphx::make_op("quant_convolution"), data, mbcast_weight_bias);

auto quant_md_wb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant_d_wb);

quant = mm->add_instruction(migraphx::make_op("sub"), quant, quant_md_wb);

auto bcast_data_bias = mm->add_instruction(
migraphx::make_op("broadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);
auto bcast_weight_bias = mm->add_instruction(
migraphx::make_op("broadcast", {{"out_lens", weight->get_shape().lens()}}), weight_bias);
mbcast_data_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);
mbcast_weight_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}),
weight_bias);
auto bias_quant = mm->add_instruction(
migraphx::make_op("quant_convolution"), bcast_data_bias, bcast_weight_bias);
migraphx::make_op("quant_convolution"), mbcast_data_bias, mbcast_weight_bias);

mm->add_instruction(migraphx::make_op("add"), quant, bias_quant);

Expand Down
Loading
Loading