Skip to content

Commit

Permalink
[opt] accerarate finding nodes/edges by name
Browse files Browse the repository at this point in the history
  • Loading branch information
ouonline committed Jan 3, 2024
1 parent 63fd536 commit c706e55
Show file tree
Hide file tree
Showing 43 changed files with 144 additions and 200 deletions.
2 changes: 1 addition & 1 deletion src/ppl/nn/auxtools/to_graphviz.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace ppl { namespace nn { namespace utils {

static string GenNodeIdStr(const ir::Node* node) {
auto& type = node->GetType();
return node->GetName() + "[" + type.domain + ":" + type.name + ":" + ToString(type.version) + "]";
return node->GetName() + string("[") + type.domain + ":" + type.name + ":" + ToString(type.version) + "]";
}

static string ToGraphviz(const ir::GraphTopo* topo) {
Expand Down
20 changes: 10 additions & 10 deletions src/ppl/nn/engines/arm/optimizer/ops/onnx/conv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ ppl::common::RetCode ConvOp::SelectAlgorithm(const InputOutputInfo& info, const

conv2d_param_->mgr = conv2d_algo_selector::fast_gen_algo(
*info.GetInput<TensorImpl>(0)->GetShape(), options.engine_options->forward_precision,
options.engine_options->dynamic_tuning_level, options.engine_options->winograd_level,
options.engine_options->dynamic_tuning_level, options.engine_options->winograd_level,
options.device->GetISA(), conv2d_param_->param, options.device->GetAllocator());

if (conv2d_param_->mgr == nullptr) {
Expand All @@ -183,7 +183,7 @@ ppl::common::RetCode ConvOp::SelectAlgorithm(const InputOutputInfo& info, const
// Note: If the filter is reused, generate a new input edge to hold cvt_filter which may differ due to different algo/sched_param.
ppl::nn::TensorBufferInfo * new_filter = &options.info->constants[node->GetInput(1)];
if (new_filter && new_filter->IsBufferOwner() && new_filter->GetBufferPtr<void>()) {
auto edge = options.graph_topo->AddEdge(node->GetName() + "_Input_Cvt_Filter").first;
auto edge = options.graph_topo->AddEdge(node->GetName() + std::string("_Input_Cvt_Filter")).first;
edge->AddConsumer(node->GetId());

TensorImpl* tensor = new TensorImpl(edge, TENSORTYPE_RESERVED);
Expand All @@ -210,7 +210,7 @@ ppl::common::RetCode ConvOp::SelectAlgorithm(const InputOutputInfo& info, const
if (normal_cvt_weights_ret != ppl::common::RC_SUCCESS) {
return normal_cvt_weights_ret;
}

if (new_bias && new_bias->IsBufferOwner() && new_bias->GetBufferPtr()) {
bias_data = nullptr;
} else if (bias_data && new_bias) {
Expand Down Expand Up @@ -350,13 +350,13 @@ ppl::common::RetCode ConvOp::SerializeData(const ::ppl::nn::pmx::SerializationCo
mgr->algo_info().data_type,
mgr->algo_info().isa,
conv_builder.CreateVector<int64_t>(algo_sp));
auto fb_param_info = ppl::nn::pmx::arm::CreateConvParamInfo(conv_builder,
conv2d_param_->param.num_output,
conv2d_param_->param.channels,
conv2d_param_->mgr->get_param().pad_type,
auto fb_param_info = ppl::nn::pmx::arm::CreateConvParamInfo(conv_builder,
conv2d_param_->param.num_output,
conv2d_param_->param.channels,
conv2d_param_->mgr->get_param().pad_type,
conv2d_param_->mgr->get_param().fuse_flag,
(!mgr->is_zero_bias()) ? 1 : 0);

auto fb_conv_data = ppl::nn::pmx::arm::CreateConvData(conv_builder, fb_algo_info, fb_param_info);
auto fb_op_data = ppl::nn::pmx::arm::CreateOpData(conv_builder, ppl::nn::pmx::arm::PrivateDataType_ConvData, fb_conv_data.Union());
ppl::nn::pmx::arm::FinishOpDataBuffer(conv_builder, fb_op_data);
Expand Down Expand Up @@ -415,11 +415,11 @@ ppl::common::RetCode ConvOp::DeserializeData(const ::ppl::nn::pmx::Deserializati
.isa = algo_info->isa(),
.data_type = algo_info->dtype()
});

std::vector<int64_t> sp;
ppl::nn::pmx::utils::Fbvec2Stdvec(algo_info->sched_param(), &sp);
mgr->set_schedule_param(sp);

const auto & shapes = *ctx.shapes;
const auto & constants = *ctx.constants;

Expand Down
6 changes: 3 additions & 3 deletions src/ppl/nn/engines/arm/optimizer/opt_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ RetCode OptGraph::AddReorderOp(const OptKernelOptions& options, const edgeid_t&

std::string reorder_node_name = "";
if (reorder_type == REORDER_INPUT) {
reorder_node_name = "ReorderInput_" + edge->GetName() + "_of_" + node->GetName();
reorder_node_name = std::string("ReorderInput_") + edge->GetName() + "_of_" + node->GetName();
} else if (reorder_type == REORDER_OUTPUT) {
reorder_node_name = "ReorderOutput_" + edge->GetName() + "_of_" + node->GetName();
reorder_node_name = std::string("ReorderOutput_") + edge->GetName() + "_of_" + node->GetName();
} else if (reorder_type == REORDER_EXTRA_INPUT) {
reorder_node_name = "ReorderExtraInput_" + edge->GetName() + "_of_" + node->GetName();
reorder_node_name = std::string("ReorderExtraInput_") + edge->GetName() + "_of_" + node->GetName();
}

auto node_ret_pair = graph_->topo->AddNode(reorder_node_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,10 @@ bool FuseChannelShuffleRule::Apply(const OptKernelOptions& options) {
/******************** do optimize ***********************/
/** 1. create & register fused op **/
std::string channel_shuffle_node_name = "ChannelShuffle_" +
(fuse_concat ? (reshape1_prev_node->GetName() + "_") : "") + reshape1_node->GetName() + "_" +
(fuse_concat ? (reshape1_prev_node->GetName() + std::string("_")) : "") + reshape1_node->GetName() + "_" +
trans_node->GetName() + "_" + reshape2_node->GetName() +
(fuse_split ? ("_" + reshape2_next_nodes[0]->GetName()) : "") +
(fuse_slice ? ("_" + reshape2_next_nodes[0]->GetName() + "_" + reshape2_next_nodes[1]->GetName()) : "");
(fuse_split ? (std::string("_") + reshape2_next_nodes[0]->GetName()) : "") +
(fuse_slice ? (std::string("_") + reshape2_next_nodes[0]->GetName() + "_" + reshape2_next_nodes[1]->GetName()) : "");
auto node_ret_pair = graph_topo->AddNode(channel_shuffle_node_name);
if (!node_ret_pair.second) {
LOG(ERROR) << "node[" << channel_shuffle_node_name << "] already exists.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ RetCode DepthwiseDirectInt8::ModifyParam(ir::Node* node, OptKernelOptions& optio
quant_constat_info.SetBuffer(buffer, options.device, true);
}

auto ret_pair = topo->AddEdge("Quant_" + node->GetName());
auto ret_pair = topo->AddEdge(string("Quant_") + node->GetName());
auto quant_edge = ret_pair.first;
auto quant_edge_id = quant_edge->GetId();
node->AddInput(quant_edge_id);
Expand Down
2 changes: 1 addition & 1 deletion src/ppl/nn/engines/cuda/optimizer/algos/algo_conv_imma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ RetCode TuringIMMAImpgemm::ModifyParam(ir::Node* node, OptKernelOptions& options
quant_constat_info.SetBuffer(buffer, options.device, true);
}

auto ret_pair = topo->AddEdge("Quant_" + node->GetName());
auto ret_pair = topo->AddEdge(string("Quant_") + node->GetName());
auto quant_edge = ret_pair.first;
auto quant_edge_id = quant_edge->GetId();
node->AddInput(quant_edge_id);
Expand Down
2 changes: 1 addition & 1 deletion src/ppl/nn/engines/cuda/optimizer/algos/algo_gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ RetCode GemmAlgorithm::ModifyParam(ir::Node* node, OptKernelOptions& options) {
quant_constat_info.SetBuffer(buffer, options.device, true);
}

auto ret_pair = topo->AddEdge("Quant_" + node->GetName());
auto ret_pair = topo->AddEdge(string("Quant_") + node->GetName());
auto quant_edge = ret_pair.first;
auto quant_edge_id = quant_edge->GetId();
node->AddInput(quant_edge_id);
Expand Down
4 changes: 2 additions & 2 deletions src/ppl/nn/engines/cuda/optimizer/ops/pmx/bridge_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ KernelImpl* BridgeOp::CreateKernelImpl() const {

RetCode BridgeOp::AddInternalBridgeNode(ir::Node* node, ir::Node* new_node, ir::Edge* edge, ir::Graph* graph) {
auto topo = graph->topo.get();
auto ret_pair = topo->AddEdge("Bridge_Edge_" + edge->GetName() + "_" + node->GetName());
auto ret_pair = topo->AddEdge(string("Bridge_Edge_") + edge->GetName() + "_" + node->GetName());
auto new_edge = ret_pair.first;

edge->DelConsumer(node->GetId());
Expand All @@ -90,7 +90,7 @@ RetCode BridgeOp::AddInternalBridgeNode(ir::Node* node, ir::Node* new_node, ir::

RetCode BridgeOp::AddFinalBridgeNode(ir::Node* node, ir::Node* new_node, ir::Edge* edge, ir::Graph* graph) {
auto topo = graph->topo.get();
auto ret_pair = topo->AddEdge("Bridge_Final_Edge_" + edge->GetName() + "_" + node->GetName());
auto ret_pair = topo->AddEdge(string("Bridge_Final_Edge_") + edge->GetName() + "_" + node->GetName());
auto new_edge = ret_pair.first;

edge->SetProducer(new_node->GetId());
Expand Down
6 changes: 3 additions & 3 deletions src/ppl/nn/engines/cuda/optimizer/opt_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,12 @@ RetCode OptGraph::AddBridgeKernels(const utils::SharedResource& resource) {
continue;
}
auto edge = topo->GetEdge(edge_id);
if (edge->GetName().find("Bridge_Edge") != string::npos) {
if (string(edge->GetName()).find("Bridge_Edge") != string::npos) {
continue;
}

auto creator = OptKernelCreatorManager::GetInstance()->Find("pmx", "Bridge", 1);
auto ret_pair = topo->AddNode("Bridge_Node_" + node->GetName() + "_" + edge->GetName());
auto ret_pair = topo->AddNode(string("Bridge_Node_") + node->GetName() + "_" + edge->GetName());
if (!ret_pair.second) {
LOG(ERROR) << "create a new node for [" << edge->GetName() << "] failed.";
return RC_OUT_OF_MEMORY;
Expand Down Expand Up @@ -280,7 +280,7 @@ RetCode OptGraph::AddBridgeKernels(const utils::SharedResource& resource) {
edge->CalcConsumerCount() == 0) { // it is an finel node for the graph
auto creator = OptKernelCreatorManager::GetInstance()->Find("pmx", "Bridge", 1);

auto ret_pair = topo->AddNode("Bridge_Final_" + node->GetName() + "_" + edge->GetName());
auto ret_pair = topo->AddNode(string("Bridge_Final_") + node->GetName() + "_" + edge->GetName());
if (!ret_pair.second) {
LOG(ERROR) << "create a new node for [" << edge->GetName() << "] failed.";
return RC_OUT_OF_MEMORY;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,10 @@ bool FuseChannelShuffle(const OptKernelOptions& options) {
/******************** do optimize ***********************/
/** 1. create & register fused op **/
std::string channel_shuffle_node_name = "ChannelShuffle_" +
(fuse_concat ? (reshape1_prev_node->GetName() + "_") : "") + reshape1_node->GetName() + "_" +
(fuse_concat ? (reshape1_prev_node->GetName() + std::string("_")) : "") + reshape1_node->GetName() + "_" +
trans_node->GetName() + "_" + reshape2_node->GetName() +
(fuse_split ? ("_" + reshape2_next_nodes[0]->GetName()) : "") +
(fuse_slice ? ("_" + reshape2_next_nodes[0]->GetName() + "_" + reshape2_next_nodes[1]->GetName()) : "");
(fuse_split ? (std::string("_") + reshape2_next_nodes[0]->GetName()) : "") +
(fuse_slice ? (std::string("_") + reshape2_next_nodes[0]->GetName() + "_" + reshape2_next_nodes[1]->GetName()) : "");
auto node_ret_pair = graph_topo->AddNode(channel_shuffle_node_name);
if (!node_ret_pair.second) {
LOG(ERROR) << "node[" << channel_shuffle_node_name << "] already exists.";
Expand Down
6 changes: 3 additions & 3 deletions src/ppl/nn/engines/riscv/optimizer/rules/layout_optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ static ppl::common::RetCode AddReorderOp(const OptKernelOptions& options, const

std::string reorder_node_name = "";
if (reorder_type == REORDER_INPUT) {
reorder_node_name = "ReorderInput_" + edge->GetName() + "_of_" + node->GetName();
reorder_node_name = std::string("ReorderInput_") + edge->GetName() + "_of_" + node->GetName();
} else if (reorder_type == REORDER_OUTPUT) {
reorder_node_name = "ReorderOutput_" + edge->GetName() + "_of_" + node->GetName();
reorder_node_name = std::string("ReorderOutput_") + edge->GetName() + "_of_" + node->GetName();
} else if (reorder_type == REORDER_EXTRA_INPUT) {
reorder_node_name = "ReorderExtraInput_" + edge->GetName() + "_of_" + node->GetName();
reorder_node_name = std::string("ReorderExtraInput_") + edge->GetName() + "_of_" + node->GetName();
}

auto node_ret_pair = graph_topo->AddNode(reorder_node_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,10 @@ bool FuseChannelShuffle(const OptKernelOptions& options) {
/******************** do optimize ***********************/
/** 1. create & register fused op **/
std::string channel_shuffle_node_name = "ChannelShuffle_" +
(fuse_concat ? (reshape1_prev_node->GetName() + "_") : "") + reshape1_node->GetName() + "_" +
(fuse_concat ? (reshape1_prev_node->GetName() + std::string("_")) : "") + reshape1_node->GetName() + "_" +
trans_node->GetName() + "_" + reshape2_node->GetName() +
(fuse_split ? ("_" + reshape2_next_nodes[0]->GetName()) : "") +
(fuse_slice ? ("_" + reshape2_next_nodes[0]->GetName() + "_" + reshape2_next_nodes[1]->GetName()) : "");
(fuse_split ? (std::string("_") + reshape2_next_nodes[0]->GetName()) : "") +
(fuse_slice ? (std::string("_") + reshape2_next_nodes[0]->GetName() + "_" + reshape2_next_nodes[1]->GetName()) : "");
auto node_ret_pair = graph_topo->AddNode(channel_shuffle_node_name);
if (!node_ret_pair.second) {
LOG(ERROR) << "node[" << channel_shuffle_node_name << "] already exists.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ bool FuseConvDepthwise(const OptKernelOptions &options) {
}

const std::string pd_conv2d_node_name =
"PostDepthwiseConv_" + conv_node->GetName() + "_" + next_node->GetName();
std::string("PostDepthwiseConv_") + conv_node->GetName() + "_" + next_node->GetName();
const ir::Node::Type type("pmx", "PostDepthwiseConv", 1);

// add node to graph topo
Expand Down Expand Up @@ -139,4 +139,3 @@ bool FuseConvDepthwise(const OptKernelOptions &options) {
}

}}} // namespace ppl::nn::x86

3 changes: 1 addition & 2 deletions src/ppl/nn/engines/x86/optimizer/rules/fuse_swish.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ bool FuseSwish(const OptKernelOptions &options) {
// sigmoid_input_edge(last_mul_outer_input_edge) ----> sigmoid_node ----> sigmoid_output_edge(last_mul_inner_input_edge) ----> last_mul_node ----> last_mul_output_edge
// |-------------------------------------------------------------------------->|
const std::string swish_node_name =
"Fused_Swish_" + sigmoid_node->GetName() + "_" + last_mul_node->GetName();
std::string("Fused_Swish_") + sigmoid_node->GetName() + "_" + last_mul_node->GetName();
const ir::Node::Type type("pmx", "Swish", 1);

// add node to graph topo
Expand Down Expand Up @@ -128,4 +128,3 @@ bool FuseSwish(const OptKernelOptions &options) {
}

}}} // namespace ppl::nn::x86

6 changes: 3 additions & 3 deletions src/ppl/nn/engines/x86/optimizer/rules/layout_optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ static ppl::common::RetCode AddReorderOp(

std::string reorder_node_name = "";
if (reorder_type == REORDER_INPUT) {
reorder_node_name = "ReorderInput_" + edge->GetName() + "_of_" + node->GetName();
reorder_node_name = std::string("ReorderInput_") + edge->GetName() + "_of_" + node->GetName();
} else if (reorder_type == REORDER_OUTPUT) {
reorder_node_name = "ReorderOutput_" + edge->GetName() + "_of_" + node->GetName();
reorder_node_name = std::string("ReorderOutput_") + edge->GetName() + "_of_" + node->GetName();
} else if (reorder_type == REORDER_EXTRA_INPUT) {
reorder_node_name = "ReorderExtraInput_" + edge->GetName() + "_of_" + node->GetName();
reorder_node_name = std::string("ReorderExtraInput_") + edge->GetName() + "_of_" + node->GetName();
}

auto node_ret_pair = graph_topo->AddNode(reorder_node_name);
Expand Down
3 changes: 1 addition & 2 deletions src/ppl/nn/ir/edge.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ class Edge {
/** @brief get the id of this edge */
virtual edgeid_t GetId() const = 0;

virtual void SetName(const std::string&) = 0;
virtual const std::string& GetName() const = 0;
virtual const char* GetName() const = 0;

/**
@brief get producer node id
Expand Down
49 changes: 13 additions & 36 deletions src/ppl/nn/ir/full_graph_topo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,13 @@ using namespace ppl::common;

namespace ppl { namespace nn { namespace ir {

static Node* FindNode(const string& name, const unique_ptr<Node>* nodes, uint32_t count) {
for (uint32_t i = 0; i < count; ++i) {
auto node = nodes[i].get();
if (node && node->GetName() == name) {
return node;
}
}
return nullptr;
}

static Edge* FindEdge(const string& name, const unique_ptr<Edge>* edges, uint32_t count) {
for (uint32_t i = 0; i < count; ++i) {
auto edge = edges[i].get();
if (edge && edge->GetName() == name) {
return edge;
}
}
return nullptr;
}

pair<Node*, bool> FullGraphTopo::AddNode(const string& name) {
auto node = FindNode(name, nodes_.data(), nodes_.size());
if (node) {
return make_pair(node, false);
auto ret_pair = name2nid_.insert(make_pair(name, GetCurrentNodeIdBound()));
if (!ret_pair.second) {
return make_pair(nodes_[ret_pair.first->second].get(), false);
}

node = new Node(nodes_.size());
node->SetName(name);
auto node = new Node(GetCurrentNodeIdBound(), ret_pair.first->first.c_str());
nodes_.emplace_back(unique_ptr<Node>(node));
return make_pair(node, true);
}
Expand All @@ -64,22 +43,20 @@ Node* FullGraphTopo::GetNode(nodeid_t nid) const {

void FullGraphTopo::DelNode(nodeid_t nid) {
if (nid < nodes_.size() && nodes_[nid]) {
name2nid_.erase(nodes_[nid]->GetName());
nodes_[nid].reset();
}
}

class FullGraphEdge final : public Edge {
public:
FullGraphEdge(edgeid_t id) : id_(id), producer_(INVALID_NODEID) {}
FullGraphEdge(edgeid_t id, const char* name) : id_(id), name_(name), producer_(INVALID_NODEID) {}

edgeid_t GetId() const override {
return id_;
}

void SetName(const std::string& name) override {
name_ = name;
}
const std::string& GetName() const override {
const char* GetName() const override {
return name_;
}

Expand Down Expand Up @@ -116,7 +93,7 @@ class FullGraphEdge final : public Edge {

private:
const edgeid_t id_;
std::string name_;
const char* name_; // pointer to GraphTopo::name2eid_[idx]::first
nodeid_t producer_;
std::vector<nodeid_t> consumers_;

Expand All @@ -126,13 +103,12 @@ class FullGraphEdge final : public Edge {
};

pair<Edge*, bool> FullGraphTopo::AddEdge(const string& name) {
auto edge = FindEdge(name, edges_.data(), edges_.size());
if (edge) {
return make_pair(edge, false);
auto ret_pair = name2eid_.insert(make_pair(name, GetCurrentEdgeIdBound()));
if (!ret_pair.second) {
return make_pair(edges_[ret_pair.first->second].get(), false);
}

edge = new FullGraphEdge(GetCurrentEdgeIdBound());
edge->SetName(name);
auto edge = new FullGraphEdge(GetCurrentEdgeIdBound(), ret_pair.first->first.c_str());
edges_.emplace_back(unique_ptr<Edge>(edge));
return make_pair(edge, true);
}
Expand All @@ -158,6 +134,7 @@ void FullGraphTopo::DelEdge(edgeid_t eid) {
utils::VectorRemoveAllIf(outputs_, p);
utils::VectorRemoveAllIf(constants_, p);

name2eid_.erase(edges_[eid]->GetName());
edges_[eid].reset();
}

Expand Down
Loading

0 comments on commit c706e55

Please sign in to comment.