Skip to content

Commit

Permalink
[onert-micro] Introduce OMBackpropExecute entities (Samsung#13148)
Browse files Browse the repository at this point in the history
This pr introduces OMBackpropExecute, OMBackpropExecuteArgs and OMBackpropExecutionBuilder entities.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
BalyshevArtem authored Jun 13, 2024
1 parent 8b33f84 commit 3eceee5
Show file tree
Hide file tree
Showing 5 changed files with 301 additions and 0 deletions.
41 changes: 41 additions & 0 deletions onert-micro/onert-micro/include/train/OMBackpropExecute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_TRAIN_BACKPROP_EXECUTE_H
#define ONERT_MICRO_TRAIN_BACKPROP_EXECUTE_H

#include "OMStatus.h"
#include "core/OMRuntimeContext.h"
#include "core/OMRuntimeStorage.h"
#include "train/OMBackpropExecuteArgs.h"
#include "core/memory/OMRuntimeAllocator.h"

namespace onert_micro
{
namespace train
{

struct OMBackpropExecute
{
// Start execution of the backward graph
static OMStatus runBackward(const OMConfig &config, OMBackpropExecuteArgs &args,
core::memory::OMRuntimeAllocator &allocator);
};

} // namespace train
} // namespace onert_micro

#endif // ONERT_MICRO_TRAIN_BACKPROP_EXECUTE_H
46 changes: 46 additions & 0 deletions onert-micro/onert-micro/include/train/OMBackpropExecuteArgs.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_TRAIN_BACKPROP_ARGS_H
#define ONERT_MICRO_TRAIN_BACKPROP_ARGS_H

#include "OMStatus.h"
#include "core/OMRuntimeContext.h"
#include "core/OMRuntimeStorage.h"
#include "core/OMRuntimeModule.h"
#include "core/train/OMTrainingStorage.h"

namespace onert_micro
{
namespace train
{

/*
* Args to execute backpropagation graph
*/
struct OMBackpropExecuteArgs
{
core::OMRuntimeStorage &forward_storage;
core::OMRuntimeStorage &backward_storage;
core::OMRuntimeContext &backward_context;
bool is_last_layer;
uint16_t kernel_index;
};

} // namespace train
} // namespace onert_micro

#endif // ONERT_MICRO_TRAIN_BACKPROP_ARGS_H
87 changes: 87 additions & 0 deletions onert-micro/onert-micro/include/train/OMBackpropExecutionBuilder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_BACKPROP_EXECUTION_BUILDER_H
#define ONERT_MICRO_BACKPROP_EXECUTION_BUILDER_H

#include "core/reader/OMCircleReader.h"
#include "core/OMKernelType.h"
#include "core/OMRuntimeStorage.h"
#include "core/OMRuntimeContext.h"
#include "train/OMBackpropExecuteArgs.h"

namespace onert_micro
{
namespace train
{

using KernelTrainFunc = OMStatus(const OMBackpropExecuteArgs &);

#define REGISTER_TRAIN_KERNEL(builtin_operator, name) \
OMStatus train_kernel_Circle##name(const OMBackpropExecuteArgs &);
#include "KernelsToTrain.lst"
#undef REGISTER_TRAIN_KERNEL

/*
* Class to registry kernels for execution of backward graph (calculation backpropagation)
* Kernels which will be register define in KernelsToTrain.lst current pal directory
*/
class KernelBuiltinTrainRegistry
{
public:
constexpr KernelBuiltinTrainRegistry() : _operator_train()
{
#define REGISTER_TRAIN_KERNEL(builtin_operator, name) \
registerKernelTrain(core::OMBuilderID::BuiltinOperator_##builtin_operator, \
train_kernel_Circle##name);

#include "KernelsToTrain.lst"

#undef REGISTER_TRAIN_KERNEL
}

public:
OMStatus getKernelTrainFunc(core::OMBuilderID builderID, KernelTrainFunc **train_func) const
{
const auto builder_id_opcode = size_t(builderID);
assert(builder_id_opcode < size_t(core::OMBuilderID::BuiltinOperatorsSize));
if (builder_id_opcode >= size_t(core::OMBuilderID::BuiltinOperatorsSize))
{
*train_func = nullptr;
return UnknownError;
}
*train_func = _operator_train[builder_id_opcode];
return Ok;
}

private:
constexpr void registerKernelTrain(core::OMBuilderID id, KernelTrainFunc *func)
{
assert(size_t(id) < size_t(core::OMBuilderID::BuiltinOperatorsSize));
_operator_train[size_t(id)] = func;
}

private:
KernelTrainFunc *_operator_train[size_t(core::OMBuilderID::BuiltinOperatorsSize)];
};

// Global constexpr kernel builtin train
constexpr KernelBuiltinTrainRegistry kernel_builtin_train;

} // namespace train
} // namespace onert_micro

#endif // ONERT_MICRO_BACKPROP_EXECUTION_BUILDER_H
109 changes: 109 additions & 0 deletions onert-micro/onert-micro/src/train/OMBackpropExecute.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "train/OMBackpropExecute.h"
#include "train/OMBackpropExecutionBuilder.h"

using namespace onert_micro::train;
using namespace onert_micro;

/*
* Run backward graph to calculate gradients
*/
OMStatus OMBackpropExecute::runBackward(const OMConfig &config, OMBackpropExecuteArgs &args,
core::memory::OMRuntimeAllocator &allocator)
{
OMStatus status = Ok;

core::OMRuntimeContext &context = args.backward_context;
core::OMRuntimeStorage &forward_storage = args.forward_storage;
core::OMRuntimeStorage &backward_storage = args.backward_storage;

const core::reader::CircleOperators *operators = context.getCircleOperators();

const auto num_operators = operators->size();
const auto *op_codes = context.getCircleOpcodes();

uint32_t num_train_layers = config.training_context.num_of_train_layers == 0
? num_operators
: config.training_context.num_of_train_layers;
uint32_t last_node_pos = std::min(num_operators, num_train_layers);

for (uint32_t i = 0; i < last_node_pos; ++i)
{
uint32_t cur_op_index = num_operators - i - 1;
auto *cur_op = operators->operator[](cur_op_index);

status = allocator.allocate(i, &context, &backward_storage);

if (status != Ok)
return status;

core::OMBuilderID builder_id = core::OMBuilderID::Size;
const circle::Operator *op = operators->operator[](cur_op_index);
uint32_t index = op->opcode_index();

assert(index < op_codes->size());

const auto opcode = op_codes->operator[](index);

status = core::getBuilderId(opcode, builder_id);

assert(status == Ok);
if (status != Ok)
return status;

args.kernel_index = cur_op_index;

if (i == last_node_pos - 1)
args.is_last_layer = true;

// Calculate gradients
KernelTrainFunc *train_func = nullptr;
if (size_t(builder_id) < size_t(core::OMBuilderID::BuiltinOperatorsSize))
{
// Builtin operator
status = kernel_builtin_train.getKernelTrainFunc(builder_id, &train_func);
}
else
{
assert(false && "Unsupported kernel type for training");
return UnsupportedOp;
}

assert(train_func != nullptr);

if (status != Ok)
return status;

status = train_func(args);

assert(status == Ok);

if (status != Ok)
return status;

// Deallocate tensors data in backward storage
status = allocator.deallocate(i, &backward_storage);
if (status != Ok)
return status;

// Deallocate tensors data in forward storage
status = allocator.deallocate(i, &forward_storage);
}

return status;
}
18 changes: 18 additions & 0 deletions onert-micro/onert-micro/src/train/OMBackpropExecutionBuilder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "train/OMBackpropExecutionBuilder.h"
// Do nothing

0 comments on commit 3eceee5

Please sign in to comment.