Skip to content

Commit

Permalink
[onert] Plan disposable tensors on train backend
Browse files Browse the repository at this point in the history
This commit adds planning disposable tensors used on train backend.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani committed Jul 29, 2024
1 parent 19b0b2b commit 0ea72a8
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 2 deletions.
62 changes: 60 additions & 2 deletions runtime/onert/backend/train/TensorPlanner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,67 @@ void TensorPlanner::planGradientTensors(TensorBuilder *tensor_builder)
VERBOSE(BackendContext) << "Finish planning gradient tensors" << std::endl;
}

void TensorPlanner::planDisposableBackPropTensors(TensorBuilder *)
void TensorPlanner::planDisposableBackPropTensors(TensorBuilder *tensor_builder)
{
// TODO Plan diposable backprop tensors
VERBOSE(BackendContext) << "Start planning disposable back-prop tensors" << std::endl;

for (const auto &op_index : _tgraph.essentialBackwardOrder())
{
// NOTE Even if there are duplicate indices, the duplicate back-propagated tensors may need
// to be updated respectively. So we use a sequence instead of a set.
const auto &inputs = _tgraph.operation(op_index).getInputs();
if (!(inputs == (inputs | ir::Remove::DUPLICATED)))
throw std::runtime_error("TensorPlanner: DispoableBackProp tensor does not support duplicate "
"inputs of an operation");

std::vector<DisposableTensorIndex> cur_seq;
const auto back_prop_indices = getOutgoingBackPropSeq(op_index, tensor_builder);
for (const auto &back_prop_index : back_prop_indices)
{
DisposableTensorIndex cur_index{op_index, back_prop_index};
if (tensor_builder->isRegisteredDisposableBackwardTensor(cur_index))
{
tensor_builder->notifyDisposableBackPropFirstUse(cur_index);
cur_seq.emplace_back(cur_index);
}
}

for (const auto &cur_index : cur_seq)
{
tensor_builder->notifyDisposableBackPropLastUse(cur_index);
}
}

VERBOSE(BackendContext) << "Finish planning disposable back-prop tensors" << std::endl;
}

ir::OperandIndexSequence TensorPlanner::getOutgoingBackPropSeq(const ir::OperationIndex &op_index,
const TensorBuilder *tensor_builder)
{
ir::OperandIndexSequence ret;

const auto &op = _tgraph.operation(op_index);
for (const auto &input : (op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED))
{
if (_external_operands.contains(input))
continue;
if (!tensor_builder->isRegisteredBackward(input))
continue;

const auto input_index = ir::train::TrainingOperandIndex{input, false};
const auto training_op_index = ir::train::TrainingOperationIndex{op_index, false};
const auto &training_usedefs = _tgraph.trainingUseDefs();
const auto &usedefs = training_usedefs.at(input_index);
if (usedefs.operand().isConstant())
continue;

if (usedefs.getTrainingDefs().find(training_op_index) == usedefs.getTrainingDefs().end())
continue;

ret.append(input);
}

return ret;
}

} // namespace train
Expand Down
4 changes: 4 additions & 0 deletions runtime/onert/backend/train/TensorPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class TensorPlanner
void planGradientTensors(TensorBuilder *tensor_builder);
void planDisposableBackPropTensors(TensorBuilder *tensor_builder);

private:
ir::OperandIndexSequence getOutgoingBackPropSeq(const ir::OperationIndex &op_index,
const TensorBuilder *tensor_builder);

private:
const ir::train::TrainableGraph &_tgraph;
const util::Set<ir::OperandIndex> &_external_operands;
Expand Down

0 comments on commit 0ea72a8

Please sign in to comment.