Skip to content

Commit

Permalink
[onert] Add check for topological order validity (#12602)
Browse files Browse the repository at this point in the history
This commit introduces the following functions

- validateTopologicalOrder
- validateBackwardTopologicalOrder

which will catch any potential bugs on forward or backward topological ordering used in onert.

ONE-DCO-1.0-Signed-off-by: YongHyun An <[email protected]>
  • Loading branch information
Aeren1564 authored Feb 7, 2024
1 parent 75e06dd commit 1ef4072
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
5 changes: 5 additions & 0 deletions runtime/onert/core/include/ir/train/TrainableGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,11 @@ class TrainableGraph : public IGraph
public:
const ITrainableOperation &operation(OperationIndex index) const;

private:
void validateTopologicalOrder(std::vector<ir::OperationIndex> order, bool is_forward) const;
void validateForwardTopologicalOrder(const std::vector<ir::OperationIndex> &order) const;
void validateBackwardTopologicalOrder(const std::vector<ir::OperationIndex> &order) const;

public:
std::vector<ir::OperationIndex> topolSortOperations() const;
std::vector<ir::OperationIndex> btopolSortOperations() const;
Expand Down
67 changes: 66 additions & 1 deletion runtime/onert/core/src/ir/train/TrainableGraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "util/Set.h"

#include <algorithm>
#include <map>
#include <misc/polymorphic_downcast.h>

namespace onert
Expand Down Expand Up @@ -125,9 +126,72 @@ const ITrainableOperation &TrainableGraph::operation(OperationIndex index) const
return dynamic_cast<const ITrainableOperation &>(_graph.operations().at(index));
}

void TrainableGraph::validateTopologicalOrder(std::vector<ir::OperationIndex> order,
bool is_forward) const
{
if (!is_forward)
std::reverse(order.begin(), order.end());

const std::string order_type = is_forward ? "forward" : "backward";

auto compare = [](const ir::OperationIndex &i, const ir::OperationIndex &j) -> bool {
return i.value() < j.value();
};

std::map<ir::OperationIndex, uint32_t, decltype(compare)> position(compare);
for (uint32_t p = 0; p < order.size(); ++p)
{
auto index = order[p];
// TODO: replace this with `std::map::contains` after C++20
if (position.find(index) != position.end())
throw std::runtime_error{"Invalid " + order_type + " topological order: duplicate node @" +
std::to_string(index.value())};

position[index] = p;
}

operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) {
if (position.count(index) == 0)
return;

uint32_t p = position[index];

for (const auto &output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
{
const auto &operand = operands().at(output);
for (const auto &use : operand.getUses())
{
if (position.count(use) == 0)
continue;

uint32_t q = position[use];
if (p > q)
throw std::runtime_error{
"Invalid " + order_type + " topological order: inversion between @" +
std::to_string(index.value()) + " and @" + std::to_string(use.value())};
}
}
});
}

void TrainableGraph::validateForwardTopologicalOrder(
const std::vector<ir::OperationIndex> &order) const
{
validateTopologicalOrder(order, true);
}

void TrainableGraph::validateBackwardTopologicalOrder(
const std::vector<ir::OperationIndex> &order) const
{
validateTopologicalOrder(order, false);
}

std::vector<ir::OperationIndex> TrainableGraph::topolSortOperations() const
{
return _graph.topolSortOperations();
auto ret = _graph.topolSortOperations();
validateForwardTopologicalOrder(ret);

return ret;
}

std::vector<ir::OperationIndex> TrainableGraph::btopolSortOperations() const
Expand Down Expand Up @@ -162,6 +226,7 @@ std::vector<ir::OperationIndex> TrainableGraph::btopolSortOperations() const
};

dfs(loss_idx, operations().at(loss_idx));
validateBackwardTopologicalOrder(ret);

return ret;
}
Expand Down

0 comments on commit 1ef4072

Please sign in to comment.