Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onert-micro] Add training unit tests #13229

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions onert-micro/onert-micro/include/train/tests/OMTestTrainBase.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_TRAIN_TESTS_TEST_TRAIN_BASE_H
#define ONERT_MICRO_TRAIN_TESTS_TEST_TRAIN_BASE_H

#include <vector>

chunseoklee marked this conversation as resolved.
Show resolved Hide resolved
namespace onert_micro
{
namespace train
{
namespace test
{

template <typename T, typename U = T> class OMTestTrainBase
{
public:
virtual ~OMTestTrainBase() = default;

// Get ptr to trained model
virtual char *getModelPtr() = 0;

// Return model size
virtual size_t getModelSize() = 0;

// Return num train samples
virtual size_t getTrainNumSamples() = 0;

// Return num test samples
virtual size_t getTestNumSamples() = 0;

// Read train input data with cur size and cur offset
virtual std::vector<T> readTrainInputData(size_t size, size_t offset) = 0;

// Read train target data with cur size and cur offset
virtual std::vector<U> readTrainTargetData(size_t size, size_t offset) = 0;

// Read test input data with cur size and cur offset
virtual std::vector<T> readTestInputData(size_t size, size_t offset) = 0;

// Read test target data with cur size and cur offset
virtual std::vector<T> readTestTargetData(size_t size, size_t offset) = 0;
};

} // namespace test
} // namespace train
} // namespace onert_micro

#endif // ONERT_MICRO_TRAIN_TESTS_TEST_TRAIN_BASE_H
130 changes: 130 additions & 0 deletions onert-micro/onert-micro/include/train/tests/OMTestUtils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_TRAIN_TESTS_TEST_UTILS_H
#define ONERT_MICRO_TRAIN_TESTS_TEST_UTILS_H

#include "OMTrainingInterpreter.h"
#include "train/tests/OMTestTrainBase.h"

#include <vector>
#include <numeric>

namespace onert_micro
{
namespace train
{
namespace test
{

// Train training_interpreter using config params and OMTestTrainBase to obtain data
template <typename T, typename U = T>
OMStatus train(OMTrainingInterpreter &train_interpreter, OMConfig &config,
OMTestTrainBase<T, U> &test_base)
{
OMStatus status = Ok;
const uint32_t training_epochs = config.training_context.epochs;
const uint32_t num_train_data_samples = test_base.getTrainNumSamples();
const uint32_t batch_size = config.training_context.batch_size;
const uint32_t input_size = train_interpreter.getInputSizeAt(0);
const uint32_t target_size = train_interpreter.getOutputSizeAt(0);
for (uint32_t e = 0; e < training_epochs; ++e)
{
config.training_context.num_epoch = e + 1;
uint32_t num_steps = num_train_data_samples / batch_size;
for (int i = 0; i < num_steps; ++i)
{
// Set batch size
uint32_t cur_batch_size = std::min(batch_size, num_train_data_samples - batch_size * i - 1);
cur_batch_size = std::max(1u, cur_batch_size);

config.training_context.batch_size = cur_batch_size;

// Read current input and target data
const uint32_t cur_input_size = sizeof(T) * input_size * cur_batch_size;
const uint32_t cur_target_size = sizeof(U) * target_size * cur_batch_size;
const uint32_t cur_input_offset = sizeof(T) * input_size * i * batch_size;
const uint32_t cur_target_offset = sizeof(U) * target_size * i * batch_size;

// Read input and target
std::vector<T> input_data = test_base.readTrainInputData(cur_input_size, cur_input_offset);
std::vector<U> target_data =
test_base.readTrainTargetData(cur_target_size, cur_target_offset);

// Set input and target
train_interpreter.setInput(reinterpret_cast<uint8_t *>(input_data.data()), 0);
train_interpreter.setTarget(reinterpret_cast<uint8_t *>(target_data.data()), 0);

// Train with current batch size
status = train_interpreter.trainSingleStep(config);
assert(status == Ok);
if (status != Ok)
return status;
}
}

return status;
}

// Evaluate trained model using metric and save result in metric_res
template <typename T, typename U = T>
OMStatus evaluate(OMTrainingInterpreter &train_interpreter, OMConfig &config,
OMTestTrainBase<T, U> &test_base, OMMetrics metric, U *metric_res)
{
OMStatus status = Ok;

// To store all calculated metrics values
std::vector<U> result_v;

const uint32_t num_test_data_samples = test_base.getTestNumSamples();
const uint32_t batch_size = 1;
const uint32_t input_size = train_interpreter.getInputSizeAt(0);
const uint32_t target_size = train_interpreter.getOutputSizeAt(0);
for (int i = 0; i < num_test_data_samples; ++i)
{
// Read current input and target data
const uint32_t cur_input_size = sizeof(T) * input_size;
const uint32_t cur_target_size = sizeof(U) * target_size;
const uint32_t cur_input_offset = sizeof(T) * input_size * i;
const uint32_t cur_target_offset = sizeof(U) * target_size * i;

// Read input and target
std::vector<T> input_data = test_base.readTestInputData(cur_input_size, cur_input_offset);
std::vector<U> target_data = test_base.readTestTargetData(cur_target_size, cur_target_offset);

// Set input and target
train_interpreter.setInput(reinterpret_cast<uint8_t *>(input_data.data()), 0);
train_interpreter.setTarget(reinterpret_cast<uint8_t *>(target_data.data()), 0);

U result = U(0);

status =
train_interpreter.evaluateMetric(metric, reinterpret_cast<void *>(&result), batch_size);
assert(status == Ok);
result_v.push_back(result);
}
// Calculate and save average values
*metric_res =
static_cast<U>(std::accumulate(result_v.begin(), result_v.end(), U(0)) / result_v.size());

return status;
}

} // namespace test
} // namespace train
} // namespace onert_micro

#endif // ONERT_MICRO_TRAIN_TESTS_TEST_UTILS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_TRAIN_TESTS_BOSTON_HOUSING_TASK_H
#define ONERT_MICRO_TRAIN_TESTS_BOSTON_HOUSING_TASK_H

#include "train/tests/OMTestTrainBase.h"
#include "train/tests/models/boston_housing.h"
#include "train/tests/boston_housing_task/data/test_target.h"
#include "train/tests/boston_housing_task/data/test_input.h"
#include "train/tests/boston_housing_task/data/train_target.h"
#include "train/tests/boston_housing_task/data/train_input.h"

#include <vector>
#include <cstring>

namespace onert_micro
{
namespace train
{
namespace test
{

namespace
{

const size_t NUM_TRAIN_SAMPLES = 404;
const size_t NUM_TEST_SAMPLES = 102;

} // namespace

template <typename T, typename U = T> class BostonHousingTask : public OMTestTrainBase<T, U>
{
public:
BostonHousingTask()
{
// Set model
_train_model_ptr.resize(models::boston_housing_model_size);
std::memcpy(_train_model_ptr.data(), models::boston_housing_model,
models::boston_housing_model_size);
}

// Get ptr to trained model
char *getModelPtr() final { return _train_model_ptr.data(); }

// Return model size
size_t getModelSize() final { return _train_model_ptr.size(); }

// Return num train samples
size_t getTrainNumSamples() final { return NUM_TRAIN_SAMPLES; };

// Return num test samples
size_t getTestNumSamples() final { return NUM_TEST_SAMPLES; }

// Read train input data with cur size and cur offset
std::vector<T> readTrainInputData(size_t size, size_t offset) final
{
std::vector<T> result(size);

auto *cur_ptr = data::train_input + offset;

std::memcpy(result.data(), cur_ptr, size);
return result;
}
// Read train target data with cur size and cur offset
std::vector<U> readTrainTargetData(size_t size, size_t offset) final
{
std::vector<T> result(size);

auto *cur_ptr = data::train_target + offset;

std::memcpy(result.data(), cur_ptr, size);
return result;
}

// Read test input data with cur size and cur offset
std::vector<T> readTestInputData(size_t size, size_t offset) final
{
std::vector<T> result(size);

auto *cur_ptr = data::test_input + offset;

std::memcpy(result.data(), cur_ptr, size);
return result;
}

// Read test target data with cur size and cur offset
std::vector<T> readTestTargetData(size_t size, size_t offset) final
{
std::vector<T> result(size);

auto *cur_ptr = data::test_target + offset;

std::memcpy(result.data(), cur_ptr, size);
return result;
}

private:
std::vector<char> _train_model_ptr = {};
};

} // namespace test
} // namespace train
} // namespace onert_micro

#endif // ONERT_MICRO_TRAIN_TESTS_BOSTON_HOUSING_TASK_H
Loading
Loading