Skip to content

Commit

Permalink
Add infra to generate extra tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
zetwhite committed Aug 7, 2024
1 parent 2a2dbc6 commit dc9b748
Show file tree
Hide file tree
Showing 29 changed files with 834 additions and 188 deletions.
25 changes: 25 additions & 0 deletions runtime/onert/backend/train/BackendContext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

#include "BackendContext.h"

#include "ExtraTensorGenerator.h"
#include "TensorBuilder.h"
#include "KernelGenerator.h"
#include "ops/BackPropInitializer.h"

#include <backend/basic/train/TrainableBackendContextHelpers.h>
#include <ir/train/ITrainableOperation.h>
#include <misc/polymorphic_downcast.h>

#include <cassert>
Expand Down Expand Up @@ -229,6 +231,29 @@ FunctionMap BackendContext::genKernels()
// fn_seq->iterate([&](exec::IFunction &ifunc) { ifunc.prepare(); });
// }

ExtraTensorGenerator extra_tensor_gen(trainable_graph(), _tensor_builder, _tensor_registry);

const auto &ops = trainable_graph()->operations();

for (auto &pair : ret)
{
auto &op_idx = pair.first;
auto &fn_seq = pair.second;

const ir::IOperation *op = &ops.at(op_idx);
const auto trainable_op = dynamic_cast<const ir::train::TrainableOperation *>(op);
assert(trainable_op != nullptr);

if (not trainable_op->isRequiredForBackward())
continue;

fn_seq->iterate([&](exec::train::ITrainableFunction &fn) {
extra_tensor_gen.register_tensors(op_idx, (&fn)->requestExtraTensors());
});
}
extra_tensor_gen.plan();
extra_tensor_gen.allocate();

return ret;
}

Expand Down
109 changes: 109 additions & 0 deletions runtime/onert/backend/train/ExtraTensorGenerator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ExtraTensorGenerator.h"

#include "ExtraTensorIndex.h"

#include <ir/Operations.h>
#include <util/logging.h>
#include <memory>

namespace onert
{
namespace backend
{
namespace train
{

ExtraTensorGenerator::ExtraTensorGenerator(const ir::train::TrainableGraph *tgraph,
std::shared_ptr<TensorBuilder> &tensor_builder,
std::shared_ptr<ITensorRegistry> &tensor_registry)
: _tgraph(tgraph), _tensor_builder(tensor_builder)
{
_tensor_reg = std::dynamic_pointer_cast<TensorRegistry>(tensor_registry);
}

void ExtraTensorGenerator::register_tensors(ir::OperationIndex op_idx, ExtraTensorRequests &&reqs)
{
// save request, _idx_to_reuqests will be used for memory planning
if (reqs.size() == 0)
return;

// _idx_to_requests[op_idx] = reqs;
_idx_to_requests.insert({op_idx, reqs});
auto &operations = _tgraph->operations();

for (size_t i = 0; i < reqs.size(); i++)
{
// register tensor
ExtraTensorIndex tensor_idx(op_idx, i);
_tensor_builder->registerExtraTensorInfo(tensor_idx, reqs[i].info);

std::stringstream op_info;
op_info << op_idx << "_" << operations.at(op_idx).name();
VERBOSE(ExtraTensorGenerator) << "register (idx:" << tensor_idx << ") requested from "
<< op_info.str() << std::endl;

// return registered tensor
auto generated_tensor = _tensor_reg->getExtraTensor(tensor_idx);
*reqs[i].address = generated_tensor;
}
return;
}

void ExtraTensorGenerator::plan()
{
// forwarding order
const auto f_order = _tgraph->topolSortOperations();
for (const auto &op_index : f_order)
{
auto &reqs = _idx_to_requests[op_index];
for (auto i = 0u; i < reqs.size(); ++i)
{
auto &lt = reqs[i].lifetime;
if (lt == ExtraTensorLifeTime::FORWARD_TO_BACKWARD)
_tensor_builder->notifyFirstUse(ExtraTensorIndex(op_index, i));
}
}

// backwarding order
const auto b_order = _tgraph->essentialBackwardOrder();
for (const auto &op_index : b_order)
{
auto &reqs = _idx_to_requests[op_index];

for (auto i = 0u; i < reqs.size(); ++i)
{
auto &lt = reqs[i].lifetime;
if (lt == ExtraTensorLifeTime::BACKWARD)
_tensor_builder->notifyFirstUse(ExtraTensorIndex(op_index, i));
}

for (auto i = 0u; i < reqs.size(); ++i)
{
auto &lt = reqs[i].lifetime;
if (lt == ExtraTensorLifeTime::FORWARD_TO_BACKWARD || lt == ExtraTensorLifeTime::BACKWARD)
_tensor_builder->notifyLastUse(ExtraTensorIndex(op_index, i));
}
}
}

void ExtraTensorGenerator::allocate() { _tensor_builder->allocateExtra(); }

} // namespace train
} // namespace backend
} // namespace onert
59 changes: 59 additions & 0 deletions runtime/onert/backend/train/ExtraTensorGenerator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_BACKEND_EXTRA_TENSOR_GENERATOR_H__
#define __ONERT_BACKEND_EXTRA_TENSOR_GENERATOR_H__

#include <backend/train/ExtraTensorRequest.h>
#include <ir/train/TrainableGraph.h>
#include <ir/Index.h>

#include "TensorBuilder.h"

namespace onert
{
namespace backend
{
namespace train
{

class ExtraTensorGenerator
{
public:
ExtraTensorGenerator() = delete;

ExtraTensorGenerator(const ir::train::TrainableGraph *tgraph,
std::shared_ptr<TensorBuilder> &tensor_builder,
std::shared_ptr<ITensorRegistry> &tensor_registry);

public:
// Since register is reserved keyword, use 'register_tensors' intead of 'register'
void register_tensors(ir::OperationIndex idx, ExtraTensorRequests &&requests);
void plan();
void allocate();

private:
const ir::train::TrainableGraph *_tgraph;
std::shared_ptr<TensorBuilder> _tensor_builder;
std::shared_ptr<TensorRegistry> _tensor_reg;
std::unordered_map<ir::OperationIndex, ExtraTensorRequests> _idx_to_requests;
};

} // namespace train
} // namespace backend
} // namespace onert

#endif // __ONERT_BACKEND_EXTRA_TENSOR_GENERATOR_H__
70 changes: 70 additions & 0 deletions runtime/onert/backend/train/ExtraTensorIndex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_BACKEND_TRAIN_EXTRA_TENSOR_INDEX_H__
#define __ONERT_BACKEND_TRAIN_EXTRA_TENSOR_INDEX_H__

#include <ir/Index.h>

namespace onert
{
namespace backend
{
namespace train
{

class ExtraTensorIndex
{
public:
ExtraTensorIndex(ir::OperationIndex op, uint32_t sub) : op_index(op), sub_index(sub) {}

ir::OperationIndex op_index;
uint32_t sub_index;

bool operator==(const ExtraTensorIndex &other) const
{
return op_index == other.op_index && sub_index == other.sub_index;
}
};

inline std::ostream &operator<<(std::ostream &o, const ExtraTensorIndex &i)
{
o << i.op_index;
o << "-" << i.sub_index;
return o;
}

} // namespace train
} // namespace backend
} // namespace onert

namespace std
{

template <> struct hash<onert::backend::train::ExtraTensorIndex>
{
size_t operator()(const onert::backend::train::ExtraTensorIndex &index) const noexcept
{
const auto op_index = index.op_index;
const auto sub_index = index.sub_index;

return (static_cast<size_t>(op_index.value())) << 16 | static_cast<size_t>(sub_index);
}
};

} // namespace std

#endif // __ONERT_BACKEND_TRAIN_EXTRA_TENSOR_INDEX_H__
31 changes: 20 additions & 11 deletions runtime/onert/backend/train/MemoryManager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "MemoryManager.h"

#include "MemoryPlannerFactory.h"
#include "ExtraTensorIndex.h"

#include <util/ConfigSource.h>

Expand Down Expand Up @@ -53,52 +54,60 @@ uint8_t *GradientMemoryManager::getOptVarBuffer(const ir::OperandIndex &ind, uin
return _var_mem_alloc->base() + var_offset + mem_blk.offset;
}

DisposableMemoryManager::DisposableMemoryManager() : _mem_planner{createMemoryPlanner()}
template <typename Index>
TrainMemoryManager<Index>::TrainMemoryManager() : _mem_planner{createMemoryPlanner()}
{
// DO NOTHING
}

DisposableMemoryManager::DisposableMemoryManager(const std::string planner_id)
template <typename Index>
TrainMemoryManager<Index>::TrainMemoryManager(const std::string planner_id)
: _mem_planner{createMemoryPlanner(planner_id)}
{
// DO NOTHING
}

basic::IMemoryPlanner<DisposableTensorIndex> *DisposableMemoryManager::createMemoryPlanner()
template <typename Index>
basic::IMemoryPlanner<Index> *TrainMemoryManager<Index>::createMemoryPlanner()
{
auto planner_id = util::getConfigString(util::config::CPU_MEMORY_PLANNER);
return MemoryPlannerFactory::get().create(planner_id);
return MemoryPlannerFactory<Index>::get().create(planner_id);
}

basic::IMemoryPlanner<DisposableTensorIndex> *
DisposableMemoryManager::createMemoryPlanner(const std::string planner_id)
template <typename Index>
basic::IMemoryPlanner<Index> *
TrainMemoryManager<Index>::createMemoryPlanner(const std::string planner_id)
{
return MemoryPlannerFactory::get().create(planner_id);
return MemoryPlannerFactory<Index>::get().create(planner_id);
}

void DisposableMemoryManager::claimPlan(const DisposableTensorIndex &ind, uint32_t size)
template <typename Index> void TrainMemoryManager<Index>::claimPlan(const Index &ind, uint32_t size)
{
_mem_planner->claim(ind, size);
}

void DisposableMemoryManager::releasePlan(const DisposableTensorIndex &ind)
template <typename Index> void TrainMemoryManager<Index>::releasePlan(const Index &ind)
{
_mem_planner->release(ind);
}

void DisposableMemoryManager::allocate(void)
template <typename Index> void TrainMemoryManager<Index>::allocate(void)
{
_mem_alloc = std::make_shared<basic::Allocator>(_mem_planner->capacity());
assert(_mem_alloc->base());
}

uint8_t *DisposableMemoryManager::getBuffer(const DisposableTensorIndex &ind) const
template <typename Index> uint8_t *TrainMemoryManager<Index>::getBuffer(const Index &ind) const
{
assert(_mem_planner->memory_plans().find(ind) != _mem_planner->memory_plans().end());
const auto &mem_blk = _mem_planner->memory_plans().at(ind);
return _mem_alloc->base() + mem_blk.offset;
}

// Instatiation
template class TrainMemoryManager<DisposableTensorIndex>;
template class TrainMemoryManager<ExtraTensorIndex>;

} // namespace train
} // namespace backend
} // namespace onert
Loading

0 comments on commit dc9b748

Please sign in to comment.