From c01c27e1f60296c15577f58063ecd4bea3b31378 Mon Sep 17 00:00:00 2001 From: YongHyun An Date: Mon, 5 Feb 2024 16:47:16 +0900 Subject: [PATCH] [onert] Add check for topological order validity This commit introduces the following functions - assertValidTopologicalOrder - assertValidBackwardTopologicalOrder which will catch any potential bugs on forward or backward topological ordering used in onert. ONE-DCO-1.0-Signed-off-by: YongHyun An --- .../core/include/ir/train/TrainableGraph.h | 5 ++ .../onert/core/src/ir/train/TrainableGraph.cc | 67 ++++++++++++++++++- 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/runtime/onert/core/include/ir/train/TrainableGraph.h b/runtime/onert/core/include/ir/train/TrainableGraph.h index 3c8d4a0f630..eb0fab78c31 100644 --- a/runtime/onert/core/include/ir/train/TrainableGraph.h +++ b/runtime/onert/core/include/ir/train/TrainableGraph.h @@ -127,6 +127,11 @@ class TrainableGraph : public IGraph public: const ITrainableOperation &operation(OperationIndex index) const; +private: + void validateTopologicalOrder(std::vector order, bool is_forward) const; + void validateForwardTopologicalOrder(const std::vector &order) const; + void validateBackwardTopologicalOrder(const std::vector &order) const; + public: std::vector topolSortOperations() const; std::vector btopolSortOperations() const; diff --git a/runtime/onert/core/src/ir/train/TrainableGraph.cc b/runtime/onert/core/src/ir/train/TrainableGraph.cc index f26a20a84bf..1cf3d771010 100644 --- a/runtime/onert/core/src/ir/train/TrainableGraph.cc +++ b/runtime/onert/core/src/ir/train/TrainableGraph.cc @@ -19,6 +19,7 @@ #include "util/Set.h" #include +#include #include namespace onert @@ -125,9 +126,72 @@ const ITrainableOperation &TrainableGraph::operation(OperationIndex index) const return dynamic_cast(_graph.operations().at(index)); } +void TrainableGraph::validateTopologicalOrder(std::vector order, + bool is_forward) const +{ + if (!is_forward) + std::reverse(order.begin(), order.end()); + + const std::string order_type = is_forward ? "forward" : "backward"; + + auto compare = [](const ir::OperationIndex &i, const ir::OperationIndex &j) -> bool { + return i.value() < j.value(); + }; + + std::map position(compare); + for (uint32_t p = 0; p < order.size(); ++p) + { + auto index = order[p]; + // TODO: replace this with `std::map::contains` after C++20 + if (position.find(index) != position.end()) + throw std::runtime_error{"Invalid " + order_type + " topological order: duplicate node @" + + std::to_string(index.value())}; + + position[index] = p; + } + + operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) { + if (position.count(index) == 0) + return; + + uint32_t p = position[index]; + + for (const auto &output : op.getOutputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + { + const auto &operand = operands().at(output); + for (const auto &use : operand.getUses()) + { + if (position.count(use) == 0) + continue; + + uint32_t q = position[use]; + if (p > q) + throw std::runtime_error{ + "Invalid " + order_type + " topological order: inversion between @" + + std::to_string(index.value()) + " and @" + std::to_string(use.value())}; + } + } + }); +} + +void TrainableGraph::validateForwardTopologicalOrder( + const std::vector &order) const +{ + validateTopologicalOrder(order, true); +} + +void TrainableGraph::validateBackwardTopologicalOrder( + const std::vector &order) const +{ + validateTopologicalOrder(order, false); +} + std::vector TrainableGraph::topolSortOperations() const { - return _graph.topolSortOperations(); + auto ret = _graph.topolSortOperations(); + validateForwardTopologicalOrder(ret); + + return ret; } std::vector TrainableGraph::btopolSortOperations() const @@ -162,6 +226,7 @@ std::vector TrainableGraph::btopolSortOperations() const }; dfs(loss_idx, operations().at(loss_idx)); + validateBackwardTopologicalOrder(ret); return ret; }