From 5d258f42da0c96b6319e82a28d40607239450183 Mon Sep 17 00:00:00 2001 From: zetwhite Date: Tue, 30 Jul 2024 01:30:36 +0900 Subject: [PATCH] register etensors --- runtime/onert/backend/train/BackendContext.cc | 11 +++++++++++ runtime/onert/backend/train/ExtraTensorGenerator.cc | 3 ++- runtime/onert/backend/train/ExtraTensorGenerator.h | 2 +- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/runtime/onert/backend/train/BackendContext.cc b/runtime/onert/backend/train/BackendContext.cc index 8b45a624f4c..e207e6b6f6f 100644 --- a/runtime/onert/backend/train/BackendContext.cc +++ b/runtime/onert/backend/train/BackendContext.cc @@ -231,6 +231,17 @@ FunctionMap BackendContext::genKernels() // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); }); // } + // register estra tensors + // from c++ 17, structured binding is supported + for (auto &[operation_index, fn_seq] : ret) + { + fn_seq->iterate([&](exec::train::ITrainableFunction &fn) { + extra_tensor_gen->register_tensors(operation_index, fn.requestExtraTensors()); + }); + } + + // TODO : plan extra tensors + return ret; } diff --git a/runtime/onert/backend/train/ExtraTensorGenerator.cc b/runtime/onert/backend/train/ExtraTensorGenerator.cc index 5f6124e1afb..7d23bc892fd 100644 --- a/runtime/onert/backend/train/ExtraTensorGenerator.cc +++ b/runtime/onert/backend/train/ExtraTensorGenerator.cc @@ -30,7 +30,8 @@ ExtraTensorGenerator::ExtraTensorGenerator(const ir::train::TrainableGraph &tgra std::shared_ptr &tensor_registry) : _tgraph(tgraph), _tensor_builder(tensor_builder), _tensor_reg(tensor_registry){}; -void ExtraTensorGenerator::generate(ir::OperationIndex op_idx, const ExtraTensorRequests &reqs) +void ExtraTensorGenerator::register_tensors(ir::OperationIndex op_idx, + const ExtraTensorRequests &reqs) { // save request, _idx_to_reuqests used for memory planning _idx_to_requests[op_idx] = reqs; diff --git a/runtime/onert/backend/train/ExtraTensorGenerator.h b/runtime/onert/backend/train/ExtraTensorGenerator.h index 506c9f1b827..f403a7f9a53 100644 --- a/runtime/onert/backend/train/ExtraTensorGenerator.h +++ b/runtime/onert/backend/train/ExtraTensorGenerator.h @@ -39,7 +39,7 @@ class ExtraTensorGenerator std::shared_ptr &tensor_registry); public: - void generate(ir::OperationIndex idx, const ExtraTensorRequests &requests); + void register_tensors(ir::OperationIndex idx, const ExtraTensorRequests &requests); private: const ir::train::TrainableGraph &_tgraph;