Skip to content

Commit

Permalink
remove layout
Browse files Browse the repository at this point in the history
  • Loading branch information
zetwhite committed Jul 29, 2024
1 parent ed4a53d commit ca52249
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 17 deletions.
2 changes: 1 addition & 1 deletion runtime/onert/backend/train/ExtraTensorGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void ExtraTensorGenerator::generate(ir::OperationIndex op_idx, const ExtraTensor
{
// register tensor
ExtraTensorIndex tensor_idx(op_idx, i);
_tensor_builder->registerExtraTensorInfo(tensor_idx, reqs[i].info, reqs[i].layout);
_tensor_builder->registerExtraTensorInfo(tensor_idx, reqs[i].info);

// return registered tensor
auto generated_tensor = _tensor_reg->getExtraTensor(tensor_idx);
Expand Down
5 changes: 2 additions & 3 deletions runtime/onert/backend/train/TensorBuilder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,11 @@ void TensorBuilder::registerDisposableBackwardTensorInfo(const DisposableTensorI
}

void TensorBuilder::registerExtraTensorInfo(const ExtraTensorIndex &index,
const ir::OperandInfo &info, ir::Layout layout)
const ir::OperandInfo &info)
{
assert(layout == ir::Layout::NHWC);
assert(!info.isDynamic());

auto extra_tensor = std::make_unique<ExtraTensor>(info, layout);
auto extra_tensor = std::make_unique<ExtraTensor>(info);
_tensor_reg->setExtraTensor(index, std::move(extra_tensor));
}

Expand Down
3 changes: 1 addition & 2 deletions runtime/onert/backend/train/TensorBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ class TensorBuilder
void registerDisposableBackwardTensorInfo(const DisposableTensorIndex &index,
const ir::OperandInfo &info);

void registerExtraTensorInfo(const ExtraTensorIndex &index, const ir::OperandInfo &info,
ir::Layout layout);
void registerExtraTensorInfo(const ExtraTensorIndex &index, const ir::OperandInfo &info);

// TODO Support memory plan of all tensors
void notifyFirstUse(const ir::OperandIndex &);
Expand Down
8 changes: 4 additions & 4 deletions runtime/onert/backend/train/ops/ConvolutionLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ createTransposeTenosrRequest(const backend::IPortableTensor *origin,
backend::train::ExtraTensor **const addr)
{
return backend::train::ExtraTensorRequest(transposeOperandInfo(origin->get_info()),
origin->layout(),
backend::train::ExtraTensorLifeTime::BACKWARD, addr);
}

Expand Down Expand Up @@ -120,14 +119,15 @@ ExtraTensorRequests ConvolutionLayer::requestExtraTensors()
reqs.push_back(tr_weights);

auto conv_back_prop_output =
ExtraTensorRequest::createRequestLike(_back_prop_out, &_conv_back_prop_output);
ExtraTensorRequest::createRequestLike(_back_prop_output, &_conv_back_prop_output);
reqs.push_back(conv_back_prop_output);

auto tr_grad_weights = createTransposeTenosrRequest(weights, &_transposed_grad_weights);
auto tr_grad_weights = createTransposeTenosrRequest(_grad_weights, &_transposed_grad_weights);
reqs.push_back(tr_grad_weights);

if (_activation != ir::Activation::NONE)
reqs.push_back(ExtraTensorRequest::createRequestLike(back_prop_out, &_act_back_prop_output));
reqs.push_back(
ExtraTensorRequest::createRequestLike(_back_prop_output, &_act_back_prop_output));

return reqs;
}
Expand Down
11 changes: 4 additions & 7 deletions runtime/onert/core/include/backend/train/ExtraTensorRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ class ExtraTensor final : public basic::Tensor
ExtraTensor() = delete;

public:
ExtraTensor(const ir::OperandInfo &info, const ir::Layout layout)
: basic::Tensor(info, layout, nullptr)
ExtraTensor(const ir::OperandInfo &info) : basic::Tensor(info, nullptr)
{
// DO NOTHING
}
Expand All @@ -50,21 +49,19 @@ class ExtraTensorRequest
{

public:
ExtraTensorRequest(ir::OperandInfo info, ir::Layout layout, ExtraTensorLifeTime lt,
ExtraTensorRequest(ir::OperandInfo info, ExtraTensorLifeTime lt,
backend::train::ExtraTensor **addr)
: info(info), layout{layout}, lifetime(lt), address(addr)
: info(info), lifetime(lt), address(addr)
{
}

static ExtraTensorRequest createRequestLike(const IPortableTensor *origin,
backend::train::ExtraTensor **addr)
{
return ExtraTensorRequest(origin->get_info(), origin->layout(), ExtraTensorLifeTime::BACKWARD,
addr);
return ExtraTensorRequest(origin->get_info(), ExtraTensorLifeTime::BACKWARD, addr);
}

ir::OperandInfo info;
ir::Layout layout;
ExtraTensorLifeTime lifetime;
backend::train::ExtraTensor **address;

Expand Down

0 comments on commit ca52249

Please sign in to comment.