diff --git a/src/simplify_reshapes.cpp b/src/simplify_reshapes.cpp index 5c94d4a2652..c642c9ee31a 100644 --- a/src/simplify_reshapes.cpp +++ b/src/simplify_reshapes.cpp @@ -267,7 +267,8 @@ struct find_concat_multibroadcasts { auto matcher() const { - return match::name("concat")(match::all_of[match::inputs()](match::name("multibroadcast"))); + return match::name("concat")( + match::all_of[match::inputs()](match::name("multibroadcast", "broadcast"))); } void apply(module& m, const match::matcher_result& mr) const @@ -287,32 +288,46 @@ struct find_concat_multibroadcasts return; } + // Skip if the broadcasts are different + auto broadcast = concat_inputs.front()->get_operator(); + auto broadcast_value = broadcast.to_value(); + if(not std::all_of(concat_inputs.begin() + 1, concat_inputs.end(), [&](instruction_ref b) { + if(b->name() != broadcast.name()) + return false; + if(broadcast.name() == "broadcast") + return b->get_operator().to_value()["axis"] == broadcast_value["axis"]; + return true; + })) + { + return; + } + // Get the inputs of multibroadcast ops. Will be used as inputs to new concat op - std::vector mb_inputs(concat_inputs.size()); - std::transform(concat_inputs.begin(), concat_inputs.end(), mb_inputs.begin(), [](auto i) { + std::vector inputs(concat_inputs.size()); + std::transform(concat_inputs.begin(), concat_inputs.end(), inputs.begin(), [](auto i) { return i->inputs().front(); }); - // Check that the inputs into the multibroadcasts have the same rank - const auto& first_shape = mb_inputs.front()->get_shape(); - if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto mb_in) { - return mb_in->get_shape().ndim() == first_shape.ndim(); + // Check that the inputs into the broadcasts have the same rank + const auto& first_shape = inputs.front()->get_shape(); + if(not std::all_of(inputs.begin() + 1, inputs.end(), [&](auto input) { + return input->get_shape().ndim() == first_shape.ndim(); })) { return; } // Reduce axis by number of leading broadcasted dimensions - if(mb_inputs.front()->get_shape().lens().size() < concat_out_lens.size()) + if(inputs.front()->get_shape().lens().size() < concat_out_lens.size()) { concat_op.axis -= std::count(front_mb_strides.begin(), front_mb_strides.begin() + concat_op.axis, 0); } - // Inputs to multibroadcasts should have the same dimensions except for the axis to + // Inputs to broadcasts should have the same dimensions except for the axis to // concatenate over - const auto& front_in_lens = mb_inputs.front()->get_shape().lens(); - if(not std::all_of(mb_inputs.begin() + 1, mb_inputs.end(), [&](auto input_to_mb) { + const auto& front_in_lens = inputs.front()->get_shape().lens(); + if(not std::all_of(inputs.begin() + 1, inputs.end(), [&](auto input_to_mb) { const auto& lens = input_to_mb->get_shape().lens(); return std::equal( lens.begin(), lens.begin() + concat_op.axis, front_in_lens.begin()) and @@ -324,10 +339,9 @@ struct find_concat_multibroadcasts return; } - auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, mb_inputs); - m.replace_instruction(concat_ins, - migraphx::make_op("multibroadcast", {{"out_lens", concat_out_lens}}), - new_concat_ins); + auto new_concat_ins = m.insert_instruction(concat_ins, concat_op, inputs); + broadcast.from_value({{"out_lens", concat_ins->get_shape().lens()}}); + m.replace_instruction(concat_ins, broadcast, new_concat_ins); } }; diff --git a/test/simplify_reshapes_test.cpp b/test/simplify_reshapes_test.cpp index ad1b281198f..5e2b0bf1fc8 100644 --- a/test/simplify_reshapes_test.cpp +++ b/test/simplify_reshapes_test.cpp @@ -1074,6 +1074,33 @@ TEST_CASE(concat_multibroadcasts9) EXPECT(m == m_original); } +TEST_CASE(concat_broadcast1) +{ + auto s = migraphx::shape{migraphx::shape::float_type, {1024, 1024}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s); + auto y = m1.add_parameter("y", s); + auto xb = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {8, 1024, 1024}}}), x); + auto yb = m1.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {8, 1024, 1024}}}), y); + auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), xb, yb); + m1.add_return({concat}); + } + migraphx::module m2; + { + auto x = m2.add_parameter("x", s); + auto y = m2.add_parameter("y", s); + auto concat = m2.add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y); + auto b = m2.add_instruction( + migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {8, 1024, 2048}}}), concat); + m2.add_return({b}); + } + run_pass(m1); + EXPECT(m1 == m2); +} + TEST_CASE(concat_transpose1) { migraphx::module m;