Skip to content

Commit

Permalink
[onert] Add generating training usedefs for ElementwiseActivation op (#…
Browse files Browse the repository at this point in the history
…13455)

This commit adds generating training usedefs for ElementwiseActivation operation.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jul 23, 2024
1 parent 89a1b73 commit 7b9450b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
23 changes: 23 additions & 0 deletions runtime/onert/core/src/ir/train/UseDefGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,29 @@ void UseDefGenerator::visit(const train::operation::DepthwiseConv2D &node)
}
}

void UseDefGenerator::visit(const train::operation::ElementwiseActivation &node)
{
if (node.param().op_type != operation::ElementwiseActivation::Type::RELU)
{
throw std::runtime_error{"UseDefGenerator: Not yet supported activation type"};
}
assert(_node_to_idx.find(&node) != _node_to_idx.end());
const auto &op_index = _node_to_idx.at(&node);
const auto backwarding_op_index = TrainingOperationIndex{op_index, false};

// Insert use of forwarding output
const auto &out_index = node.getOutputs().at(0);
const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
insertUse(out_forwarding_index, backwarding_op_index);

// Set def of backwarding(backprop) inputs
for (const auto &in_index : node.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
{
const auto outgoing_index = TrainingOperandIndex{in_index, false};
insertBackPropDef(outgoing_index, backwarding_op_index);
}
}

void UseDefGenerator::visit(const train::operation::Loss &node)
{
assert(_node_to_idx.find(&node) != _node_to_idx.end());
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/core/src/ir/train/UseDefGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class UseDefGenerator : public UseDefGeneratorBase
public:
void visit(const train::operation::Conv2D &node) override;
void visit(const train::operation::DepthwiseConv2D &node) override;
void visit(const train::operation::ElementwiseActivation &node) override;
void visit(const train::operation::Loss &node) override;
void visit(const train::operation::Reshape &node) override;
void visit(const train::operation::Pad &node) override;
Expand Down

0 comments on commit 7b9450b

Please sign in to comment.