diff --git a/src/ppl/nn/auxtools/to_graphviz.h b/src/ppl/nn/auxtools/to_graphviz.h index a1712522c..5833a460c 100644 --- a/src/ppl/nn/auxtools/to_graphviz.h +++ b/src/ppl/nn/auxtools/to_graphviz.h @@ -27,7 +27,7 @@ namespace ppl { namespace nn { namespace utils { static string GenNodeIdStr(const ir::Node* node) { auto& type = node->GetType(); - return node->GetName() + "[" + type.domain + ":" + type.name + ":" + ToString(type.version) + "]"; + return node->GetName() + string("[") + type.domain + ":" + type.name + ":" + ToString(type.version) + "]"; } static string ToGraphviz(const ir::GraphTopo* topo) { diff --git a/src/ppl/nn/engines/arm/optimizer/ops/onnx/conv_op.cc b/src/ppl/nn/engines/arm/optimizer/ops/onnx/conv_op.cc index bcaf8caca..5144ec2e6 100644 --- a/src/ppl/nn/engines/arm/optimizer/ops/onnx/conv_op.cc +++ b/src/ppl/nn/engines/arm/optimizer/ops/onnx/conv_op.cc @@ -162,7 +162,7 @@ ppl::common::RetCode ConvOp::SelectAlgorithm(const InputOutputInfo& info, const conv2d_param_->mgr = conv2d_algo_selector::fast_gen_algo( *info.GetInput<TensorImpl>(0)->GetShape(), options.engine_options->forward_precision, - options.engine_options->dynamic_tuning_level, options.engine_options->winograd_level, + options.engine_options->dynamic_tuning_level, options.engine_options->winograd_level, options.device->GetISA(), conv2d_param_->param, options.device->GetAllocator()); if (conv2d_param_->mgr == nullptr) { @@ -183,7 +183,7 @@ ppl::common::RetCode ConvOp::SelectAlgorithm(const InputOutputInfo& info, const // Note: If the filter is reused, generate a new input edge to hold cvt_filter which may differ due to different algo/sched_param. ppl::nn::TensorBufferInfo * new_filter = &options.info->constants[node->GetInput(1)]; if (new_filter && new_filter->IsBufferOwner() && new_filter->GetBufferPtr<void>()) { - auto edge = options.graph_topo->AddEdge(node->GetName() + "_Input_Cvt_Filter").first; + auto edge = options.graph_topo->AddEdge(node->GetName() + std::string("_Input_Cvt_Filter")).first; edge->AddConsumer(node->GetId()); TensorImpl* tensor = new TensorImpl(edge, TENSORTYPE_RESERVED); @@ -210,7 +210,7 @@ ppl::common::RetCode ConvOp::SelectAlgorithm(const InputOutputInfo& info, const if (normal_cvt_weights_ret != ppl::common::RC_SUCCESS) { return normal_cvt_weights_ret; } - + if (new_bias && new_bias->IsBufferOwner() && new_bias->GetBufferPtr()) { bias_data = nullptr; } else if (bias_data && new_bias) { @@ -350,13 +350,13 @@ ppl::common::RetCode ConvOp::SerializeData(const ::ppl::nn::pmx::SerializationCo mgr->algo_info().data_type, mgr->algo_info().isa, conv_builder.CreateVector<int64_t>(algo_sp)); - auto fb_param_info = ppl::nn::pmx::arm::CreateConvParamInfo(conv_builder, - conv2d_param_->param.num_output, - conv2d_param_->param.channels, - conv2d_param_->mgr->get_param().pad_type, + auto fb_param_info = ppl::nn::pmx::arm::CreateConvParamInfo(conv_builder, + conv2d_param_->param.num_output, + conv2d_param_->param.channels, + conv2d_param_->mgr->get_param().pad_type, conv2d_param_->mgr->get_param().fuse_flag, (!mgr->is_zero_bias()) ? 1 : 0); - + auto fb_conv_data = ppl::nn::pmx::arm::CreateConvData(conv_builder, fb_algo_info, fb_param_info); auto fb_op_data = ppl::nn::pmx::arm::CreateOpData(conv_builder, ppl::nn::pmx::arm::PrivateDataType_ConvData, fb_conv_data.Union()); ppl::nn::pmx::arm::FinishOpDataBuffer(conv_builder, fb_op_data); @@ -415,11 +415,11 @@ ppl::common::RetCode ConvOp::DeserializeData(const ::ppl::nn::pmx::Deserializati .isa = algo_info->isa(), .data_type = algo_info->dtype() }); - + std::vector<int64_t> sp; ppl::nn::pmx::utils::Fbvec2Stdvec(algo_info->sched_param(), &sp); mgr->set_schedule_param(sp); - + const auto & shapes = *ctx.shapes; const auto & constants = *ctx.constants; diff --git a/src/ppl/nn/engines/arm/optimizer/opt_graph.cc b/src/ppl/nn/engines/arm/optimizer/opt_graph.cc index 69bd7083f..0f9ddf363 100644 --- a/src/ppl/nn/engines/arm/optimizer/opt_graph.cc +++ b/src/ppl/nn/engines/arm/optimizer/opt_graph.cc @@ -150,11 +150,11 @@ RetCode OptGraph::AddReorderOp(const OptKernelOptions& options, const edgeid_t& std::string reorder_node_name = ""; if (reorder_type == REORDER_INPUT) { - reorder_node_name = "ReorderInput_" + edge->GetName() + "_of_" + node->GetName(); + reorder_node_name = std::string("ReorderInput_") + edge->GetName() + "_of_" + node->GetName(); } else if (reorder_type == REORDER_OUTPUT) { - reorder_node_name = "ReorderOutput_" + edge->GetName() + "_of_" + node->GetName(); + reorder_node_name = std::string("ReorderOutput_") + edge->GetName() + "_of_" + node->GetName(); } else if (reorder_type == REORDER_EXTRA_INPUT) { - reorder_node_name = "ReorderExtraInput_" + edge->GetName() + "_of_" + node->GetName(); + reorder_node_name = std::string("ReorderExtraInput_") + edge->GetName() + "_of_" + node->GetName(); } auto node_ret_pair = graph_->topo->AddNode(reorder_node_name); diff --git a/src/ppl/nn/engines/arm/optimizer/rules/fuse_channel_shuffle.cc b/src/ppl/nn/engines/arm/optimizer/rules/fuse_channel_shuffle.cc index 6ab72d425..f60f54087 100644 --- a/src/ppl/nn/engines/arm/optimizer/rules/fuse_channel_shuffle.cc +++ b/src/ppl/nn/engines/arm/optimizer/rules/fuse_channel_shuffle.cc @@ -264,10 +264,10 @@ bool FuseChannelShuffleRule::Apply(const OptKernelOptions& options) { /******************** do optimize ***********************/ /** 1. create & register fused op **/ std::string channel_shuffle_node_name = "ChannelShuffle_" + - (fuse_concat ? (reshape1_prev_node->GetName() + "_") : "") + reshape1_node->GetName() + "_" + + (fuse_concat ? (reshape1_prev_node->GetName() + std::string("_")) : "") + reshape1_node->GetName() + "_" + trans_node->GetName() + "_" + reshape2_node->GetName() + - (fuse_split ? ("_" + reshape2_next_nodes[0]->GetName()) : "") + - (fuse_slice ? ("_" + reshape2_next_nodes[0]->GetName() + "_" + reshape2_next_nodes[1]->GetName()) : ""); + (fuse_split ? (std::string("_") + reshape2_next_nodes[0]->GetName()) : "") + + (fuse_slice ? (std::string("_") + reshape2_next_nodes[0]->GetName() + "_" + reshape2_next_nodes[1]->GetName()) : ""); auto node_ret_pair = graph_topo->AddNode(channel_shuffle_node_name); if (!node_ret_pair.second) { LOG(ERROR) << "node[" << channel_shuffle_node_name << "] already exists."; diff --git a/src/ppl/nn/engines/cuda/optimizer/algos/algo_conv_depthwise_int8.cc b/src/ppl/nn/engines/cuda/optimizer/algos/algo_conv_depthwise_int8.cc index 443ac45de..64bc4a77c 100644 --- a/src/ppl/nn/engines/cuda/optimizer/algos/algo_conv_depthwise_int8.cc +++ b/src/ppl/nn/engines/cuda/optimizer/algos/algo_conv_depthwise_int8.cc @@ -195,7 +195,7 @@ RetCode DepthwiseDirectInt8::ModifyParam(ir::Node* node, OptKernelOptions& optio quant_constat_info.SetBuffer(buffer, options.device, true); } - auto ret_pair = topo->AddEdge("Quant_" + node->GetName()); + auto ret_pair = topo->AddEdge(string("Quant_") + node->GetName()); auto quant_edge = ret_pair.first; auto quant_edge_id = quant_edge->GetId(); node->AddInput(quant_edge_id); diff --git a/src/ppl/nn/engines/cuda/optimizer/algos/algo_conv_imma.cc b/src/ppl/nn/engines/cuda/optimizer/algos/algo_conv_imma.cc index 28979b75e..0a1e8a9f1 100644 --- a/src/ppl/nn/engines/cuda/optimizer/algos/algo_conv_imma.cc +++ b/src/ppl/nn/engines/cuda/optimizer/algos/algo_conv_imma.cc @@ -224,7 +224,7 @@ RetCode TuringIMMAImpgemm::ModifyParam(ir::Node* node, OptKernelOptions& options quant_constat_info.SetBuffer(buffer, options.device, true); } - auto ret_pair = topo->AddEdge("Quant_" + node->GetName()); + auto ret_pair = topo->AddEdge(string("Quant_") + node->GetName()); auto quant_edge = ret_pair.first; auto quant_edge_id = quant_edge->GetId(); node->AddInput(quant_edge_id); diff --git a/src/ppl/nn/engines/cuda/optimizer/algos/algo_gemm.cc b/src/ppl/nn/engines/cuda/optimizer/algos/algo_gemm.cc index 22f6d707d..14f346442 100644 --- a/src/ppl/nn/engines/cuda/optimizer/algos/algo_gemm.cc +++ b/src/ppl/nn/engines/cuda/optimizer/algos/algo_gemm.cc @@ -253,7 +253,7 @@ RetCode GemmAlgorithm::ModifyParam(ir::Node* node, OptKernelOptions& options) { quant_constat_info.SetBuffer(buffer, options.device, true); } - auto ret_pair = topo->AddEdge("Quant_" + node->GetName()); + auto ret_pair = topo->AddEdge(string("Quant_") + node->GetName()); auto quant_edge = ret_pair.first; auto quant_edge_id = quant_edge->GetId(); node->AddInput(quant_edge_id); diff --git a/src/ppl/nn/engines/cuda/optimizer/ops/pmx/bridge_op.cc b/src/ppl/nn/engines/cuda/optimizer/ops/pmx/bridge_op.cc index cbaa62ff2..39b32a0fa 100644 --- a/src/ppl/nn/engines/cuda/optimizer/ops/pmx/bridge_op.cc +++ b/src/ppl/nn/engines/cuda/optimizer/ops/pmx/bridge_op.cc @@ -70,7 +70,7 @@ KernelImpl* BridgeOp::CreateKernelImpl() const { RetCode BridgeOp::AddInternalBridgeNode(ir::Node* node, ir::Node* new_node, ir::Edge* edge, ir::Graph* graph) { auto topo = graph->topo.get(); - auto ret_pair = topo->AddEdge("Bridge_Edge_" + edge->GetName() + "_" + node->GetName()); + auto ret_pair = topo->AddEdge(string("Bridge_Edge_") + edge->GetName() + "_" + node->GetName()); auto new_edge = ret_pair.first; edge->DelConsumer(node->GetId()); @@ -90,7 +90,7 @@ RetCode BridgeOp::AddInternalBridgeNode(ir::Node* node, ir::Node* new_node, ir:: RetCode BridgeOp::AddFinalBridgeNode(ir::Node* node, ir::Node* new_node, ir::Edge* edge, ir::Graph* graph) { auto topo = graph->topo.get(); - auto ret_pair = topo->AddEdge("Bridge_Final_Edge_" + edge->GetName() + "_" + node->GetName()); + auto ret_pair = topo->AddEdge(string("Bridge_Final_Edge_") + edge->GetName() + "_" + node->GetName()); auto new_edge = ret_pair.first; edge->SetProducer(new_node->GetId()); diff --git a/src/ppl/nn/engines/cuda/optimizer/opt_graph.cc b/src/ppl/nn/engines/cuda/optimizer/opt_graph.cc index fc7a1c430..6fe2211e9 100644 --- a/src/ppl/nn/engines/cuda/optimizer/opt_graph.cc +++ b/src/ppl/nn/engines/cuda/optimizer/opt_graph.cc @@ -241,12 +241,12 @@ RetCode OptGraph::AddBridgeKernels(const utils::SharedResource& resource) { continue; } auto edge = topo->GetEdge(edge_id); - if (edge->GetName().find("Bridge_Edge") != string::npos) { + if (string(edge->GetName()).find("Bridge_Edge") != string::npos) { continue; } auto creator = OptKernelCreatorManager::GetInstance()->Find("pmx", "Bridge", 1); - auto ret_pair = topo->AddNode("Bridge_Node_" + node->GetName() + "_" + edge->GetName()); + auto ret_pair = topo->AddNode(string("Bridge_Node_") + node->GetName() + "_" + edge->GetName()); if (!ret_pair.second) { LOG(ERROR) << "create a new node for [" << edge->GetName() << "] failed."; return RC_OUT_OF_MEMORY; @@ -280,7 +280,7 @@ RetCode OptGraph::AddBridgeKernels(const utils::SharedResource& resource) { edge->CalcConsumerCount() == 0) { // it is an finel node for the graph auto creator = OptKernelCreatorManager::GetInstance()->Find("pmx", "Bridge", 1); - auto ret_pair = topo->AddNode("Bridge_Final_" + node->GetName() + "_" + edge->GetName()); + auto ret_pair = topo->AddNode(string("Bridge_Final_") + node->GetName() + "_" + edge->GetName()); if (!ret_pair.second) { LOG(ERROR) << "create a new node for [" << edge->GetName() << "] failed."; return RC_OUT_OF_MEMORY; diff --git a/src/ppl/nn/engines/riscv/optimizer/rules/fuse_channel_shuffle.cc b/src/ppl/nn/engines/riscv/optimizer/rules/fuse_channel_shuffle.cc index b46e7515f..ca1ec2aaa 100644 --- a/src/ppl/nn/engines/riscv/optimizer/rules/fuse_channel_shuffle.cc +++ b/src/ppl/nn/engines/riscv/optimizer/rules/fuse_channel_shuffle.cc @@ -273,10 +273,10 @@ bool FuseChannelShuffle(const OptKernelOptions& options) { /******************** do optimize ***********************/ /** 1. create & register fused op **/ std::string channel_shuffle_node_name = "ChannelShuffle_" + - (fuse_concat ? (reshape1_prev_node->GetName() + "_") : "") + reshape1_node->GetName() + "_" + + (fuse_concat ? (reshape1_prev_node->GetName() + std::string("_")) : "") + reshape1_node->GetName() + "_" + trans_node->GetName() + "_" + reshape2_node->GetName() + - (fuse_split ? ("_" + reshape2_next_nodes[0]->GetName()) : "") + - (fuse_slice ? ("_" + reshape2_next_nodes[0]->GetName() + "_" + reshape2_next_nodes[1]->GetName()) : ""); + (fuse_split ? (std::string("_") + reshape2_next_nodes[0]->GetName()) : "") + + (fuse_slice ? (std::string("_") + reshape2_next_nodes[0]->GetName() + "_" + reshape2_next_nodes[1]->GetName()) : ""); auto node_ret_pair = graph_topo->AddNode(channel_shuffle_node_name); if (!node_ret_pair.second) { LOG(ERROR) << "node[" << channel_shuffle_node_name << "] already exists."; diff --git a/src/ppl/nn/engines/riscv/optimizer/rules/layout_optimize.cc b/src/ppl/nn/engines/riscv/optimizer/rules/layout_optimize.cc index 7a3a248be..f922f3a2b 100644 --- a/src/ppl/nn/engines/riscv/optimizer/rules/layout_optimize.cc +++ b/src/ppl/nn/engines/riscv/optimizer/rules/layout_optimize.cc @@ -39,11 +39,11 @@ static ppl::common::RetCode AddReorderOp(const OptKernelOptions& options, const std::string reorder_node_name = ""; if (reorder_type == REORDER_INPUT) { - reorder_node_name = "ReorderInput_" + edge->GetName() + "_of_" + node->GetName(); + reorder_node_name = std::string("ReorderInput_") + edge->GetName() + "_of_" + node->GetName(); } else if (reorder_type == REORDER_OUTPUT) { - reorder_node_name = "ReorderOutput_" + edge->GetName() + "_of_" + node->GetName(); + reorder_node_name = std::string("ReorderOutput_") + edge->GetName() + "_of_" + node->GetName(); } else if (reorder_type == REORDER_EXTRA_INPUT) { - reorder_node_name = "ReorderExtraInput_" + edge->GetName() + "_of_" + node->GetName(); + reorder_node_name = std::string("ReorderExtraInput_") + edge->GetName() + "_of_" + node->GetName(); } auto node_ret_pair = graph_topo->AddNode(reorder_node_name); diff --git a/src/ppl/nn/engines/x86/optimizer/rules/fuse_channel_shuffle.cc b/src/ppl/nn/engines/x86/optimizer/rules/fuse_channel_shuffle.cc index aade18be0..37e15e8c3 100644 --- a/src/ppl/nn/engines/x86/optimizer/rules/fuse_channel_shuffle.cc +++ b/src/ppl/nn/engines/x86/optimizer/rules/fuse_channel_shuffle.cc @@ -264,10 +264,10 @@ bool FuseChannelShuffle(const OptKernelOptions& options) { /******************** do optimize ***********************/ /** 1. create & register fused op **/ std::string channel_shuffle_node_name = "ChannelShuffle_" + - (fuse_concat ? (reshape1_prev_node->GetName() + "_") : "") + reshape1_node->GetName() + "_" + + (fuse_concat ? (reshape1_prev_node->GetName() + std::string("_")) : "") + reshape1_node->GetName() + "_" + trans_node->GetName() + "_" + reshape2_node->GetName() + - (fuse_split ? ("_" + reshape2_next_nodes[0]->GetName()) : "") + - (fuse_slice ? ("_" + reshape2_next_nodes[0]->GetName() + "_" + reshape2_next_nodes[1]->GetName()) : ""); + (fuse_split ? (std::string("_") + reshape2_next_nodes[0]->GetName()) : "") + + (fuse_slice ? (std::string("_") + reshape2_next_nodes[0]->GetName() + "_" + reshape2_next_nodes[1]->GetName()) : ""); auto node_ret_pair = graph_topo->AddNode(channel_shuffle_node_name); if (!node_ret_pair.second) { LOG(ERROR) << "node[" << channel_shuffle_node_name << "] already exists."; diff --git a/src/ppl/nn/engines/x86/optimizer/rules/fuse_conv_depthwise.cc b/src/ppl/nn/engines/x86/optimizer/rules/fuse_conv_depthwise.cc index b64f31943..8512b91ff 100644 --- a/src/ppl/nn/engines/x86/optimizer/rules/fuse_conv_depthwise.cc +++ b/src/ppl/nn/engines/x86/optimizer/rules/fuse_conv_depthwise.cc @@ -54,7 +54,7 @@ bool FuseConvDepthwise(const OptKernelOptions &options) { } const std::string pd_conv2d_node_name = - "PostDepthwiseConv_" + conv_node->GetName() + "_" + next_node->GetName(); + std::string("PostDepthwiseConv_") + conv_node->GetName() + "_" + next_node->GetName(); const ir::Node::Type type("pmx", "PostDepthwiseConv", 1); // add node to graph topo @@ -139,4 +139,3 @@ bool FuseConvDepthwise(const OptKernelOptions &options) { } }}} // namespace ppl::nn::x86 - diff --git a/src/ppl/nn/engines/x86/optimizer/rules/fuse_swish.cc b/src/ppl/nn/engines/x86/optimizer/rules/fuse_swish.cc index 9f3dfa172..825e59fa7 100644 --- a/src/ppl/nn/engines/x86/optimizer/rules/fuse_swish.cc +++ b/src/ppl/nn/engines/x86/optimizer/rules/fuse_swish.cc @@ -71,7 +71,7 @@ bool FuseSwish(const OptKernelOptions &options) { // sigmoid_input_edge(last_mul_outer_input_edge) ----> sigmoid_node ----> sigmoid_output_edge(last_mul_inner_input_edge) ----> last_mul_node ----> last_mul_output_edge // |-------------------------------------------------------------------------->| const std::string swish_node_name = - "Fused_Swish_" + sigmoid_node->GetName() + "_" + last_mul_node->GetName(); + std::string("Fused_Swish_") + sigmoid_node->GetName() + "_" + last_mul_node->GetName(); const ir::Node::Type type("pmx", "Swish", 1); // add node to graph topo @@ -128,4 +128,3 @@ bool FuseSwish(const OptKernelOptions &options) { } }}} // namespace ppl::nn::x86 - diff --git a/src/ppl/nn/engines/x86/optimizer/rules/layout_optimize.cc b/src/ppl/nn/engines/x86/optimizer/rules/layout_optimize.cc index 07b33e51a..67de1b663 100644 --- a/src/ppl/nn/engines/x86/optimizer/rules/layout_optimize.cc +++ b/src/ppl/nn/engines/x86/optimizer/rules/layout_optimize.cc @@ -41,11 +41,11 @@ static ppl::common::RetCode AddReorderOp( std::string reorder_node_name = ""; if (reorder_type == REORDER_INPUT) { - reorder_node_name = "ReorderInput_" + edge->GetName() + "_of_" + node->GetName(); + reorder_node_name = std::string("ReorderInput_") + edge->GetName() + "_of_" + node->GetName(); } else if (reorder_type == REORDER_OUTPUT) { - reorder_node_name = "ReorderOutput_" + edge->GetName() + "_of_" + node->GetName(); + reorder_node_name = std::string("ReorderOutput_") + edge->GetName() + "_of_" + node->GetName(); } else if (reorder_type == REORDER_EXTRA_INPUT) { - reorder_node_name = "ReorderExtraInput_" + edge->GetName() + "_of_" + node->GetName(); + reorder_node_name = std::string("ReorderExtraInput_") + edge->GetName() + "_of_" + node->GetName(); } auto node_ret_pair = graph_topo->AddNode(reorder_node_name); diff --git a/src/ppl/nn/ir/edge.h b/src/ppl/nn/ir/edge.h index 3cadb576c..53c73ce32 100644 --- a/src/ppl/nn/ir/edge.h +++ b/src/ppl/nn/ir/edge.h @@ -34,8 +34,7 @@ class Edge { /** @brief get the id of this edge */ virtual edgeid_t GetId() const = 0; - virtual void SetName(const std::string&) = 0; - virtual const std::string& GetName() const = 0; + virtual const char* GetName() const = 0; /** @brief get producer node id diff --git a/src/ppl/nn/ir/full_graph_topo.cc b/src/ppl/nn/ir/full_graph_topo.cc index 143c73af0..454ac21d7 100644 --- a/src/ppl/nn/ir/full_graph_topo.cc +++ b/src/ppl/nn/ir/full_graph_topo.cc @@ -23,34 +23,13 @@ using namespace ppl::common; namespace ppl { namespace nn { namespace ir { -static Node* FindNode(const string& name, const unique_ptr<Node>* nodes, uint32_t count) { - for (uint32_t i = 0; i < count; ++i) { - auto node = nodes[i].get(); - if (node && node->GetName() == name) { - return node; - } - } - return nullptr; -} - -static Edge* FindEdge(const string& name, const unique_ptr<Edge>* edges, uint32_t count) { - for (uint32_t i = 0; i < count; ++i) { - auto edge = edges[i].get(); - if (edge && edge->GetName() == name) { - return edge; - } - } - return nullptr; -} - pair<Node*, bool> FullGraphTopo::AddNode(const string& name) { - auto node = FindNode(name, nodes_.data(), nodes_.size()); - if (node) { - return make_pair(node, false); + auto ret_pair = name2nid_.insert(make_pair(name, GetCurrentNodeIdBound())); + if (!ret_pair.second) { + return make_pair(nodes_[ret_pair.first->second].get(), false); } - node = new Node(nodes_.size()); - node->SetName(name); + auto node = new Node(GetCurrentNodeIdBound(), ret_pair.first->first.c_str()); nodes_.emplace_back(unique_ptr<Node>(node)); return make_pair(node, true); } @@ -64,22 +43,20 @@ Node* FullGraphTopo::GetNode(nodeid_t nid) const { void FullGraphTopo::DelNode(nodeid_t nid) { if (nid < nodes_.size() && nodes_[nid]) { + name2nid_.erase(nodes_[nid]->GetName()); nodes_[nid].reset(); } } class FullGraphEdge final : public Edge { public: - FullGraphEdge(edgeid_t id) : id_(id), producer_(INVALID_NODEID) {} + FullGraphEdge(edgeid_t id, const char* name) : id_(id), name_(name), producer_(INVALID_NODEID) {} edgeid_t GetId() const override { return id_; } - void SetName(const std::string& name) override { - name_ = name; - } - const std::string& GetName() const override { + const char* GetName() const override { return name_; } @@ -116,7 +93,7 @@ class FullGraphEdge final : public Edge { private: const edgeid_t id_; - std::string name_; + const char* name_; // pointer to GraphTopo::name2eid_[idx]::first nodeid_t producer_; std::vector<nodeid_t> consumers_; @@ -126,13 +103,12 @@ class FullGraphEdge final : public Edge { }; pair<Edge*, bool> FullGraphTopo::AddEdge(const string& name) { - auto edge = FindEdge(name, edges_.data(), edges_.size()); - if (edge) { - return make_pair(edge, false); + auto ret_pair = name2eid_.insert(make_pair(name, GetCurrentEdgeIdBound())); + if (!ret_pair.second) { + return make_pair(edges_[ret_pair.first->second].get(), false); } - edge = new FullGraphEdge(GetCurrentEdgeIdBound()); - edge->SetName(name); + auto edge = new FullGraphEdge(GetCurrentEdgeIdBound(), ret_pair.first->first.c_str()); edges_.emplace_back(unique_ptr<Edge>(edge)); return make_pair(edge, true); } @@ -158,6 +134,7 @@ void FullGraphTopo::DelEdge(edgeid_t eid) { utils::VectorRemoveAllIf(outputs_, p); utils::VectorRemoveAllIf(constants_, p); + name2eid_.erase(edges_[eid]->GetName()); edges_[eid].reset(); } diff --git a/src/ppl/nn/ir/graph_topo.cc b/src/ppl/nn/ir/graph_topo.cc index 248d48e4d..c5041d705 100644 --- a/src/ppl/nn/ir/graph_topo.cc +++ b/src/ppl/nn/ir/graph_topo.cc @@ -24,18 +24,24 @@ using namespace ppl::common; namespace ppl { namespace nn { namespace ir { -static Node* FindNode(const ir::GraphTopo* topo, const string& name) { - for (auto it = topo->CreateNodeIter(); it->IsValid(); it->Forward()) { - auto node = it->Get(); - if (node->GetName() == name) { - return node; - } +Node* GraphTopo::GetNode(const string& name) const { + auto ref = name2nid_.find(name); + if (ref == name2nid_.end()) { + return nullptr; } - return nullptr; + + return GetNode(ref->second); } -Node* GraphTopo::GetNode(const string& name) const { - return FindNode(this, name); +bool GraphTopo::RenameNode(Node* node, const string& new_name) { + auto ret_pair = name2nid_.insert(make_pair(new_name, node->GetId())); + if (!ret_pair.second) { + return false; + } + + name2nid_.erase(node->GetName()); + node->name_ = ret_pair.first->first.c_str(); + return true; } static edgeid_t FindEdgeId(const string& name, const vector<edgeid_t>& edge_ids, const GraphTopo* topo) { @@ -61,13 +67,11 @@ static uint32_t FindEdgeIdIdx(const string& name, const vector<edgeid_t>& edge_i } Edge* GraphTopo::GetEdge(const std::string& name) const { - for (auto it = CreateEdgeIter(); it->IsValid(); it->Forward()) { - auto edge = it->Get(); - if (edge->GetName() == name) { - return edge; - } + auto ref = name2eid_.find(name); + if (ref == name2eid_.end()) { + return nullptr; } - return nullptr; + return GetEdge(ref->second); } edgeid_t GraphTopo::GetInput(const string& name) const { diff --git a/src/ppl/nn/ir/graph_topo.h b/src/ppl/nn/ir/graph_topo.h index 8e32bb55f..b0bd36461 100644 --- a/src/ppl/nn/ir/graph_topo.h +++ b/src/ppl/nn/ir/graph_topo.h @@ -26,6 +26,7 @@ #include <set> #include <memory> #include <functional> +#include <unordered_map> namespace ppl { namespace nn { namespace ir { @@ -87,6 +88,8 @@ class GraphTopo { Node* GetNode(const std::string& name) const; + bool RenameNode(Node*, const std::string& new_name); + // ----- // /** @@ -207,6 +210,12 @@ class GraphTopo { /** output edge ids */ std::vector<edgeid_t> outputs_; + /** node name => node id */ + std::unordered_map<std::string, nodeid_t> name2nid_; + + /** edge name => edge id */ + std::unordered_map<std::string, edgeid_t> name2eid_; + private: GraphTopo(const GraphTopo&) = delete; GraphTopo& operator=(const GraphTopo&) = delete; diff --git a/src/ppl/nn/ir/node.h b/src/ppl/nn/ir/node.h index c731007c9..d88c0edc6 100644 --- a/src/ppl/nn/ir/node.h +++ b/src/ppl/nn/ir/node.h @@ -45,16 +45,13 @@ class Node final { }; public: - Node(nodeid_t id) : id_(id) {} + Node(nodeid_t id, const char* name) : id_(id), name_(name) {} nodeid_t GetId() const { return id_; } - void SetName(const std::string& name) { - name_ = name; - } - const std::string& GetName() const { + const char* GetName() const { return name_; } @@ -166,8 +163,10 @@ class Node final { uint32_t ReplaceExtraInput(edgeid_t old_value, edgeid_t new_value); private: + friend class GraphTopo; // for GraphTopo::RenameNode() + const nodeid_t id_; - std::string name_; + const char* name_; // pointer to GraphTopo::name2nid_[idx]::first Type type_; std::vector<edgeid_t> inputs_; diff --git a/src/ppl/nn/ir/partial_graph_topo.cc b/src/ppl/nn/ir/partial_graph_topo.cc index 805ff3f76..953154bd2 100644 --- a/src/ppl/nn/ir/partial_graph_topo.cc +++ b/src/ppl/nn/ir/partial_graph_topo.cc @@ -24,18 +24,13 @@ namespace ppl { namespace nn { namespace ir { class PartialGraphEdge final : public Edge { public: - PartialGraphEdge(Edge* orig_edge, const vector<Node*>* node_ptrs) - : orig_edge_(orig_edge), node_ptrs_(node_ptrs) {} + PartialGraphEdge(Edge* orig_edge, const vector<Node*>* node_ptrs) : orig_edge_(orig_edge), node_ptrs_(node_ptrs) {} edgeid_t GetId() const override { return orig_edge_->GetId(); } - void SetName(const string& name) override { - orig_edge_->SetName(name); - } - - const string& GetName() const override { + const char* GetName() const override { return orig_edge_->GetName(); } diff --git a/src/ppl/nn/models/onnx/parsers/onnx/parse_clip_param.cc b/src/ppl/nn/models/onnx/parsers/onnx/parse_clip_param.cc index a3cf01a42..6d4ae6c12 100644 --- a/src/ppl/nn/models/onnx/parsers/onnx/parse_clip_param.cc +++ b/src/ppl/nn/models/onnx/parsers/onnx/parse_clip_param.cc @@ -37,7 +37,7 @@ RetCode ParseClipParam(const ::onnx::NodeProto& pb_node, const ParamParserExtraA utils::GetNodeAttr(pb_node, "min", &min_value, numeric_limits<float>::lowest()); utils::GetNodeAttr(pb_node, "max", &max_value, numeric_limits<float>::max()); - auto new_edge_name = node->GetName() + "_clip_min_" + ToString(topo->GetCurrentEdgeIdBound()); + auto new_edge_name = string(node->GetName()) + "_clip_min_" + ToString(topo->GetCurrentEdgeIdBound()); auto edge = ppl::nn::utils::AddScalarInitializer(topo, data, new_edge_name, min_value, DATATYPE_FLOAT32); if (!edge) { LOG(ERROR) << "add initializer[" << new_edge_name << "] failed."; @@ -45,7 +45,7 @@ RetCode ParseClipParam(const ::onnx::NodeProto& pb_node, const ParamParserExtraA } node->AddInput(edge->GetId()); - new_edge_name = node->GetName() + "_clip_max_" + ToString(topo->GetCurrentEdgeIdBound()); + new_edge_name = string(node->GetName()) + "_clip_max_" + ToString(topo->GetCurrentEdgeIdBound()); edge = ppl::nn::utils::AddScalarInitializer(topo, data, new_edge_name, max_value, DATATYPE_FLOAT32); if (!edge) { LOG(ERROR) << "add initializer[" << new_edge_name << "] failed."; diff --git a/src/ppl/nn/models/onnx/parsers/onnx/parse_pad_param.cc b/src/ppl/nn/models/onnx/parsers/onnx/parse_pad_param.cc index 7c682dd3d..fc421891f 100644 --- a/src/ppl/nn/models/onnx/parsers/onnx/parse_pad_param.cc +++ b/src/ppl/nn/models/onnx/parsers/onnx/parse_pad_param.cc @@ -47,7 +47,7 @@ RetCode ParsePadParam(const ::onnx::NodeProto& pb_node, const ParamParserExtraAr vector<int64_t> pads; utils::GetNodeAttr(pb_node, "pads", &pads); - auto new_edge_name = node->GetName() + "_pad_pads_" + ToString(args.topo->GetCurrentEdgeIdBound()); + auto new_edge_name = node->GetName() + string("_pad_pads_") + ToString(args.topo->GetCurrentEdgeIdBound()); auto edge = ppl::nn::utils::Add1DInitializer(args.topo, args.data, new_edge_name, pads, DATATYPE_INT64); if (!edge) { LOG(ERROR) << "add initializer[" << new_edge_name << "] failed."; @@ -58,7 +58,7 @@ RetCode ParsePadParam(const ::onnx::NodeProto& pb_node, const ParamParserExtraAr float value; utils::GetNodeAttr(pb_node, "value", &value, 0.0); - new_edge_name = node->GetName() + "_pad_value_" + ToString(args.topo->GetCurrentEdgeIdBound()); + new_edge_name = node->GetName() + string("_pad_value_") + ToString(args.topo->GetCurrentEdgeIdBound()); edge = ppl::nn::utils::AddScalarInitializer(args.topo, args.data, new_edge_name, value, DATATYPE_FLOAT32); if (!edge) { LOG(ERROR) << "add initializer[" << new_edge_name << "] failed."; diff --git a/src/ppl/nn/models/onnx/parsers/onnx/parse_slice_param.cc b/src/ppl/nn/models/onnx/parsers/onnx/parse_slice_param.cc index 03572c00b..c498684dc 100644 --- a/src/ppl/nn/models/onnx/parsers/onnx/parse_slice_param.cc +++ b/src/ppl/nn/models/onnx/parsers/onnx/parse_slice_param.cc @@ -33,7 +33,7 @@ RetCode ParseSliceParam(const ::onnx::NodeProto& pb_node, const ParamParserExtra vector<int64_t> starts; utils::GetNodeAttr(pb_node, "starts", &starts); - auto new_edge_name = node->GetName() + "_slice_starts_" + ToString(args.topo->GetCurrentEdgeIdBound()); + auto new_edge_name = node->GetName() + string("_slice_starts_") + ToString(args.topo->GetCurrentEdgeIdBound()); auto edge = ppl::nn::utils::Add1DInitializer(args.topo, args.data, new_edge_name, starts, DATATYPE_INT64); if (!edge) { LOG(ERROR) << "add initializer[" << new_edge_name << "] failed."; @@ -44,7 +44,7 @@ RetCode ParseSliceParam(const ::onnx::NodeProto& pb_node, const ParamParserExtra vector<int64_t> ends; utils::GetNodeAttr(pb_node, "ends", &ends); - new_edge_name = node->GetName() + "_slice_ends_" + ToString(args.topo->GetCurrentEdgeIdBound()); + new_edge_name = node->GetName() + string("_slice_ends_") + ToString(args.topo->GetCurrentEdgeIdBound()); edge = ppl::nn::utils::Add1DInitializer(args.topo, args.data, new_edge_name, ends, DATATYPE_INT64); if (!edge) { LOG(ERROR) << "add initializer[" << new_edge_name << "] failed."; @@ -55,7 +55,7 @@ RetCode ParseSliceParam(const ::onnx::NodeProto& pb_node, const ParamParserExtra vector<int64_t> axes; utils::GetNodeAttr(pb_node, "axes", &axes); - new_edge_name = node->GetName() + "_slice_axes_" + ToString(args.topo->GetCurrentEdgeIdBound()); + new_edge_name = node->GetName() + string("_slice_axes_") + ToString(args.topo->GetCurrentEdgeIdBound()); edge = ppl::nn::utils::Add1DInitializer(args.topo, args.data, new_edge_name, axes, DATATYPE_INT64); if (!edge) { LOG(ERROR) << "add initializer[" << new_edge_name << "] failed."; diff --git a/src/ppl/nn/models/onnx/parsers/onnx/parse_split_param.cc b/src/ppl/nn/models/onnx/parsers/onnx/parse_split_param.cc index 751b09568..e3b928571 100644 --- a/src/ppl/nn/models/onnx/parsers/onnx/parse_split_param.cc +++ b/src/ppl/nn/models/onnx/parsers/onnx/parse_split_param.cc @@ -25,10 +25,11 @@ using namespace ppl::common; namespace ppl { namespace nn { namespace onnx { -RetCode ParseSplitParam(const ::onnx::NodeProto& pb_node, const ParamParserExtraArgs& args, ir::Node* node, ir::Attr* arg) { +RetCode ParseSplitParam(const ::onnx::NodeProto& pb_node, const ParamParserExtraArgs& args, ir::Node* node, + ir::Attr* arg) { auto param = static_cast<SplitParam*>(arg); utils::GetNodeAttr(pb_node, "axis", ¶m->axis, 0); - + auto& node_type = node->GetType(); if (node_type.version < 13) { @@ -38,7 +39,7 @@ RetCode ParseSplitParam(const ::onnx::NodeProto& pb_node, const ParamParserExtra std::vector<int64_t> split_point; utils::GetNodeAttr(pb_node, "split", &split_point); - auto new_edge_name = node->GetName() + "_split_point_" + ToString(topo->GetCurrentEdgeIdBound()); + auto new_edge_name = node->GetName() + string("_split_point_") + ToString(topo->GetCurrentEdgeIdBound()); auto edge = ppl::nn::utils::Add1DInitializer(topo, data, new_edge_name, split_point, DATATYPE_INT64); if (!edge) { LOG(ERROR) << "add initializer[" << new_edge_name << "] failed."; @@ -49,7 +50,7 @@ RetCode ParseSplitParam(const ::onnx::NodeProto& pb_node, const ParamParserExtra node_type.version = 13; } - + return RC_SUCCESS; } diff --git a/src/ppl/nn/models/onnx/parsers/onnx/parse_squeeze_param.cc b/src/ppl/nn/models/onnx/parsers/onnx/parse_squeeze_param.cc index 07170cf48..e932f95d6 100644 --- a/src/ppl/nn/models/onnx/parsers/onnx/parse_squeeze_param.cc +++ b/src/ppl/nn/models/onnx/parsers/onnx/parse_squeeze_param.cc @@ -36,7 +36,7 @@ RetCode ParseSqueezeParam(const ::onnx::NodeProto& pb_node, const ParamParserExt std::vector<int64_t> axes; utils::GetNodeAttr(pb_node, "axes", &axes); - auto new_edge_name = node->GetName() + "_axes_" + ToString(topo->GetCurrentEdgeIdBound()); + auto new_edge_name = node->GetName() + string("_axes_") + ToString(topo->GetCurrentEdgeIdBound()); auto edge = ppl::nn::utils::Add1DInitializer(topo, data, new_edge_name, axes, DATATYPE_INT64); if (!edge) { LOG(ERROR) << "add initializer[" << new_edge_name << "] failed."; diff --git a/src/ppl/nn/models/onnx/parsers/onnx/parse_topk_param.cc b/src/ppl/nn/models/onnx/parsers/onnx/parse_topk_param.cc index 84e9b0ee1..888e4f49e 100644 --- a/src/ppl/nn/models/onnx/parsers/onnx/parse_topk_param.cc +++ b/src/ppl/nn/models/onnx/parsers/onnx/parse_topk_param.cc @@ -36,7 +36,7 @@ RetCode ParseTopKParam(const ::onnx::NodeProto& pb_node, const ParamParserExtraA if (node_type.version < 10) { int64_t k; utils::GetNodeAttr(pb_node, "k", &k, -1); - auto new_edge_name = node->GetName() + "_topk_k_" + ToString(args.topo->GetCurrentEdgeIdBound()); + auto new_edge_name = node->GetName() + string("_topk_k_") + ToString(args.topo->GetCurrentEdgeIdBound()); auto edge = ppl::nn::utils::AddScalarInitializer(args.topo, args.data, new_edge_name, k, DATATYPE_INT64); if (!edge) { LOG(ERROR) << "add initializer[" << new_edge_name << "] failed."; diff --git a/src/ppl/nn/models/onnx/parsers/onnx/parse_unsqueeze_param.cc b/src/ppl/nn/models/onnx/parsers/onnx/parse_unsqueeze_param.cc index f6bcf74a5..27051b232 100644 --- a/src/ppl/nn/models/onnx/parsers/onnx/parse_unsqueeze_param.cc +++ b/src/ppl/nn/models/onnx/parsers/onnx/parse_unsqueeze_param.cc @@ -36,7 +36,7 @@ RetCode ParseUnsqueezeParam(const ::onnx::NodeProto& pb_node, const ParamParserE std::vector<int64_t> axes; utils::GetNodeAttr(pb_node, "axes", &axes); - auto new_edge_name = node->GetName() + "_axes_" + ToString(topo->GetCurrentEdgeIdBound()); + auto new_edge_name = node->GetName() + string("_axes_") + ToString(topo->GetCurrentEdgeIdBound()); auto edge = ppl::nn::utils::Add1DInitializer(topo, data, new_edge_name, axes, DATATYPE_INT64); if (!edge) { LOG(ERROR) << "add initializer[" << new_edge_name << "] failed."; diff --git a/src/ppl/nn/models/pmx/serializer.cc b/src/ppl/nn/models/pmx/serializer.cc index 03b49d427..e39c238f8 100644 --- a/src/ppl/nn/models/pmx/serializer.cc +++ b/src/ppl/nn/models/pmx/serializer.cc @@ -75,7 +75,7 @@ static RetCode CreateFbEdges(FlatBufferBuilder* builder, const SerializationCont vector<Offset<pmx::Edge>> edges(seq2eid.size()); for (uint32_t i = 0; i < seq2eid.size(); ++i) { auto edge = topo->GetEdge(seq2eid[i]); - edges[i] = pmx::CreateEdgeDirect(*builder, edge->GetName().c_str()); + edges[i] = pmx::CreateEdgeDirect(*builder, edge->GetName()); } *fb_edges = builder->CreateVector<Offset<pmx::Edge>>(edges); @@ -109,7 +109,7 @@ static RetCode CreateFbNodes(FlatBufferBuilder* builder, const SerializationCont extra_inputs[i] = eid2seq[node->GetExtraInput(i)]; } - nodes[i] = pmx::CreateNodeDirect(*builder, node->GetName().c_str(), fb_type, &inputs, &outputs, &extra_inputs); + nodes[i] = pmx::CreateNodeDirect(*builder, node->GetName(), fb_type, &inputs, &outputs, &extra_inputs); } *fb_nodes = builder->CreateVector<Offset<pmx::Node>>(nodes); diff --git a/src/ppl/nn/optimizers/fuse_bn_optimizer.cc b/src/ppl/nn/optimizers/fuse_bn_optimizer.cc index cdfdd42cf..9a9e589dc 100644 --- a/src/ppl/nn/optimizers/fuse_bn_optimizer.cc +++ b/src/ppl/nn/optimizers/fuse_bn_optimizer.cc @@ -130,7 +130,7 @@ static bool FuseConvBatchNormalization(ir::Graph* graph) { if (conv_bias_edge) { conv_bias_ptr = (float*)constants[conv_bias_edge->GetId()].data.GetData(); } else { // if conv node has no bias, add bias tensor - auto add_bias_edge_name = conv_node->GetName() + "_bias"; + auto add_bias_edge_name = conv_node->GetName() + string("_bias"); auto edge_ret_pair = graph->topo->AddEdge(add_bias_edge_name); if (!edge_ret_pair.second) { LOG(ERROR) << "edge[" << add_bias_edge_name << "] already exists."; @@ -287,7 +287,7 @@ static bool FuseConvTransposeBatchNormalization(ir::Graph* graph) { if (convtranspose_bias_edge) { convtranspose_bias_ptr = (float*)constants[convtranspose_bias_edge->GetId()].data.GetData(); } else { // if convtranspose node has no bias, add bias tensor - auto add_bias_edge_name = convtranspose_node->GetName() + "_bias"; + auto add_bias_edge_name = convtranspose_node->GetName() + string("_bias"); auto edge_ret_pair = graph->topo->AddEdge(add_bias_edge_name); if (!edge_ret_pair.second) { LOG(ERROR) << "edge[" << add_bias_edge_name << "] already exists."; diff --git a/src/ppl/nn/optimizers/fuse_constant_optimizer.cc b/src/ppl/nn/optimizers/fuse_constant_optimizer.cc index 710ad816e..43afac2aa 100644 --- a/src/ppl/nn/optimizers/fuse_constant_optimizer.cc +++ b/src/ppl/nn/optimizers/fuse_constant_optimizer.cc @@ -290,7 +290,7 @@ static RetCode FuseConvAdd(ir::Graph* graph) { // fuse conv & add if (!conv_bias_edge) { // if conv node has no bias, add bias tensor - auto add_bias_edge_name = conv_node->GetName() + "_bias"; + auto add_bias_edge_name = conv_node->GetName() + string("_bias"); auto edge_ret_pair = graph->topo->AddEdge(add_bias_edge_name); if (!edge_ret_pair.second) { LOG(ERROR) << "edge[" << add_bias_edge_name << "] already exists."; diff --git a/src/ppl/nn/optimizers/fuse_shape_optimizer.cc b/src/ppl/nn/optimizers/fuse_shape_optimizer.cc index ddf9ae7c6..3b7a6eb2a 100644 --- a/src/ppl/nn/optimizers/fuse_shape_optimizer.cc +++ b/src/ppl/nn/optimizers/fuse_shape_optimizer.cc @@ -275,8 +275,14 @@ RetCode FuseShapeOptimizer::Optimize(ir::Graph* graph) const { if (!node || node->GetType().domain != "" || node->GetType().name != "Shape") { continue; } + + const string new_name(node->GetName() + string("_Fused")); + bool ok = graph->topo->RenameNode(node, new_name); + if (!ok) { + LOG(ERROR) << "rename node [" << node->GetName() << "] to new name [" << new_name << "] failed."; + return RC_EXISTS; + } node->SetType(ir::Node::Type{"pmx", "Shape", 1}); - node->SetName(node->GetName() + "_Fused"); ShapeOperationParam shape_param; ShapeMatrix temp_matrix; diff --git a/src/ppl/nn/optimizers/utils.cc b/src/ppl/nn/optimizers/utils.cc index 211a90ca7..d9c6d0373 100644 --- a/src/ppl/nn/optimizers/utils.cc +++ b/src/ppl/nn/optimizers/utils.cc @@ -173,7 +173,7 @@ static RetCode GenConverterNodes(const vector<pair<EngineImpl*, vector<nodeid_t> auto eid = x->first; auto edge = topo->GetEdge(eid); - const string output_edge_name("converted_output_of_" + edge->GetName() + "_" + + const string output_edge_name("converted_output_of_" + string(edge->GetName()) + "_" + ToString(topo->GetCurrentEdgeIdBound())); auto ret_pair = topo->AddEdge(output_edge_name); if (!ret_pair.second) { @@ -319,7 +319,7 @@ static RetCode CopyConstantsForDevices(const vector<pair<EngineImpl*, vector<nod // create copies for other engines for (auto it = engine_node_groups.begin(); it != engine_node_groups.end(); ++it) { auto ret_pair = - topo->AddEdge("__copy_of_" + edge->GetName() + "_" + ToString(topo->GetCurrentEdgeIdBound())); + topo->AddEdge("__copy_of_" + string(edge->GetName()) + "_" + ToString(topo->GetCurrentEdgeIdBound())); auto new_edge = ret_pair.first; auto new_edge_id = new_edge->GetId(); diff --git a/src/ppl/nn/runtime/kernel_impl.h b/src/ppl/nn/runtime/kernel_impl.h index cf3c2a7e4..f5e15c12b 100644 --- a/src/ppl/nn/runtime/kernel_impl.h +++ b/src/ppl/nn/runtime/kernel_impl.h @@ -42,7 +42,7 @@ class KernelImpl { } /** @brief get kernel's name */ - const std::string& GetName() const { + const char* GetName() const { return node_->GetName(); } diff --git a/src/ppl/nn/runtime/runtime_aux_info.cc b/src/ppl/nn/runtime/runtime_aux_info.cc index 8bb8247fa..c1d33b9f5 100644 --- a/src/ppl/nn/runtime/runtime_aux_info.cc +++ b/src/ppl/nn/runtime/runtime_aux_info.cc @@ -24,13 +24,6 @@ using namespace ppl::common; namespace ppl { namespace nn { -static void InitName2Nodeid(const ir::GraphTopo* topo, map<string, nodeid_t>* name2nodeid) { - for (auto it = topo->CreateNodeIter(); it->IsValid(); it->Forward()) { - auto node = it->Get(); - name2nodeid->insert(make_pair(node->GetName(), node->GetId())); - } -} - RetCode RuntimeAuxInfo::Init(const ir::GraphTopo* topo, const set<edgeid_t>& reserved_edgeids) { utils::DfsDeeperFirst(topo, [this](nodeid_t nid) -> void { this->sorted_nodes.push_back(nid); @@ -42,8 +35,6 @@ RetCode RuntimeAuxInfo::Init(const ir::GraphTopo* topo, const set<edgeid_t>& res return status; } - InitName2Nodeid(topo, &name2nodeid); - return RC_SUCCESS; } diff --git a/src/ppl/nn/runtime/runtime_aux_info.h b/src/ppl/nn/runtime/runtime_aux_info.h index b63821fbd..2f87f66be 100644 --- a/src/ppl/nn/runtime/runtime_aux_info.h +++ b/src/ppl/nn/runtime/runtime_aux_info.h @@ -22,7 +22,6 @@ #include "ppl/nn/common/types.h" #include "ppl/nn/ir/graph_topo.h" #include <set> -#include <map> #include <vector> namespace ppl { namespace nn { @@ -39,9 +38,6 @@ struct RuntimeAuxInfo final { /** an `EdgeObject` can be released right after the last consumer finish executing in `sorted_nodes` */ std::vector<nodeid_t> edge_last_consumer; - - /** node name => id mapping */ - std::map<std::string, nodeid_t> name2nodeid; }; }} // namespace ppl::nn diff --git a/src/ppl/nn/runtime/runtime_impl.cc b/src/ppl/nn/runtime/runtime_impl.cc index 58ecb20a7..b3997e2d9 100644 --- a/src/ppl/nn/runtime/runtime_impl.cc +++ b/src/ppl/nn/runtime/runtime_impl.cc @@ -79,7 +79,6 @@ static RetCode GenGraphKernels(const RuntimeGraphInfo& info, vector<unique_ptr<E } static RetCode GenGraphInputs(const ir::GraphTopo* topo, const RuntimeGraphInfo& info, - const map<string, nodeid_t>& name2nodeid, const vector<unique_ptr<KernelImpl>>& nodeid2kernel, map<string, TensorImpl>* reserved_tensors) { for (uint32_t i = 0; i < topo->GetInputCount(); ++i) { @@ -96,14 +95,8 @@ static RetCode GenGraphInputs(const ir::GraphTopo* topo, const RuntimeGraphInfo& continue; } - auto nid_ref = name2nodeid.find(consumer->GetName()); - if (nid_ref == name2nodeid.end()) { - LOG(ERROR) << "cannot find consumer[" << consumer->GetName() << "] of [" << edge->GetName() << "]"; - return RC_NOT_FOUND; - } - // Consumers of an input are in the same engine. This is guranteed by optimizer. - auto kernel = nodeid2kernel[nid_ref->second].get(); + auto kernel = nodeid2kernel[consumer->GetId()].get(); tensor->SetDevice(kernel->GetEngineContext()->GetDevice()); break; } @@ -120,7 +113,6 @@ static RetCode GenGraphInputs(const ir::GraphTopo* topo, const RuntimeGraphInfo& } static RetCode GenGraphExtraInputs(const ir::GraphTopo* topo, const RuntimeGraphInfo& info, - const map<string, nodeid_t>& name2nodeid, const vector<unique_ptr<KernelImpl>>& nodeid2kernel, map<string, TensorImpl>* reserved_tensors) { for (uint32_t i = 0; i < topo->GetExtraInputCount(); ++i) { @@ -137,14 +129,8 @@ static RetCode GenGraphExtraInputs(const ir::GraphTopo* topo, const RuntimeGraph continue; } - auto nid_ref = name2nodeid.find(consumer->GetName()); - if (nid_ref == name2nodeid.end()) { - LOG(ERROR) << "cannot find consumer[" << consumer->GetName() << "] of [" << edge->GetName() << "]"; - return RC_NOT_FOUND; - } - // Consumers of an input are in the same engine. This is guranteed by optimizer. - auto kernel = nodeid2kernel[nid_ref->second].get(); + auto kernel = nodeid2kernel[consumer->GetId()].get(); tensor->SetDevice(kernel->GetEngineContext()->GetDevice()); break; } @@ -160,7 +146,7 @@ static RetCode GenGraphExtraInputs(const ir::GraphTopo* topo, const RuntimeGraph } RetCode GenGraphOutputs(const ir::GraphTopo* topo, const RuntimeGraphInfo& info, - const map<string, nodeid_t>& name2nodeid, const vector<unique_ptr<KernelImpl>>& nodeid2kernel, + const vector<unique_ptr<KernelImpl>>& nodeid2kernel, map<string, TensorImpl>* reserved_tensors) { for (uint32_t i = 0; i < topo->GetOutputCount(); ++i) { auto eid = topo->GetOutput(i); @@ -172,15 +158,7 @@ RetCode GenGraphOutputs(const ir::GraphTopo* topo, const RuntimeGraphInfo& info, if (ret_pair.second) { auto producer_id = edge->GetProducer(); if (producer_id != INVALID_NODEID) { - auto producer = topo->GetNode(producer_id); - - auto nid_ref = name2nodeid.find(producer->GetName()); - if (nid_ref == name2nodeid.end()) { - LOG(ERROR) << "cannot find producer[" << producer->GetName() << "] of [" << edge->GetName() << "]"; - return RC_NOT_FOUND; - } - - auto kernel = nodeid2kernel[nid_ref->second].get(); + auto kernel = nodeid2kernel[producer_id].get(); tensor->SetDevice(kernel->GetEngineContext()->GetDevice()); } @@ -260,19 +238,19 @@ static RetCode InitGraphResources(const ir::GraphTopo* topo, const RuntimeGraphI return status; } - status = GenGraphInputs(topo, info, aux_info.name2nodeid, *nodeid2kernel, reserved_tensors); + status = GenGraphInputs(topo, info, *nodeid2kernel, reserved_tensors); if (status != RC_SUCCESS) { LOG(ERROR) << "GenGraphInputs failed: " << GetRetCodeStr(status); return status; } - status = GenGraphExtraInputs(topo, info, aux_info.name2nodeid, *nodeid2kernel, reserved_tensors); + status = GenGraphExtraInputs(topo, info, *nodeid2kernel, reserved_tensors); if (status != RC_SUCCESS) { LOG(ERROR) << "GenGraphExtraInputs failed: " << GetRetCodeStr(status); return status; } - status = GenGraphOutputs(topo, info, aux_info.name2nodeid, *nodeid2kernel, reserved_tensors); + status = GenGraphOutputs(topo, info, *nodeid2kernel, reserved_tensors); if (status != RC_SUCCESS) { LOG(ERROR) << "GenGraphOutputs failed: " << GetRetCodeStr(status); return status; diff --git a/src/ppl/nn/runtime/tensor_impl.h b/src/ppl/nn/runtime/tensor_impl.h index 901b1b67c..5ab9442a0 100644 --- a/src/ppl/nn/runtime/tensor_impl.h +++ b/src/ppl/nn/runtime/tensor_impl.h @@ -47,7 +47,7 @@ class TensorImpl final : public EdgeObject, public Tensor { } const char* GetName() const override { - return GetEdge()->GetName().c_str(); + return GetEdge()->GetName(); } DeviceContext* GetDeviceContext() const override { diff --git a/tests/common/input_output_info_test.cc b/tests/common/input_output_info_test.cc index 691781226..bac811f02 100644 --- a/tests/common/input_output_info_test.cc +++ b/tests/common/input_output_info_test.cc @@ -52,7 +52,7 @@ TEST_F(InputOutputInfoTest, misc) { auto topo = builder_.GetGraph()->topo.get(); auto node = topo->GetNode(0); - EXPECT_EQ("a", node->GetName()); + EXPECT_EQ(string("a"), node->GetName()); InputOutputInfo info; info.SetNode(node); diff --git a/tests/ir/full_graph_topo_test.cc b/tests/ir/full_graph_topo_test.cc index addbfb859..ea6d6ca9e 100644 --- a/tests/ir/full_graph_topo_test.cc +++ b/tests/ir/full_graph_topo_test.cc @@ -61,7 +61,7 @@ TEST_F(FullGraphTopoTest, full_grapg_topo_GetNodeById_Test) { auto topo = graph_builder_.GetGraph()->topo.get(); auto res_node = topo->GetNode(0); cout << res_node->GetName() << endl; - EXPECT_EQ(res_node->GetName(), "c"); + EXPECT_EQ(res_node->GetName(), string("c")); } TEST_F(FullGraphTopoTest, full_graph_topo_DelNodeById_Test) { @@ -98,7 +98,7 @@ TEST_F(FullGraphTopoTest, full_graph_topo_GetMaxEdgeId_Test) { TEST_F(FullGraphTopoTest, full_graph_topo_GetEdgeById_Test) { auto topo = graph_builder_.GetGraph()->topo.get(); auto res_edge = topo->GetEdge(0); - EXPECT_EQ(res_edge->GetName(), "output_of_b"); + EXPECT_EQ(res_edge->GetName(), string("output_of_b")); } TEST_F(FullGraphTopoTest, full_graph_topo_DelEdgeById_Test) { @@ -118,14 +118,6 @@ TEST_F(FullGraphTopoTest, full_edge_GetId_Test) { EXPECT_EQ(0, edge_id); } -TEST_F(FullGraphTopoTest, full_edge_SetNameAndGetName_Test) { - auto topo = graph_builder_.GetGraph()->topo.get(); - auto edge = topo->GetEdge(0); - const string edge_name = "tmp"; - edge->SetName(edge_name); - EXPECT_EQ(edge->GetName(), edge_name); -} - TEST_F(FullGraphTopoTest, full_edge_SetProducerAndGetProducer_Test) { auto topo = graph_builder_.GetGraph()->topo.get(); auto edge = topo->GetEdge(0); diff --git a/tests/ir/node_test.cc b/tests/ir/node_test.cc index 54066f1f1..3db4a948e 100644 --- a/tests/ir/node_test.cc +++ b/tests/ir/node_test.cc @@ -23,19 +23,18 @@ using namespace ppl::nn; TEST(NodeTest, NodeTest_GetId_Test) { const nodeid_t nodeid = 1; - ir::Node node(nodeid); + ir::Node node(nodeid, "dummy"); EXPECT_EQ(node.GetId(), nodeid); } TEST(NodeTest, NodeTest_SetNameAndGetName_Test) { const string node_name = "tmp"; - ir::Node node(1); - node.SetName(node_name); + ir::Node node(1, node_name.c_str()); EXPECT_EQ(node_name, node.GetName()); } TEST(NodeTest, NodeTest_SetTypeAndGetType_Test) { - ir::Node node(1); + ir::Node node(1, "tmp"); node.SetType(ir::Node::Type("domain", "test", 1)); const ir::Node::Type& type = node.GetType(); EXPECT_EQ("domain", type.domain); @@ -43,59 +42,59 @@ TEST(NodeTest, NodeTest_SetTypeAndGetType_Test) { } TEST(NodeTest, NodeTest_AddInputAndGetInput_Test) { - ir::Node node(1); + ir::Node node(1, "tmp"); const edgeid_t expected_edgeid = 2; node.AddInput(expected_edgeid); EXPECT_EQ(expected_edgeid, node.GetInput(0)); } TEST(NodeTest, NodeTest_GetInputCount_Test) { - ir::Node node(1); + ir::Node node(1, "tmp"); const edgeid_t expected_edgeid = 2; node.AddInput(expected_edgeid); EXPECT_EQ(1, node.GetInputCount()); } TEST(NodeTest, NodeTest_ReplaceInput_Test) { - ir::Node node(1); + ir::Node node(1, "tmp"); node.AddInput(2); EXPECT_EQ(1, node.ReplaceInput(2, 4)); EXPECT_EQ(4, node.GetInput(0)); } TEST(NodeTest, NodeTest_AddOutputAndGetOutput_Test) { - ir::Node node(1); + ir::Node node(1, "tmp"); node.AddOutput(2); EXPECT_EQ(2, node.GetOutput(0)); } TEST(NodeTest, NodeTest_GerOutputCount_Test) { - ir::Node node(1); + ir::Node node(1, "tmp"); node.AddOutput(2); EXPECT_EQ(1, node.GetOutputCount()); } TEST(NodeTest, NodeTest_ReplaceOutput__Test) { - ir::Node node(1); + ir::Node node(1, "tmp"); node.AddOutput(2); EXPECT_EQ(1, node.ReplaceOutput(2, 4)); EXPECT_EQ(4, node.GetOutput(0)); } TEST(NodeTest, NodeTest_AddExtralInputAndGetExtraInput_Test) { - ir::Node node(1); + ir::Node node(1, "tmp"); node.AddExtraInput(2); EXPECT_EQ(2, node.GetExtraInput(0)); } TEST(NodeTest, NodeTest_GetExtraInputCount_Test) { - ir::Node node(1); + ir::Node node(1, "tmp"); node.AddExtraInput(2); EXPECT_EQ(1, node.GetExtraInputCount()); } TEST(NodeTest, NodeTest_ReplaceExtraInput__Test) { - ir::Node node(1); + ir::Node node(1, "tmp"); node.AddExtraInput(2); EXPECT_EQ(1, node.ReplaceExtraInput(2, 4)); EXPECT_EQ(4, node.GetExtraInput(0)); diff --git a/tests/runtime/kernel_exec_context_test.cc b/tests/runtime/kernel_exec_context_test.cc index 39ed28dc6..49322fd7e 100644 --- a/tests/runtime/kernel_exec_context_test.cc +++ b/tests/runtime/kernel_exec_context_test.cc @@ -41,7 +41,7 @@ TEST_F(KernelExecContextTest, misc) { auto topo = builder_.GetGraph()->topo.get(); auto node = topo->GetNode(0); - EXPECT_EQ("a", node->GetName()); + EXPECT_EQ(string("a"), node->GetName()); KernelExecContext ctx; ctx.SetNode(node); diff --git a/tests/runtime/kernel_impl_test.cc b/tests/runtime/kernel_impl_test.cc index dd0481bbc..f0997fad0 100644 --- a/tests/runtime/kernel_impl_test.cc +++ b/tests/runtime/kernel_impl_test.cc @@ -42,11 +42,11 @@ class KernelImplTest : public testing::Test { TEST_F(KernelImplTest, misc) { auto topo = builder_.GetGraph()->topo.get(); auto node = topo->GetNode(0); - EXPECT_EQ("a", node->GetName()); + EXPECT_EQ(string("a"), node->GetName()); TestKernel kernels(node); EXPECT_EQ(node, kernels.GetNode()); - EXPECT_EQ("a", kernels.GetName()); + EXPECT_EQ(string("a"), kernels.GetName()); EXPECT_EQ(ir::Node::Type("test", "op1", 1), kernels.GetType()); TmpEngineContext ctx;