Skip to content

Commit

Permalink
[onert] Add generating training usedefs for Pool2D op (#13458)
Browse files Browse the repository at this point in the history
This commit adds generating training usedefs for Pool2D op.

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jul 23, 2024
1 parent 7b9450b commit 8a76a55
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
30 changes: 30 additions & 0 deletions runtime/onert/core/src/ir/train/UseDefGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,36 @@ void UseDefGenerator::visit(const train::operation::Pad &node)
insertBackPropDef(outgoing_index, backwarding_op_index);
}

void UseDefGenerator::visit(const train::operation::Pool2D &node)
{
if (node.param().op_type != ir::operation::Pool2D::PoolType::MAX)
{
throw std::runtime_error{"UseDefGenerator: Not yet supported pool type"};
}

assert(_node_to_idx.find(&node) != _node_to_idx.end());
const auto &op_index = _node_to_idx.at(&node);
const auto backwarding_op_index = TrainingOperationIndex{op_index, false};

// Insert uses of forwarding output
if (node.param().activation != ir::Activation::NONE)
{
const auto &out_index = node.getOutputs().at(0);
const auto out_forwarding_index = TrainingOperandIndex{out_index, true};
insertUse(out_forwarding_index, backwarding_op_index);
}

// Insert use of backwarding(backprop) output
const auto &out_index = node.getOutputs().at(0);
const auto incoming_index = TrainingOperandIndex{out_index, false};
insertUse(incoming_index, backwarding_op_index);

// Set def of backwarding(backprop) input
const auto &in_index = node.getInputs().at(train::operation::Pool2D::Input::INPUT);
const auto outgoing_index = TrainingOperandIndex{in_index, false};
insertBackPropDef(outgoing_index, backwarding_op_index);
}

void UseDefGenerator::visit(const train::operation::Reshape &node)
{
assert(_node_to_idx.find(&node) != _node_to_idx.end());
Expand Down
3 changes: 2 additions & 1 deletion runtime/onert/core/src/ir/train/UseDefGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ class UseDefGenerator : public UseDefGeneratorBase
void visit(const train::operation::DepthwiseConv2D &node) override;
void visit(const train::operation::ElementwiseActivation &node) override;
void visit(const train::operation::Loss &node) override;
void visit(const train::operation::Reshape &node) override;
void visit(const train::operation::Pad &node) override;
void visit(const train::operation::Pool2D &node) override;
void visit(const train::operation::Reshape &node) override;

private:
void insertUse(const TrainingOperandIndex &operand_index, const TrainingOperationIndex &op_index);
Expand Down

0 comments on commit 8a76a55

Please sign in to comment.