diff --git a/runtime/onert/backend/train/BackendContext.cc b/runtime/onert/backend/train/BackendContext.cc index f5bf9999671..6b5fab34157 100644 --- a/runtime/onert/backend/train/BackendContext.cc +++ b/runtime/onert/backend/train/BackendContext.cc @@ -136,22 +136,50 @@ getDisposableBackPropTensorList(const ir::train::TrainableGraph &tgraph, } } // namespace -backend::ITensorRegistry *BackendContext::genTensors() +FunctionMap BackendContext::gen() { planForwardTensors(); + planBackwardTensors(); _tensor_builder->allocate(); + _tensor_builder->allocateBackward(); - return _tensor_registry.get(); -} + auto codes = generateFunctionMap(); -backend::train::ITensorRegistry *BackendContext::genTrainingTensors() -{ - planBackwardTensors(); + // Initialize TrainableTensors + trainable_graph()->operands().iterate( + [&](const ir::OperandIndex &ind, const ir::Operand &operand) { + if (external_operands().contains(ind) || !operand.isConstant()) + return; - _tensor_builder->allocateBackward(); + auto tensor = tensor_registry()->getNativeITensor(ind); + assert(tensor != nullptr); + + VERBOSE(FillOperandData) << "Fill data for " << ind << std::endl; + + auto data = operand.shareData(); + assert(data && data->base()); + auto trainable_tensor = dynamic_cast(tensor); - return _tensor_registry.get(); + if (trainable_tensor == nullptr) + throw std::runtime_error{"This tensor is not trainable tensor"}; + + trainable_tensor->fillBuffer(data); + }); + + // NOTE For memory optimization, we want to free some operand data + const_cast(*_tdata->tgraph) + .operands() + .iterate([&](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); }); + + // TODO Enable + // for (auto &&it : ret) + // { + // auto &fn_seq = it.second; + // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); }); + // } + + return codes; } void BackendContext::planForwardTensors() @@ -209,46 +237,6 @@ void BackendContext::planBackwardTensors() tensor_planner.planDisposableBackPropTensors(tensor_builder.get()); } -FunctionMap BackendContext::genKernels() -{ - auto ret = generateFunctionMap(); - - // Initialize TrainableTensors - trainable_graph()->operands().iterate( - [&](const ir::OperandIndex &ind, const ir::Operand &operand) { - if (external_operands().contains(ind) || !operand.isConstant()) - return; - - auto tensor = tensor_registry()->getNativeITensor(ind); - assert(tensor != nullptr); - - VERBOSE(FillOperandData) << "Fill data for " << ind << std::endl; - - auto data = operand.shareData(); - assert(data && data->base()); - auto trainable_tensor = dynamic_cast(tensor); - - if (trainable_tensor == nullptr) - throw std::runtime_error{"This tensor is not trainable tensor"}; - - trainable_tensor->fillBuffer(data); - }); - - // NOTE For memory optimization, we want to free some operand data - const_cast(*_tdata->tgraph) - .operands() - .iterate([&](const ir::OperandIndex &, ir::Operand &obj) { obj.releaseData(); }); - - // TODO Enable - // for (auto &&it : ret) - // { - // auto &fn_seq = it.second; - // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); }); - // } - - return ret; -} - FunctionMap BackendContext::generateFunctionMap() { train::FunctionMap ret; diff --git a/runtime/onert/backend/train/BackendContext.h b/runtime/onert/backend/train/BackendContext.h index 69d17d352c4..8e343aee403 100644 --- a/runtime/onert/backend/train/BackendContext.h +++ b/runtime/onert/backend/train/BackendContext.h @@ -68,16 +68,13 @@ class BackendContext : public onert::backend::train::TrainableBackendContext BackendContext &operator=(const BackendContext &) = delete; public: - backend::ITensorRegistry *genTensors() override; - backend::train::ITensorRegistry *genTrainingTensors() override; + FunctionMap gen() override; private: void planForwardTensors(); void planBackwardTensors(); public: - FunctionMap genKernels() override; - std::shared_ptr external_context() { return _external_context; } const exec::train::optimizer::Optimizer *optimizer() const { return _optimizer.get(); } diff --git a/runtime/onert/core/include/backend/train/TrainableBackendContext.h b/runtime/onert/core/include/backend/train/TrainableBackendContext.h index b3a9cdd7d52..c2edf0deb79 100644 --- a/runtime/onert/core/include/backend/train/TrainableBackendContext.h +++ b/runtime/onert/core/include/backend/train/TrainableBackendContext.h @@ -76,9 +76,7 @@ class TrainableBackendContext std::shared_ptr tensor_registry() { return _tensor_registry; } - virtual ITensorRegistry *genTrainingTensors() = 0; - virtual backend::ITensorRegistry *genTensors() = 0; - virtual FunctionMap genKernels() = 0; + virtual FunctionMap gen() = 0; private: const ITrainableBackend *_backend{nullptr}; diff --git a/runtime/onert/core/src/backend/builtin/train/BackendContext.cc b/runtime/onert/core/src/backend/builtin/train/BackendContext.cc index 69483eade12..ec3ce0f882f 100644 --- a/runtime/onert/core/src/backend/builtin/train/BackendContext.cc +++ b/runtime/onert/core/src/backend/builtin/train/BackendContext.cc @@ -28,29 +28,19 @@ namespace builtin namespace train { -backend::ITensorRegistry *BackendContext::genTensors() +backend::train::FunctionMap BackendContext::gen() { - // For now, there is no need to generate tensors for forwarding. + // For now, there is no need to generate tensors for forwarding and backwarding. // builtin train backend handles 3 operators: `Permute`, `IF`, `WHILE`. // `Permute`: Tensor generation is not required. // `IF`, `WHILE`: Not supported yet - return tensor_registry().get(); -} -backend::train::ITensorRegistry *BackendContext::genTrainingTensors() -{ - // For now, there is no need to generate tensors for backwarding. - return tensor_registry().get(); -} - -backend::train::FunctionMap BackendContext::genKernels() -{ - backend::train::FunctionMap ret; + backend::train::FunctionMap codes; for (auto &&op_ind : _tdata->op_order) { auto tn_seq = kernel_gen->generate(op_ind); - ret.emplace(op_ind, std::move(tn_seq)); + codes.emplace(op_ind, std::move(tn_seq)); } trainable_graph()->operands().iterate( @@ -69,7 +59,7 @@ backend::train::FunctionMap BackendContext::genKernels() // fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); }); // } - return ret; + return codes; } } // namespace train diff --git a/runtime/onert/core/src/backend/builtin/train/BackendContext.h b/runtime/onert/core/src/backend/builtin/train/BackendContext.h index 4782756c31c..eb88a4ddc8c 100644 --- a/runtime/onert/core/src/backend/builtin/train/BackendContext.h +++ b/runtime/onert/core/src/backend/builtin/train/BackendContext.h @@ -46,11 +46,7 @@ class BackendContext : public backend::train::TrainableBackendContext { } - backend::ITensorRegistry *genTensors() override; - backend::train::ITensorRegistry *genTrainingTensors() override; - -public: - backend::train::FunctionMap genKernels() override; + backend::train::FunctionMap gen() override; std::shared_ptr external_context() { return _external_context; } diff --git a/runtime/onert/core/src/compiler/ExecutorFactory.cc b/runtime/onert/core/src/compiler/ExecutorFactory.cc index 3cbe5f670eb..4b7fa687c3d 100644 --- a/runtime/onert/core/src/compiler/ExecutorFactory.cc +++ b/runtime/onert/core/src/compiler/ExecutorFactory.cc @@ -274,6 +274,21 @@ std::deque> orderBackendContext( return ordered_contexts; } +void extractCodes(backend::train::FunctionMap &codes, + const compiler::train::LoweredTrainableGraph *lowered_graph, + compiler::train::TrainableCodeMap &code_map) +{ + for (auto &&[op_ind, tn_seq] : codes) + { + auto &op = lowered_graph->trainable_graph().operation(op_ind); + const auto backend = lowered_graph->lower_info().operation.at(op_ind); + + assert(code_map.find(op_ind) == code_map.end()); + code_map.insert( + {op_ind, compiler::train::TrainableCodeAndInfo{op_ind, &op, backend, std::move(tn_seq)}}); + } +} + } // namespace } // namespace onert @@ -741,15 +756,16 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor( VERBOSE(ExecutorFactory) << "Linearize for backwarding order" << std::endl; Linear::dump(*lowered_graph, backward_order); - for (auto &&pair : tbackend_contexts) + train::TrainableCodeMap code_map; + // Generate tensors and kernels + for (auto &&[backend, context] : tbackend_contexts) { - pair.second->genTensors(); - } + // builtin backend's kernel generator requires access to tensors in other backends. + if (backend->config()->id() == "builtin") + continue; - for (auto &&pair : tbackend_contexts) - { - auto tctx = pair.second.get(); - tctx->genTrainingTensors(); + auto codes = context->gen(); + extractCodes(codes, lowered_graph.get(), code_map); } prepareMigrantTensors(*lowered_graph, tbackend_contexts); @@ -767,6 +783,15 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor( } } + for (auto &&[backend, context] : tbackend_contexts) + { + if (backend->config()->id() == "builtin") + { + auto codes = context->gen(); + extractCodes(codes, lowered_graph.get(), code_map); + } + } + // Adjust the order of backends for the upcoming iteration auto ordered_contexts = onert::orderBackendContext(tbackend_contexts); @@ -845,22 +870,6 @@ exec::IExecutor *ExecutorFactory::createTrainableExecutor( })); } - train::TrainableCodeMap code_map; - // Generate kernels - for (auto &&pair : ordered_contexts) - { - auto codes = pair.second->genKernels(); - for (auto &&[op_ind, tn_seq] : codes) - { - auto &op = lowered_graph->trainable_graph().operation(op_ind); - const auto backend = lowered_graph->lower_info().operation.at(op_ind); - - assert(code_map.find(op_ind) == code_map.end()); - code_map.insert( - {op_ind, train::TrainableCodeAndInfo{op_ind, &op, backend, std::move(tn_seq)}}); - } - } - if (order.size() != code_map.size()) { throw std::runtime_error("ExecutorFactory: Some kernels are not generated");