diff --git a/runtime/onert/core/include/ir/train/TrainableGraph.h b/runtime/onert/core/include/ir/train/TrainableGraph.h index a2978d25209..5d234fa9f9b 100644 --- a/runtime/onert/core/include/ir/train/TrainableGraph.h +++ b/runtime/onert/core/include/ir/train/TrainableGraph.h @@ -129,7 +129,7 @@ class TrainableGraph : public IGraph public: std::vector topolSortOperations() const; - // TODO Support topological sort for backwarding + std::vector btopolSortOperations() const; private: Graph _graph; diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.cc b/runtime/onert/core/src/compiler/ExecutorFactory.cc index 485913bec33..ea7c1813cbf 100644 --- a/runtime/onert/core/src/compiler/ExecutorFactory.cc +++ b/runtime/onert/core/src/compiler/ExecutorFactory.cc @@ -749,10 +749,16 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor( (lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs()) | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED); - // linearize + // linearize for forwarding auto order = Linear::linearize(*lowered_graph); + VERBOSE(ExecutorFactory) << "Linearize for forwarding order" << std::endl; Linear::dump(*lowered_graph, order); + // linearize for backwarding + auto backward_order = lowered_graph->trainable_graph().btopolSortOperations(); + VERBOSE(ExecutorFactory) << "Linearize for backwarding order" << std::endl; + Linear::dump(*lowered_graph, backward_order); + for (auto &&pair : tbackend_contexts) { pair.second->genTensors(); @@ -885,6 +891,7 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor( tensor_regs, std::move(code_map), order, + backward_order, tracing_ctx, training_info.lossInfo()}; diff --git a/runtime/onert/core/src/exec/train/TrainableExecutor.cc b/runtime/onert/core/src/exec/train/TrainableExecutor.cc index 70a97f0c69d..666e03531f4 100644 --- a/runtime/onert/core/src/exec/train/TrainableExecutor.cc +++ b/runtime/onert/core/src/exec/train/TrainableExecutor.cc @@ -32,9 +32,13 @@ TrainableExecutor::TrainableExecutor( std::unique_ptr lowered_graph, backend::train::TrainableBackendContexts &&backend_contexts, const compiler::train::TensorRegistries &tensor_regs, - compiler::train::TrainableCodeMap &&code_map, const std::vector &order, - const util::TracingCtx *tracing_ctx, const ir::train::LossInfo &loss_info) - : _lowered_graph{std::move(lowered_graph)}, _backend_contexts{std::move(backend_contexts)}, + compiler::train::TrainableCodeMap &&code_map, + const std::vector &forward_order, + const std::vector &backward_order, const util::TracingCtx *tracing_ctx, + const ir::train::LossInfo &loss_info) + : _code_map{std::move(code_map)}, _forward_order{std::move(forward_order)}, + _backward_order{std::move(backward_order)}, _lowered_graph{std::move(lowered_graph)}, + _backend_contexts{std::move(backend_contexts)}, _trainable_graph{_lowered_graph->trainable_graph()}, _tensor_regs{std::move(tensor_regs)}, _mutex(), _tracing_ctx(tracing_ctx), _loss_info(loss_info) { @@ -50,12 +54,6 @@ TrainableExecutor::TrainableExecutor( }; build_tensor_list(_trainable_graph.getInputs(), _input_tensors); build_tensor_list(_trainable_graph.getOutputs(), _output_tensors); - - for (auto &&index : order) - { - auto &trainable_code = code_map.at(index); - _code.emplace_back(std::move(trainable_code)); - } } void TrainableExecutor::execute(const std::vector &, @@ -110,8 +108,9 @@ void TrainableExecutor::forwardImpl(bool training) auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph()); _subject.notifySubgraphBegin(profiling_subg_index); - for (auto &&code : _code) + for (auto &&index : _forward_order) { + const auto &code = _code_map.at(index); const auto backend = code.lower_info->backend(); // TODO : Move ruy profiler into ExecutionObserver #ifdef RUY_PROFILER @@ -128,8 +127,9 @@ void TrainableExecutor::forwardImpl(bool training) } else { - for (auto &&code : _code) + for (auto &&index : _forward_order) { + const auto &code = _code_map.at(index); // TODO : Move ruy profiler into ExecutionObserver #ifdef RUY_PROFILER ruy::profiler::ScopeLabel label(code.op->name()); @@ -157,9 +157,9 @@ void TrainableExecutor::backwardImpl(uint32_t training_step) auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph()); _subject.notifySubgraphBegin(profiling_subg_index); - for (auto it = _code.rbegin(); it != _code.rend(); ++it) + for (auto &&index : _backward_order) { - const auto &code = *it; + const auto &code = _code_map.at(index); const auto backend = code.lower_info->backend(); // TODO : Move ruy profiler into ExecutionObserver #ifdef RUY_PROFILER @@ -176,9 +176,9 @@ void TrainableExecutor::backwardImpl(uint32_t training_step) } else { - for (auto it = _code.rbegin(); it != _code.rend(); ++it) + for (auto &&index : _backward_order) { - const auto &code = *it; + const auto &code = _code_map.at(index); // TODO : Move ruy profiler into ExecutionObserver #ifdef RUY_PROFILER ruy::profiler::ScopeLabel label(code.op->name()); diff --git a/runtime/onert/core/src/exec/train/TrainableExecutor.h b/runtime/onert/core/src/exec/train/TrainableExecutor.h index 877f81a731f..92bb0dc84fe 100644 --- a/runtime/onert/core/src/exec/train/TrainableExecutor.h +++ b/runtime/onert/core/src/exec/train/TrainableExecutor.h @@ -49,7 +49,8 @@ class TrainableExecutor : public IExecutor backend::train::TrainableBackendContexts &&backend_contexts, const compiler::train::TensorRegistries &tensor_regs, compiler::train::TrainableCodeMap &&code_map, - const std::vector &order, + const std::vector &forward_order, + const std::vector &backward_order, const util::TracingCtx *tracing_ctx, const ir::train::LossInfo &training_info); public: @@ -90,7 +91,9 @@ class TrainableExecutor : public IExecutor void backwardImpl(uint32_t training_step); private: - std::vector _code; + compiler::train::TrainableCodeMap _code_map; + std::vector _forward_order; + std::vector _backward_order; ExecutionObservee _subject; std::shared_ptr> _indexed_ranks; std::unique_ptr _lowered_graph; diff --git a/runtime/onert/core/src/ir/train/TrainableGraph.cc b/runtime/onert/core/src/ir/train/TrainableGraph.cc index 1805d6c47d4..1dadb02dc61 100644 --- a/runtime/onert/core/src/ir/train/TrainableGraph.cc +++ b/runtime/onert/core/src/ir/train/TrainableGraph.cc @@ -16,6 +16,7 @@ #include "ir/train/TrainableGraph.h" #include "util/Utils.h" +#include "util/Set.h" #include #include @@ -128,6 +129,42 @@ std::vector TrainableGraph::topolSortOperations() const return _graph.topolSortOperations(); } +std::vector TrainableGraph::btopolSortOperations() const +{ + std::vector ret; + util::Set unvisited; + ir::OperationIndex loss_idx; + operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) { + unvisited.add(index); + if (op.opcode() == ir::OpCode::Loss) + { + assert(!loss_idx.valid()); // Should be only one loss + loss_idx = index; + } + }); + + std::function dfs = + [&](const ir::OperationIndex &index, const ir::IOperation &op) -> void { + if (!unvisited.contains(index)) + return; + unvisited.remove(index); + ret.push_back(index); + + for (const auto &input : op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED) + { + const auto &operand = operands().at(input); + const auto &def = operand.getDef(); + if (!def.valid()) + return; + dfs(def, operations().at(def)); + } + }; + + dfs(loss_idx, operations().at(loss_idx)); + + return ret; +} + void TrainableGraph::addLoss(const OperandIndex &loss_ind, const IOIndex &pred_ioind) { _losses.emplace(pred_ioind, loss_ind);