Skip to content

Commit

Permalink
[opt] accerarate finding nodes/edges by name
Browse files Browse the repository at this point in the history
  • Loading branch information
ouonline committed Jan 4, 2024
1 parent 63fd536 commit 256c641
Show file tree
Hide file tree
Showing 15 changed files with 77 additions and 132 deletions.
1 change: 0 additions & 1 deletion src/ppl/nn/ir/edge.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class Edge {
/** @brief get the id of this edge */
virtual edgeid_t GetId() const = 0;

virtual void SetName(const std::string&) = 0;
virtual const std::string& GetName() const = 0;

/**
Expand Down
49 changes: 13 additions & 36 deletions src/ppl/nn/ir/full_graph_topo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,13 @@ using namespace ppl::common;

namespace ppl { namespace nn { namespace ir {

static Node* FindNode(const string& name, const unique_ptr<Node>* nodes, uint32_t count) {
for (uint32_t i = 0; i < count; ++i) {
auto node = nodes[i].get();
if (node && node->GetName() == name) {
return node;
}
}
return nullptr;
}

static Edge* FindEdge(const string& name, const unique_ptr<Edge>* edges, uint32_t count) {
for (uint32_t i = 0; i < count; ++i) {
auto edge = edges[i].get();
if (edge && edge->GetName() == name) {
return edge;
}
}
return nullptr;
}

pair<Node*, bool> FullGraphTopo::AddNode(const string& name) {
auto node = FindNode(name, nodes_.data(), nodes_.size());
if (node) {
return make_pair(node, false);
auto ret_pair = name2nid_.insert(make_pair(name, GetCurrentNodeIdBound()));
if (!ret_pair.second) {
return make_pair(nodes_[ret_pair.first->second].get(), false);
}

node = new Node(nodes_.size());
node->SetName(name);
auto node = new Node(GetCurrentNodeIdBound(), name);
nodes_.emplace_back(unique_ptr<Node>(node));
return make_pair(node, true);
}
Expand All @@ -64,22 +43,20 @@ Node* FullGraphTopo::GetNode(nodeid_t nid) const {

void FullGraphTopo::DelNode(nodeid_t nid) {
if (nid < nodes_.size() && nodes_[nid]) {
name2nid_.erase(nodes_[nid]->GetName());
nodes_[nid].reset();
}
}

class FullGraphEdge final : public Edge {
public:
FullGraphEdge(edgeid_t id) : id_(id), producer_(INVALID_NODEID) {}
FullGraphEdge(edgeid_t id, const string& name) : id_(id), name_(name), producer_(INVALID_NODEID) {}

edgeid_t GetId() const override {
return id_;
}

void SetName(const std::string& name) override {
name_ = name;
}
const std::string& GetName() const override {
const string& GetName() const override {
return name_;
}

Expand Down Expand Up @@ -116,7 +93,7 @@ class FullGraphEdge final : public Edge {

private:
const edgeid_t id_;
std::string name_;
const string name_;
nodeid_t producer_;
std::vector<nodeid_t> consumers_;

Expand All @@ -126,13 +103,12 @@ class FullGraphEdge final : public Edge {
};

pair<Edge*, bool> FullGraphTopo::AddEdge(const string& name) {
auto edge = FindEdge(name, edges_.data(), edges_.size());
if (edge) {
return make_pair(edge, false);
auto ret_pair = name2eid_.insert(make_pair(name, GetCurrentEdgeIdBound()));
if (!ret_pair.second) {
return make_pair(edges_[ret_pair.first->second].get(), false);
}

edge = new FullGraphEdge(GetCurrentEdgeIdBound());
edge->SetName(name);
auto edge = new FullGraphEdge(GetCurrentEdgeIdBound(), name);
edges_.emplace_back(unique_ptr<Edge>(edge));
return make_pair(edge, true);
}
Expand All @@ -158,6 +134,7 @@ void FullGraphTopo::DelEdge(edgeid_t eid) {
utils::VectorRemoveAllIf(outputs_, p);
utils::VectorRemoveAllIf(constants_, p);

name2eid_.erase(edges_[eid]->GetName());
edges_[eid].reset();
}

Expand Down
34 changes: 19 additions & 15 deletions src/ppl/nn/ir/graph_topo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,24 @@ using namespace ppl::common;

namespace ppl { namespace nn { namespace ir {

static Node* FindNode(const ir::GraphTopo* topo, const string& name) {
for (auto it = topo->CreateNodeIter(); it->IsValid(); it->Forward()) {
auto node = it->Get();
if (node->GetName() == name) {
return node;
}
Node* GraphTopo::GetNode(const string& name) const {
auto ref = name2nid_.find(name);
if (ref == name2nid_.end()) {
return nullptr;
}
return nullptr;

return GetNode(ref->second);
}

Node* GraphTopo::GetNode(const string& name) const {
return FindNode(this, name);
bool GraphTopo::RenameNode(Node* node, const string& new_name) {
auto ret_pair = name2nid_.insert(make_pair(new_name, node->GetId()));
if (!ret_pair.second) {
return false;
}

name2nid_.erase(node->GetName());
node->name_ = new_name;
return true;
}

static edgeid_t FindEdgeId(const string& name, const vector<edgeid_t>& edge_ids, const GraphTopo* topo) {
Expand All @@ -61,13 +67,11 @@ static uint32_t FindEdgeIdIdx(const string& name, const vector<edgeid_t>& edge_i
}

Edge* GraphTopo::GetEdge(const std::string& name) const {
for (auto it = CreateEdgeIter(); it->IsValid(); it->Forward()) {
auto edge = it->Get();
if (edge->GetName() == name) {
return edge;
}
auto ref = name2eid_.find(name);
if (ref == name2eid_.end()) {
return nullptr;
}
return nullptr;
return GetEdge(ref->second);
}

edgeid_t GraphTopo::GetInput(const string& name) const {
Expand Down
9 changes: 9 additions & 0 deletions src/ppl/nn/ir/graph_topo.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <set>
#include <memory>
#include <functional>
#include <unordered_map>

namespace ppl { namespace nn { namespace ir {

Expand Down Expand Up @@ -87,6 +88,8 @@ class GraphTopo {

Node* GetNode(const std::string& name) const;

bool RenameNode(Node*, const std::string& new_name);

// ----- //

/**
Expand Down Expand Up @@ -207,6 +210,12 @@ class GraphTopo {
/** output edge ids */
std::vector<edgeid_t> outputs_;

/** node name => node id */
std::unordered_map<std::string, nodeid_t> name2nid_;

/** edge name => edge id */
std::unordered_map<std::string, edgeid_t> name2eid_;

private:
GraphTopo(const GraphTopo&) = delete;
GraphTopo& operator=(const GraphTopo&) = delete;
Expand Down
7 changes: 3 additions & 4 deletions src/ppl/nn/ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,12 @@ class Node final {
};

public:
Node(nodeid_t id) : id_(id) {}
Node(nodeid_t id, const std::string& name) : id_(id), name_(name) {}

nodeid_t GetId() const {
return id_;
}

void SetName(const std::string& name) {
name_ = name;
}
const std::string& GetName() const {
return name_;
}
Expand Down Expand Up @@ -166,6 +163,8 @@ class Node final {
uint32_t ReplaceExtraInput(edgeid_t old_value, edgeid_t new_value);

private:
friend class GraphTopo; // for GraphTopo::RenameNode()

const nodeid_t id_;
std::string name_;
Type type_;
Expand Down
7 changes: 1 addition & 6 deletions src/ppl/nn/ir/partial_graph_topo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,12 @@ namespace ppl { namespace nn { namespace ir {

class PartialGraphEdge final : public Edge {
public:
PartialGraphEdge(Edge* orig_edge, const vector<Node*>* node_ptrs)
: orig_edge_(orig_edge), node_ptrs_(node_ptrs) {}
PartialGraphEdge(Edge* orig_edge, const vector<Node*>* node_ptrs) : orig_edge_(orig_edge), node_ptrs_(node_ptrs) {}

edgeid_t GetId() const override {
return orig_edge_->GetId();
}

void SetName(const string& name) override {
orig_edge_->SetName(name);
}

const string& GetName() const override {
return orig_edge_->GetName();
}
Expand Down
8 changes: 7 additions & 1 deletion src/ppl/nn/optimizers/fuse_shape_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,14 @@ RetCode FuseShapeOptimizer::Optimize(ir::Graph* graph) const {
if (!node || node->GetType().domain != "" || node->GetType().name != "Shape") {
continue;
}

const string new_name(node->GetName() + "_Fused");
bool ok = graph->topo->RenameNode(node, new_name);
if (!ok) {
LOG(ERROR) << "rename node [" << node->GetName() << "] to new name [" << new_name << "] failed.";
return RC_EXISTS;
}
node->SetType(ir::Node::Type{"pmx", "Shape", 1});
node->SetName(node->GetName() + "_Fused");

ShapeOperationParam shape_param;
ShapeMatrix temp_matrix;
Expand Down
9 changes: 0 additions & 9 deletions src/ppl/nn/runtime/runtime_aux_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,6 @@ using namespace ppl::common;

namespace ppl { namespace nn {

static void InitName2Nodeid(const ir::GraphTopo* topo, map<string, nodeid_t>* name2nodeid) {
for (auto it = topo->CreateNodeIter(); it->IsValid(); it->Forward()) {
auto node = it->Get();
name2nodeid->insert(make_pair(node->GetName(), node->GetId()));
}
}

RetCode RuntimeAuxInfo::Init(const ir::GraphTopo* topo, const set<edgeid_t>& reserved_edgeids) {
utils::DfsDeeperFirst(topo, [this](nodeid_t nid) -> void {
this->sorted_nodes.push_back(nid);
Expand All @@ -42,8 +35,6 @@ RetCode RuntimeAuxInfo::Init(const ir::GraphTopo* topo, const set<edgeid_t>& res
return status;
}

InitName2Nodeid(topo, &name2nodeid);

return RC_SUCCESS;
}

Expand Down
4 changes: 0 additions & 4 deletions src/ppl/nn/runtime/runtime_aux_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include "ppl/nn/common/types.h"
#include "ppl/nn/ir/graph_topo.h"
#include <set>
#include <map>
#include <vector>

namespace ppl { namespace nn {
Expand All @@ -39,9 +38,6 @@ struct RuntimeAuxInfo final {

/** an `EdgeObject` can be released right after the last consumer finish executing in `sorted_nodes` */
std::vector<nodeid_t> edge_last_consumer;

/** node name => id mapping */
std::map<std::string, nodeid_t> name2nodeid;
};

}} // namespace ppl::nn
Expand Down
36 changes: 7 additions & 29 deletions src/ppl/nn/runtime/runtime_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ static RetCode GenGraphKernels(const RuntimeGraphInfo& info, vector<unique_ptr<E
}

static RetCode GenGraphInputs(const ir::GraphTopo* topo, const RuntimeGraphInfo& info,
const map<string, nodeid_t>& name2nodeid,
const vector<unique_ptr<KernelImpl>>& nodeid2kernel,
map<string, TensorImpl>* reserved_tensors) {
for (uint32_t i = 0; i < topo->GetInputCount(); ++i) {
Expand All @@ -96,14 +95,8 @@ static RetCode GenGraphInputs(const ir::GraphTopo* topo, const RuntimeGraphInfo&
continue;
}

auto nid_ref = name2nodeid.find(consumer->GetName());
if (nid_ref == name2nodeid.end()) {
LOG(ERROR) << "cannot find consumer[" << consumer->GetName() << "] of [" << edge->GetName() << "]";
return RC_NOT_FOUND;
}

// Consumers of an input are in the same engine. This is guranteed by optimizer.
auto kernel = nodeid2kernel[nid_ref->second].get();
auto kernel = nodeid2kernel[consumer->GetId()].get();
tensor->SetDevice(kernel->GetEngineContext()->GetDevice());
break;
}
Expand All @@ -120,7 +113,6 @@ static RetCode GenGraphInputs(const ir::GraphTopo* topo, const RuntimeGraphInfo&
}

static RetCode GenGraphExtraInputs(const ir::GraphTopo* topo, const RuntimeGraphInfo& info,
const map<string, nodeid_t>& name2nodeid,
const vector<unique_ptr<KernelImpl>>& nodeid2kernel,
map<string, TensorImpl>* reserved_tensors) {
for (uint32_t i = 0; i < topo->GetExtraInputCount(); ++i) {
Expand All @@ -137,14 +129,8 @@ static RetCode GenGraphExtraInputs(const ir::GraphTopo* topo, const RuntimeGraph
continue;
}

auto nid_ref = name2nodeid.find(consumer->GetName());
if (nid_ref == name2nodeid.end()) {
LOG(ERROR) << "cannot find consumer[" << consumer->GetName() << "] of [" << edge->GetName() << "]";
return RC_NOT_FOUND;
}

// Consumers of an input are in the same engine. This is guranteed by optimizer.
auto kernel = nodeid2kernel[nid_ref->second].get();
auto kernel = nodeid2kernel[consumer->GetId()].get();
tensor->SetDevice(kernel->GetEngineContext()->GetDevice());
break;
}
Expand All @@ -160,7 +146,7 @@ static RetCode GenGraphExtraInputs(const ir::GraphTopo* topo, const RuntimeGraph
}

RetCode GenGraphOutputs(const ir::GraphTopo* topo, const RuntimeGraphInfo& info,
const map<string, nodeid_t>& name2nodeid, const vector<unique_ptr<KernelImpl>>& nodeid2kernel,
const vector<unique_ptr<KernelImpl>>& nodeid2kernel,
map<string, TensorImpl>* reserved_tensors) {
for (uint32_t i = 0; i < topo->GetOutputCount(); ++i) {
auto eid = topo->GetOutput(i);
Expand All @@ -172,15 +158,7 @@ RetCode GenGraphOutputs(const ir::GraphTopo* topo, const RuntimeGraphInfo& info,
if (ret_pair.second) {
auto producer_id = edge->GetProducer();
if (producer_id != INVALID_NODEID) {
auto producer = topo->GetNode(producer_id);

auto nid_ref = name2nodeid.find(producer->GetName());
if (nid_ref == name2nodeid.end()) {
LOG(ERROR) << "cannot find producer[" << producer->GetName() << "] of [" << edge->GetName() << "]";
return RC_NOT_FOUND;
}

auto kernel = nodeid2kernel[nid_ref->second].get();
auto kernel = nodeid2kernel[producer_id].get();
tensor->SetDevice(kernel->GetEngineContext()->GetDevice());
}

Expand Down Expand Up @@ -260,19 +238,19 @@ static RetCode InitGraphResources(const ir::GraphTopo* topo, const RuntimeGraphI
return status;
}

status = GenGraphInputs(topo, info, aux_info.name2nodeid, *nodeid2kernel, reserved_tensors);
status = GenGraphInputs(topo, info, *nodeid2kernel, reserved_tensors);
if (status != RC_SUCCESS) {
LOG(ERROR) << "GenGraphInputs failed: " << GetRetCodeStr(status);
return status;
}

status = GenGraphExtraInputs(topo, info, aux_info.name2nodeid, *nodeid2kernel, reserved_tensors);
status = GenGraphExtraInputs(topo, info, *nodeid2kernel, reserved_tensors);
if (status != RC_SUCCESS) {
LOG(ERROR) << "GenGraphExtraInputs failed: " << GetRetCodeStr(status);
return status;
}

status = GenGraphOutputs(topo, info, aux_info.name2nodeid, *nodeid2kernel, reserved_tensors);
status = GenGraphOutputs(topo, info, *nodeid2kernel, reserved_tensors);
if (status != RC_SUCCESS) {
LOG(ERROR) << "GenGraphOutputs failed: " << GetRetCodeStr(status);
return status;
Expand Down
2 changes: 1 addition & 1 deletion tests/common/input_output_info_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ TEST_F(InputOutputInfoTest, misc) {
auto topo = builder_.GetGraph()->topo.get();

auto node = topo->GetNode(0);
EXPECT_EQ("a", node->GetName());
EXPECT_EQ(string("a"), node->GetName());

InputOutputInfo info;
info.SetNode(node);
Expand Down
Loading

0 comments on commit 256c641

Please sign in to comment.