From 7b9450b1a18fababc2ae348ab65e7c078ba6963e Mon Sep 17 00:00:00 2001 From: Jang Jiseob Date: Tue, 23 Jul 2024 13:56:02 +0900 Subject: [PATCH] [onert] Add generating training usedefs for ElementwiseActivation op (#13455) This commit adds generating training usedefs for ElementwiseActivation operation. ONE-DCO-1.0-Signed-off-by: ragmani --- .../core/src/ir/train/UseDefGenerator.cc | 23 +++++++++++++++++++ .../onert/core/src/ir/train/UseDefGenerator.h | 1 + 2 files changed, 24 insertions(+) diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.cc b/runtime/onert/core/src/ir/train/UseDefGenerator.cc index b6d750c62af..3b1c65b8fe5 100644 --- a/runtime/onert/core/src/ir/train/UseDefGenerator.cc +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.cc @@ -147,6 +147,29 @@ void UseDefGenerator::visit(const train::operation::DepthwiseConv2D &node) } } +void UseDefGenerator::visit(const train::operation::ElementwiseActivation &node) +{ + if (node.param().op_type != operation::ElementwiseActivation::Type::RELU) + { + throw std::runtime_error{"UseDefGenerator: Not yet supported activation type"}; + } + assert(_node_to_idx.find(&node) != _node_to_idx.end()); + const auto &op_index = _node_to_idx.at(&node); + const auto backwarding_op_index = TrainingOperationIndex{op_index, false}; + + // Insert use of forwarding output + const auto &out_index = node.getOutputs().at(0); + const auto out_forwarding_index = TrainingOperandIndex{out_index, true}; + insertUse(out_forwarding_index, backwarding_op_index); + + // Set def of backwarding(backprop) inputs + for (const auto &in_index : node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED) + { + const auto outgoing_index = TrainingOperandIndex{in_index, false}; + insertBackPropDef(outgoing_index, backwarding_op_index); + } +} + void UseDefGenerator::visit(const train::operation::Loss &node) { assert(_node_to_idx.find(&node) != _node_to_idx.end()); diff --git a/runtime/onert/core/src/ir/train/UseDefGenerator.h b/runtime/onert/core/src/ir/train/UseDefGenerator.h index 2b7825f9b4b..02d19de8156 100644 --- a/runtime/onert/core/src/ir/train/UseDefGenerator.h +++ b/runtime/onert/core/src/ir/train/UseDefGenerator.h @@ -66,6 +66,7 @@ class UseDefGenerator : public UseDefGeneratorBase public: void visit(const train::operation::Conv2D &node) override; void visit(const train::operation::DepthwiseConv2D &node) override; + void visit(const train::operation::ElementwiseActivation &node) override; void visit(const train::operation::Loss &node) override; void visit(const train::operation::Reshape &node) override; void visit(const train::operation::Pad &node) override;