Skip to content

Commit

Permalink
[onert] Introduce backwarding graph order for training (#12143)
Browse files Browse the repository at this point in the history
This commit introduces backwarding graph order for training.
It starts with the Loss op and find the backwarding order.

ONE-DCO-1.0-Signed-off-by: Jiyoung Yun <[email protected]>
  • Loading branch information
jyoungyun authored Dec 1, 2023
1 parent e8e060c commit 62c5b51
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 19 deletions.
2 changes: 1 addition & 1 deletion runtime/onert/core/include/ir/train/TrainableGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class TrainableGraph : public IGraph

public:
std::vector<ir::OperationIndex> topolSortOperations() const;
// TODO Support topological sort for backwarding
std::vector<ir::OperationIndex> btopolSortOperations() const;

private:
Graph _graph;
Expand Down
9 changes: 8 additions & 1 deletion runtime/onert/core/src/compiler/ExecutorFactory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -749,10 +749,16 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor(
(lowered_graph->graph().getInputs() + lowered_graph->graph().getOutputs()) |
ir::Remove::DUPLICATED | ir::Remove::UNDEFINED);

// linearize
// linearize for forwarding
auto order = Linear::linearize(*lowered_graph);
VERBOSE(ExecutorFactory) << "Linearize for forwarding order" << std::endl;
Linear::dump(*lowered_graph, order);

// linearize for backwarding
auto backward_order = lowered_graph->trainable_graph().btopolSortOperations();
VERBOSE(ExecutorFactory) << "Linearize for backwarding order" << std::endl;
Linear::dump(*lowered_graph, backward_order);

for (auto &&pair : tbackend_contexts)
{
pair.second->genTensors();
Expand Down Expand Up @@ -885,6 +891,7 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor(
tensor_regs,
std::move(code_map),
order,
backward_order,
tracing_ctx,
training_info.lossInfo()};

Expand Down
30 changes: 15 additions & 15 deletions runtime/onert/core/src/exec/train/TrainableExecutor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ TrainableExecutor::TrainableExecutor(
std::unique_ptr<compiler::train::LoweredTrainableGraph> lowered_graph,
backend::train::TrainableBackendContexts &&backend_contexts,
const compiler::train::TensorRegistries &tensor_regs,
compiler::train::TrainableCodeMap &&code_map, const std::vector<ir::OperationIndex> &order,
const util::TracingCtx *tracing_ctx, const ir::train::LossInfo &loss_info)
: _lowered_graph{std::move(lowered_graph)}, _backend_contexts{std::move(backend_contexts)},
compiler::train::TrainableCodeMap &&code_map,
const std::vector<ir::OperationIndex> &forward_order,
const std::vector<ir::OperationIndex> &backward_order, const util::TracingCtx *tracing_ctx,
const ir::train::LossInfo &loss_info)
: _code_map{std::move(code_map)}, _forward_order{std::move(forward_order)},
_backward_order{std::move(backward_order)}, _lowered_graph{std::move(lowered_graph)},
_backend_contexts{std::move(backend_contexts)},
_trainable_graph{_lowered_graph->trainable_graph()}, _tensor_regs{std::move(tensor_regs)},
_mutex(), _tracing_ctx(tracing_ctx), _loss_info(loss_info)
{
Expand All @@ -50,12 +54,6 @@ TrainableExecutor::TrainableExecutor(
};
build_tensor_list(_trainable_graph.getInputs(), _input_tensors);
build_tensor_list(_trainable_graph.getOutputs(), _output_tensors);

for (auto &&index : order)
{
auto &trainable_code = code_map.at(index);
_code.emplace_back(std::move(trainable_code));
}
}

void TrainableExecutor::execute(const std::vector<backend::IPortableTensor *> &,
Expand Down Expand Up @@ -110,8 +108,9 @@ void TrainableExecutor::forwardImpl(bool training)
auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph());

_subject.notifySubgraphBegin(profiling_subg_index);
for (auto &&code : _code)
for (auto &&index : _forward_order)
{
const auto &code = _code_map.at(index);
const auto backend = code.lower_info->backend();
// TODO : Move ruy profiler into ExecutionObserver
#ifdef RUY_PROFILER
Expand All @@ -128,8 +127,9 @@ void TrainableExecutor::forwardImpl(bool training)
}
else
{
for (auto &&code : _code)
for (auto &&index : _forward_order)
{
const auto &code = _code_map.at(index);
// TODO : Move ruy profiler into ExecutionObserver
#ifdef RUY_PROFILER
ruy::profiler::ScopeLabel label(code.op->name());
Expand Down Expand Up @@ -157,9 +157,9 @@ void TrainableExecutor::backwardImpl(uint32_t training_step)
auto profiling_subg_index = _tracing_ctx->getSubgraphIndex(&_trainable_graph.graph());

_subject.notifySubgraphBegin(profiling_subg_index);
for (auto it = _code.rbegin(); it != _code.rend(); ++it)
for (auto &&index : _backward_order)
{
const auto &code = *it;
const auto &code = _code_map.at(index);
const auto backend = code.lower_info->backend();
// TODO : Move ruy profiler into ExecutionObserver
#ifdef RUY_PROFILER
Expand All @@ -176,9 +176,9 @@ void TrainableExecutor::backwardImpl(uint32_t training_step)
}
else
{
for (auto it = _code.rbegin(); it != _code.rend(); ++it)
for (auto &&index : _backward_order)
{
const auto &code = *it;
const auto &code = _code_map.at(index);
// TODO : Move ruy profiler into ExecutionObserver
#ifdef RUY_PROFILER
ruy::profiler::ScopeLabel label(code.op->name());
Expand Down
7 changes: 5 additions & 2 deletions runtime/onert/core/src/exec/train/TrainableExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class TrainableExecutor : public IExecutor
backend::train::TrainableBackendContexts &&backend_contexts,
const compiler::train::TensorRegistries &tensor_regs,
compiler::train::TrainableCodeMap &&code_map,
const std::vector<ir::OperationIndex> &order,
const std::vector<ir::OperationIndex> &forward_order,
const std::vector<ir::OperationIndex> &backward_order,
const util::TracingCtx *tracing_ctx, const ir::train::LossInfo &training_info);

public:
Expand Down Expand Up @@ -90,7 +91,9 @@ class TrainableExecutor : public IExecutor
void backwardImpl(uint32_t training_step);

private:
std::vector<compiler::train::TrainableCodeAndInfo> _code;
compiler::train::TrainableCodeMap _code_map;
std::vector<ir::OperationIndex> _forward_order;
std::vector<ir::OperationIndex> _backward_order;
ExecutionObservee _subject;
std::shared_ptr<ir::OperationIndexMap<int64_t>> _indexed_ranks;
std::unique_ptr<compiler::train::LoweredTrainableGraph> _lowered_graph;
Expand Down
37 changes: 37 additions & 0 deletions runtime/onert/core/src/ir/train/TrainableGraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "ir/train/TrainableGraph.h"
#include "util/Utils.h"
#include "util/Set.h"

#include <algorithm>
#include <misc/polymorphic_downcast.h>
Expand Down Expand Up @@ -128,6 +129,42 @@ std::vector<ir::OperationIndex> TrainableGraph::topolSortOperations() const
return _graph.topolSortOperations();
}

std::vector<ir::OperationIndex> TrainableGraph::btopolSortOperations() const
{
std::vector<ir::OperationIndex> ret;
util::Set<ir::OperationIndex> unvisited;
ir::OperationIndex loss_idx;
operations().iterate([&](const ir::OperationIndex &index, const ir::IOperation &op) {
unvisited.add(index);
if (op.opcode() == ir::OpCode::Loss)
{
assert(!loss_idx.valid()); // Should be only one loss
loss_idx = index;
}
});

std::function<void(const ir::OperationIndex &, const ir::IOperation &)> dfs =
[&](const ir::OperationIndex &index, const ir::IOperation &op) -> void {
if (!unvisited.contains(index))
return;
unvisited.remove(index);
ret.push_back(index);

for (const auto &input : op.getInputs() | ir::Remove::DUPLICATED | ir::Remove::UNDEFINED)
{
const auto &operand = operands().at(input);
const auto &def = operand.getDef();
if (!def.valid())
return;
dfs(def, operations().at(def));
}
};

dfs(loss_idx, operations().at(loss_idx));

return ret;
}

void TrainableGraph::addLoss(const OperandIndex &loss_ind, const IOIndex &pred_ioind)
{
_losses.emplace(pred_ioind, loss_ind);
Expand Down

0 comments on commit 62c5b51

Please sign in to comment.