Skip to content

Commit

Permalink
register etensors
Browse files Browse the repository at this point in the history
  • Loading branch information
zetwhite committed Jul 29, 2024
1 parent dde2bbf commit 5d258f4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
11 changes: 11 additions & 0 deletions runtime/onert/backend/train/BackendContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,17 @@ FunctionMap BackendContext::genKernels()
// fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
// }

// register estra tensors
// from c++ 17, structured binding is supported
for (auto &[operation_index, fn_seq] : ret)
{
fn_seq->iterate([&](exec::train::ITrainableFunction &fn) {
extra_tensor_gen->register_tensors(operation_index, fn.requestExtraTensors());
});
}

// TODO : plan extra tensors

return ret;
}

Expand Down
3 changes: 2 additions & 1 deletion runtime/onert/backend/train/ExtraTensorGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ ExtraTensorGenerator::ExtraTensorGenerator(const ir::train::TrainableGraph &tgra
std::shared_ptr<TensorRegistry> &tensor_registry)
: _tgraph(tgraph), _tensor_builder(tensor_builder), _tensor_reg(tensor_registry){};

void ExtraTensorGenerator::generate(ir::OperationIndex op_idx, const ExtraTensorRequests &reqs)
void ExtraTensorGenerator::register_tensors(ir::OperationIndex op_idx,
const ExtraTensorRequests &reqs)
{
// save request, _idx_to_reuqests used for memory planning
_idx_to_requests[op_idx] = reqs;
Expand Down
2 changes: 1 addition & 1 deletion runtime/onert/backend/train/ExtraTensorGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ExtraTensorGenerator
std::shared_ptr<TensorRegistry> &tensor_registry);

public:
void generate(ir::OperationIndex idx, const ExtraTensorRequests &requests);
void register_tensors(ir::OperationIndex idx, const ExtraTensorRequests &requests);

private:
const ir::train::TrainableGraph &_tgraph;
Expand Down

0 comments on commit 5d258f4

Please sign in to comment.