diff --git a/runtime/onert/backend/train/TensorPlanner.cc b/runtime/onert/backend/train/TensorPlanner.cc index acfc289d351..221904e98eb 100644 --- a/runtime/onert/backend/train/TensorPlanner.cc +++ b/runtime/onert/backend/train/TensorPlanner.cc @@ -456,9 +456,67 @@ void TensorPlanner::planGradientTensors(TensorBuilder *tensor_builder) VERBOSE(BackendContext) << "Finish planning gradient tensors" << std::endl; } -void TensorPlanner::planDisposableBackPropTensors(TensorBuilder *) +void TensorPlanner::planDisposableBackPropTensors(TensorBuilder *tensor_builder) { - // TODO Plan diposable backprop tensors + VERBOSE(BackendContext) << "Start planning disposable back-prop tensors" << std::endl; + + for (const auto &op_index : _tgraph.essentialBackwardOrder()) + { + // NOTE Even if there are duplicate indices, the duplicate back-propagated tensors may need + // to be updated respectively. So we use a sequence instead of a set. + const auto &inputs = _tgraph.operation(op_index).getInputs(); + if (!(inputs == (inputs | ir::Remove::DUPLICATED))) + throw std::runtime_error("TensorPlanner: DispoableBackProp tensor does not support duplicate " + "inputs of an operation"); + + std::vector cur_seq; + const auto back_prop_indices = getOutgoingBackPropSeq(op_index, tensor_builder); + for (const auto &back_prop_index : back_prop_indices) + { + DisposableTensorIndex cur_index{op_index, back_prop_index}; + if (tensor_builder->isRegisteredDisposableBackwardTensor(cur_index)) + { + tensor_builder->notifyDisposableBackPropFirstUse(cur_index); + cur_seq.emplace_back(cur_index); + } + } + + for (const auto &cur_index : cur_seq) + { + tensor_builder->notifyDisposableBackPropLastUse(cur_index); + } + } + + VERBOSE(BackendContext) << "Finish planning disposable back-prop tensors" << std::endl; +} + +ir::OperandIndexSequence TensorPlanner::getOutgoingBackPropSeq(const ir::OperationIndex &op_index, + const TensorBuilder *tensor_builder) +{ + ir::OperandIndexSequence ret; + + const auto &op = _tgraph.operation(op_index); + for (const auto &input : (op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)) + { + if (_external_operands.contains(input)) + continue; + if (!tensor_builder->isRegisteredBackward(input)) + continue; + + const auto input_index = ir::train::TrainingOperandIndex{input, false}; + const auto training_op_index = ir::train::TrainingOperationIndex{op_index, false}; + const auto &training_usedefs = _tgraph.trainingUseDefs(); + const auto &usedefs = training_usedefs.at(input_index); + if (usedefs.operand().isConstant()) + continue; + + if (usedefs.getTrainingDefs().find(training_op_index) == usedefs.getTrainingDefs().end()) + continue; + + ret.append(input); + } + + return ret; } } // namespace train diff --git a/runtime/onert/backend/train/TensorPlanner.h b/runtime/onert/backend/train/TensorPlanner.h index dd19f09c1d2..61af802fda9 100644 --- a/runtime/onert/backend/train/TensorPlanner.h +++ b/runtime/onert/backend/train/TensorPlanner.h @@ -46,6 +46,10 @@ class TensorPlanner void planGradientTensors(TensorBuilder *tensor_builder); void planDisposableBackPropTensors(TensorBuilder *tensor_builder); +private: + ir::OperandIndexSequence getOutgoingBackPropSeq(const ir::OperationIndex &op_index, + const TensorBuilder *tensor_builder); + private: const ir::train::TrainableGraph &_tgraph; const util::Set &_external_operands;