From 86c70adecfa05fdbbdeab978c69253992ef9fcd2 Mon Sep 17 00:00:00 2001 From: Jang Jiseob Date: Fri, 26 Jul 2024 14:31:04 +0900 Subject: [PATCH] [onert] Apply usedef chains for training (#13462) This commit applies training usedefs to patial trainable graphs to be passed into backends. - Add a method to TrainableGraph that updates trainable graph dependency by using training usedefs - Call the method for patial trainable graphs ONE-DCO-1.0-Signed-off-by: ragmani --- .../core/include/ir/train/TrainableGraph.h | 3 + .../core/src/compiler/ExecutorFactory.cc | 15 ++++- .../src/compiler/train/TrainingCompiler.cc | 9 +++ .../onert/core/src/ir/train/TrainableGraph.cc | 59 +++++++++++++++++++ 4 files changed, 85 insertions(+), 1 deletion(-) diff --git a/runtime/onert/core/include/ir/train/TrainableGraph.h b/runtime/onert/core/include/ir/train/TrainableGraph.h index 1faf8c56714..230d335bb23 100644 --- a/runtime/onert/core/include/ir/train/TrainableGraph.h +++ b/runtime/onert/core/include/ir/train/TrainableGraph.h @@ -161,6 +161,9 @@ class TrainableGraph : public IGraph truncateBackwardOrder(std::vector backward_order, std::function truncating_cond) const; +public: + void updateGraphDependency(); + private: Graph _graph; Operands _backward_operands; diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.cc b/runtime/onert/core/src/compiler/ExecutorFactory.cc index 22af010c751..1c3749ad992 100644 --- a/runtime/onert/core/src/compiler/ExecutorFactory.cc +++ b/runtime/onert/core/src/compiler/ExecutorFactory.cc @@ -695,6 +695,20 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor( external_operands.remove(index); } + const auto backend = pair.first; + // NOTE The builtin backend does not yet support initializing UseDefs for training + // because it's graph does not have loss operation + // Without loss opeartion, we cannot call btopolSortOperations() or + // getEssentialBackwardOrder() + // TODO Modify checking the condition to check whether loss op exists + if (backend->config()->id() != "builtin") + { + // Initialize training def-uses + tgraph->updateGraphDependency(); + + tgraph->verify(); + } + // Set trainable context data backend::train::TrainableContextData tdata; tdata.tgraph = std::move(tgraph); @@ -706,7 +720,6 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor( tdata.optim_info = training_info.optimizerInfo(); // TODO Remove dynamic_cast - const auto backend = pair.first; const auto tbackend = dynamic_cast(backend); if (!tbackend) { diff --git a/runtime/onert/core/src/compiler/train/TrainingCompiler.cc b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc index d807d44fca9..aee23a6d992 100644 --- a/runtime/onert/core/src/compiler/train/TrainingCompiler.cc +++ b/runtime/onert/core/src/compiler/train/TrainingCompiler.cc @@ -166,6 +166,15 @@ std::shared_ptr TrainingCompiler::compile(void) dot_dumper.dump(*subg, nnfw::misc::str("after_loss_insertion-", subg_index.value())); } + for (auto &&[subg_index, subg] : trainable_subgraphs) + { + subg->updateGraphDependency(); + subg->verify(); + + dot_dumper.dump(*subg, + nnfw::misc::str("after_initializing_training_usedefs-", subg_index.value())); + } + // Change input shape according to batch_size for (auto &&pair : trainable_subgraphs) { diff --git a/runtime/onert/core/src/ir/train/TrainableGraph.cc b/runtime/onert/core/src/ir/train/TrainableGraph.cc index ce89dcf2284..9a3998d6c5c 100644 --- a/runtime/onert/core/src/ir/train/TrainableGraph.cc +++ b/runtime/onert/core/src/ir/train/TrainableGraph.cc @@ -17,6 +17,7 @@ #include "ir/train/TrainableGraph.h" #include "ir/OperandIndexMap.h" +#include "UseDefGenerator.h" #include "util/Utils.h" #include "util/Set.h" #include "../verifier/Verifier.h" @@ -26,6 +27,52 @@ #include #include +namespace +{ + +using namespace onert; +using namespace onert::ir; +using namespace onert::ir::train; + +void disableUnusedBackwardNodes(const UseDefChains &training_usedefs, TrainableGraph &tgraph) +{ + // Disable backward nodes that will be unused + const auto border = tgraph.btopolSortOperations(); + for (const auto &op_index : border) + { + const auto &node = tgraph.operations().at(op_index); + const auto &candidates = + (node.getInputs() + node.getOutputs()) | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + const bool is_backward_op_used = + std::any_of(candidates.begin(), candidates.end(), [&](const OperandIndex &operand) { + const auto training_op_index = TrainingOperationIndex{op_index, false}; + const auto forwarding_index = TrainingOperandIndex{operand, true}; + const auto &forwarding_uses = training_usedefs.at(forwarding_index).getTrainingUses(); + const auto backwarding_index = TrainingOperandIndex{operand, false}; + const auto &backwarding_uses = training_usedefs.at(backwarding_index).getTrainingUses(); + return forwarding_uses.find(training_op_index) != forwarding_uses.end() || + backwarding_uses.find(training_op_index) != backwarding_uses.end(); + }); + + // NOTE Backward op does not define any incoming operand in backwarding + const auto &inputs = node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED; + const bool is_backward_op_def = + std::any_of(inputs.begin(), inputs.end(), [&](const OperandIndex &input) { + const auto training_op_index = TrainingOperationIndex{op_index, false}; + const auto outcoming_index = TrainingOperandIndex{input, false}; + const auto &backwarding_defs = training_usedefs.at(outcoming_index).getTrainingUses(); + return backwarding_defs.find(training_op_index) != backwarding_defs.end(); + }); + + if (is_backward_op_used || is_backward_op_def) + tgraph.enableBackward(op_index); + else + tgraph.disableBackward(op_index); + } +} + +} // namespace + namespace onert { namespace ir @@ -332,6 +379,18 @@ OperandIndex TrainableGraph::getLossIndex(const IOIndex &pred_ioind) const return (itr == _losses.end()) ? OperandIndex{} : itr->second; } +void TrainableGraph::updateGraphDependency() +{ + _graph.verify(); + + // Initialize training usedefs + setTrainingUseDefs(UseDefGenerator{*this}()); + + disableUnusedBackwardNodes(_training_defuses, *this); + + verifyTrainingUseDefs(); +} + } // namespace train } // namespace ir } // namespace onert