Skip to content

Commit

Permalink
Apply cpu::ops::PadLayer to PadLayer in train
Browse files Browse the repository at this point in the history
  • Loading branch information
YongseopKim committed Jan 29, 2024
1 parent 2a2d86d commit dab9c59
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 57 deletions.
2 changes: 1 addition & 1 deletion runtime/onert/backend/cpu/ops/PadLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class PadLayer : public ::onert::exec::IFunction

void run() override;

private:
protected:
const IPortableTensor *_input;
IPortableTensor *_output;

Expand Down
50 changes: 3 additions & 47 deletions runtime/onert/backend/train/ops/PadLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,11 @@ namespace train
namespace ops
{

PadLayer::PadLayer()
: _input(nullptr), _output(nullptr), _padData(), _padRank(), _constantValueData(),
_back_prop_input{nullptr}, _back_prop_output{nullptr}
PadLayer::PadLayer() : cpu::ops::PadLayer(), _back_prop_input{nullptr}, _back_prop_output{nullptr}
{
// DO NOTHING
}

template <typename T> void PadLayer::padImpl(const T *constant_value_data)
{
nnfw::cker::Pad<T>(_padData, _padRank, getShape(_input), getBuffer<T>(_input), getShape(_output),
getBuffer<T>(_output), constant_value_data);
}

template <typename T> void PadLayer::depad()
{
nnfw::cker::train::Depad<T>(_padData, _padRank, getShape(_back_prop_output),
Expand All @@ -51,48 +43,12 @@ void PadLayer::configure(const IPortableTensor *input, IPortableTensor *output,
const int32_t *padData, int32_t padRank, const void *constantValueData,
IPortableTensor *back_prop_input, const IPortableTensor *back_prop_output)
{
_input = input;
_output = output;
memcpy(_padData, padData, sizeof(_padData));
_padRank = padRank;
_constantValueData.v = constantValueData;
cpu::ops::PadLayer::configure(input, output, padData, padRank, constantValueData);
_back_prop_input = back_prop_input;
_back_prop_output = back_prop_output;
}

void PadLayer::forward(bool)
{
switch (_input->data_type())
{
case OperandType::FLOAT32:
padImpl<float>(_constantValueData.f);
break;
case OperandType::QUANT_UINT8_ASYMM:
if (_constantValueData.u8 == nullptr)
{
uint8_t pad_value = static_cast<uint8_t>(_output->data_zero_point());
padImpl<uint8_t>(&pad_value);
}
else
{
padImpl<uint8_t>(_constantValueData.u8);
}
break;
case OperandType::QUANT_INT8_ASYMM:
if (_constantValueData.i8 == nullptr)
{
int8_t pad_value = static_cast<int8_t>(_output->data_zero_point());
padImpl<int8_t>(&pad_value);
}
else
{
padImpl<int8_t>(_constantValueData.i8);
}
break;
default:
throw std::runtime_error{"Pad: unsupported data type"};
}
}
void PadLayer::forward(bool) { cpu::ops::PadLayer::run(); }

void PadLayer::backward()
{
Expand Down
11 changes: 2 additions & 9 deletions runtime/onert/backend/train/ops/PadLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef __ONERT_BACKEND_TRAIN_OPS_PADLAYER_H__
#define __ONERT_BACKEND_TRAIN_OPS_PADLAYER_H__

#include <ops/PadLayer.h>
#include <backend/IPortableTensor.h>
#include "OperationUtils.h"

Expand All @@ -33,13 +34,12 @@ namespace ops

// Note, this is pad with mode=`CONSTANT`: it doesn't support `REFLECT` and
// `SYMMETRIC`
class PadLayer : public ::onert::exec::train::ITrainableFunction
class PadLayer : public ::onert::exec::train::ITrainableFunction, public cpu::ops::PadLayer
{
public:
PadLayer();

public:
template <typename T> void padImpl(const T *constant_value_data);
template <typename T> void depad();

void configure(const IPortableTensor *input, IPortableTensor *output, const int32_t *padData,
Expand All @@ -49,13 +49,6 @@ class PadLayer : public ::onert::exec::train::ITrainableFunction
void backward() override;

private:
const IPortableTensor *_input;
IPortableTensor *_output;

int32_t _padData[8];
int32_t _padRank;
ConstDataPtr _constantValueData;

IPortableTensor *_back_prop_input;
const IPortableTensor *_back_prop_output;
};
Expand Down

0 comments on commit dab9c59

Please sign in to comment.