Skip to content

Commit

Permalink
[fix] add onnx::Dropout v12 support
Browse files Browse the repository at this point in the history
  • Loading branch information
ouonline committed Jan 15, 2024
1 parent a152d82 commit 00efe89
Showing 1 changed file with 85 additions and 16 deletions.
101 changes: 85 additions & 16 deletions src/ppl/nn/optimizers/skip_dropout_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,107 @@
// under the License.

#include "ppl/nn/optimizers/skip_dropout_optimizer.h"
#include "ppl/nn/common/logger.h"

using namespace ppl::common;

namespace ppl { namespace nn {

inline bool IsGraphOutput(const ir::Graph* graph, edgeid_t edge_id) {
for (uint32_t i = 0; i < graph->topo->GetOutputCount(); i++) {
if (graph->topo->GetOutput(i) == edge_id) {
static void DeleteNodeAndOutput(const ir::Node* node, ir::GraphTopo* topo) {
auto in_eid = node->GetInput(0);
auto out_eid = node->GetOutput(0);
auto in_edge = topo->GetEdge(in_eid);
auto out_edge = topo->GetEdge(out_eid);

for (uint32_t i = 0; i < node->GetInputCount(); ++i) {
auto edge = topo->GetEdge(node->GetInput(i));
edge->DelConsumer(node->GetId());
}

for (auto it = out_edge->CreateConsumerIter(); it.IsValid(); it.Forward()) {
auto nid = it.Get();
auto next = topo->GetNode(nid);
auto nr = next->ReplaceInput(out_eid, in_eid);
if (nr > 0) {
in_edge->AddConsumer(nid);
}
}

topo->DelEdge(out_eid);
topo->DelEdge(node->GetOutput(1));
topo->DelNode(node->GetId());
}

static void DeleteNodeAndInput(const ir::Node* node, ir::GraphTopo* topo) {
auto in_eid = node->GetInput(0);
auto out_eid = node->GetOutput(0);
auto in_edge = topo->GetEdge(in_eid);
auto out_edge = topo->GetEdge(out_eid);

for (uint32_t i = 1; i < node->GetInputCount(); ++i) {
auto edge = topo->GetEdge(node->GetInput(i));
edge->DelConsumer(node->GetId());
}

auto prev = topo->GetNode(in_edge->GetProducer());
auto nr = prev->ReplaceOutput(in_eid, out_eid);
if (nr > 0) {
out_edge->SetProducer(prev->GetId());
}

topo->DelEdge(in_eid);
topo->DelEdge(node->GetOutput(1));
topo->DelNode(node->GetId());
}

static bool IsGraphInput(const ir::GraphTopo* topo, edgeid_t edge_id) {
for (uint32_t i = 0; i < topo->GetInputCount(); i++) {
if (topo->GetInput(i) == edge_id) {
return true;
}
}
return false;
}

static bool IsGraphOutput(const ir::GraphTopo* topo, edgeid_t edge_id) {
for (uint32_t i = 0; i < topo->GetOutputCount(); i++) {
if (topo->GetOutput(i) == edge_id) {
return true;
}
}
return false;
}

RetCode SkipDropoutOptimizer::Optimize(ir::Graph* graph) const {
for (auto it = graph->topo->CreateNodeIter(); it->IsValid(); it->Forward()) {
auto topo = graph->topo.get();
for (auto it = topo->CreateNodeIter(); it->IsValid(); it->Forward()) {
auto node = it->Get();
if (node->GetType().name == "Dropout") {
auto input_edge = graph->topo->GetEdge(node->GetInput(0));
if (input_edge->CalcConsumerCount() != 1 || IsGraphOutput(graph, input_edge->GetId())) {
auto& type = node->GetType();
if (type.name == "Dropout") {
if (node->GetOutputCount() > 2) {
LOG(ERROR) << "unsupported Dropout version [" << type.version << "]";
return RC_UNSUPPORTED;
}

bool input_is_graph_input = IsGraphInput(topo, node->GetInput(0));
bool output_is_graph_output = IsGraphOutput(topo, node->GetOutput(0));

if (input_is_graph_input && output_is_graph_output) {
continue;
}
auto output_edge = graph->topo->GetEdge(node->GetOutput(0));
auto node_pre = graph->topo->GetNode(input_edge->GetProducer());
node_pre->ReplaceOutput(input_edge->GetId(), output_edge->GetId());
output_edge->SetProducer(node_pre->GetId());

graph->topo->DelEdge(input_edge->GetId());
if (node->GetOutputCount() >= 2) {
graph->topo->DelEdge(node->GetOutput(1));

if (node->GetOutputCount() == 2) {
auto mask_edge = topo->GetEdge(node->GetOutput(1));
if (mask_edge->CalcConsumerCount() > 0) {
continue;
}
}

if (output_is_graph_output) {
DeleteNodeAndInput(node, topo);
} else {
DeleteNodeAndOutput(node, topo);
}
graph->topo->DelNode(node->GetId());
}
}

Expand Down

0 comments on commit 00efe89

Please sign in to comment.