diff --git a/paddle/fluid/framework/ir/onednn/cpu_bfloat16_pass.cc b/paddle/fluid/framework/ir/onednn/cpu_bfloat16_pass.cc index 7b621a98491dc8..9088924c437e20 100644 --- a/paddle/fluid/framework/ir/onednn/cpu_bfloat16_pass.cc +++ b/paddle/fluid/framework/ir/onednn/cpu_bfloat16_pass.cc @@ -28,7 +28,7 @@ class Quanter { void AddQuantOps() { if (IsNotPermittedOpType()) return; - std::vector linked_xputs; + std::unordered_map linked_xputs; for (const auto& logical_xput : op_xputs) { std::vector quant_xput_names; @@ -39,7 +39,12 @@ class Quanter { const auto& physical_xputs_names = logical_xput.second; for (const auto& physical_xput_name : physical_xputs_names) { - if (IsAlreadyLinked(linked_xputs, physical_xput_name)) continue; + // In case the input is repetitively used, where the input should be + // still added. + if (IsAlreadyLinked(linked_xputs, physical_xput_name)) { + quant_xput_names.emplace_back(linked_xputs[physical_xput_name]); + continue; + } VarDesc quant_x_desc( patterns::PDNodeName(get_op_type(), get_op_edge())); @@ -52,7 +57,7 @@ class Quanter { auto physical_xput_node = xputs_map[physical_xput_name]; link_nodes(physical_xput_node, quant_op, quant_x_node); counter++; - linked_xputs.push_back(physical_xput_name); + linked_xputs[physical_xput_name] = xput_name; } set_edge(logical_xput_name, quant_xput_names); @@ -87,10 +92,10 @@ class Quanter { virtual void set_edge(const std::string& logical_xput_name, const std::vector& quant_xput_names) = 0; - bool IsAlreadyLinked(const std::vector& node_names, - const std::string& node_name) const { - return std::find(node_names.begin(), node_names.end(), node_name) != - node_names.end(); + bool IsAlreadyLinked( + const std::unordered_map& node_names_map, + const std::string& node_name) const { + return node_names_map.find(node_name) != node_names_map.end(); } virtual ir::Node* create_quant_op(const std::string& input_name,