Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Artem Balyshev committed Jun 11, 2024
1 parent 83fbeb1 commit 97976ac
Show file tree
Hide file tree
Showing 17 changed files with 2,743 additions and 45 deletions.
155 changes: 115 additions & 40 deletions onert-micro/eval-driver/TrainingDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

#define MODEL_TYPE float

#define CLASSIFICATION_TASK 0

namespace
{

Expand Down Expand Up @@ -177,13 +179,13 @@ int entry(int argc, char **argv)
config.wof_ptr = nullptr;

// Set user defined training settings
const uint32_t training_epochs = 50;
const float lambda = 0.001f;
const uint32_t training_epochs = 30;
const float lambda = 0.1f;
const uint32_t BATCH_SIZE = 32;
const uint32_t INPUT_SIZE = 180;
const uint32_t OUTPUT_SIZE = 4;
const uint32_t INPUT_SIZE = 13;
const uint32_t OUTPUT_SIZE = 1;
const uint32_t num_train_layers = 0;
const onert_micro::OMLoss loss = onert_micro::CROSS_ENTROPY;
const onert_micro::OMLoss loss = onert_micro::MSE;
const onert_micro::OMTrainOptimizer train_optim = onert_micro::ADAM;
const float beta = 0.9;
const float beta_squares = 0.999;
Expand All @@ -200,6 +202,7 @@ int entry(int argc, char **argv)
train_context.beta = beta;
train_context.beta_squares = beta_squares;
train_context.epsilon = epsilon;
train_context.epochs = training_epochs;

config.training_context = train_context;
}
Expand All @@ -214,9 +217,14 @@ int entry(int argc, char **argv)
// Note: here test size used with BATCH_SIZE = 1
float test_input[INPUT_SIZE];
float test_target[OUTPUT_SIZE];

std::vector<float> accuracy_v;
std::vector<float> cross_entropy_v;
std::vector<float> mse_v;
std::vector<float> mae_v;

float max_accuracy = std::numeric_limits<float>::min();
float min_mae = std::numeric_limits<float>::max();

for (uint32_t e = 0; e < training_epochs; ++e)
{
Expand All @@ -230,7 +238,6 @@ int entry(int argc, char **argv)
cur_batch_size = std::max(1u, cur_batch_size);

config.training_context.batch_size = cur_batch_size;
config.training_context.num_step = i + 1;

// Read current input and target data
readDataFromFile(input_input_train_data_path, reinterpret_cast<char *>(training_input),
Expand All @@ -253,22 +260,51 @@ int entry(int argc, char **argv)
float cross_entropy_metric = 0.f;
float accuracy = 0.f;

// Evaluate cross_entropy and accuracy metrics
train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS,
reinterpret_cast<void *>(&cross_entropy_metric),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast<void *>(&accuracy),
cur_batch_size);

// Save them into vectors
accuracy_v.push_back(accuracy);
cross_entropy_v.push_back(cross_entropy_metric);
if (CLASSIFICATION_TASK)
{
// Evaluate cross_entropy and accuracy metrics
train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS,
reinterpret_cast<void *>(&cross_entropy_metric),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast<void *>(&accuracy),
cur_batch_size);

// Save them into vectors
accuracy_v.push_back(accuracy);
cross_entropy_v.push_back(cross_entropy_metric);
}
else
{
// Evaluate mse and mae metrics
train_interpreter.evaluateMetric(onert_micro::MSE_METRICS, reinterpret_cast<void *>(&mse),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::MAE_METRICS, reinterpret_cast<void *>(&mae),
cur_batch_size);

// Save them into vectors
mse_v.push_back(mse);
mae_v.push_back(mae);
}
}

// Reset num step value
config.training_context.num_step = 0;

// Calculate and print average values
float sum_acc = std::accumulate(accuracy_v.begin(), accuracy_v.end(), 0.f);
float sum_ent = std::accumulate(cross_entropy_v.begin(), cross_entropy_v.end(), 0.f);
std::cout << "Train Average ACCURACY = " << sum_acc / accuracy_v.size() << "\n";
std::cout << "Train Average CROSS ENTROPY = " << sum_ent / cross_entropy_v.size() << "\n";
if (CLASSIFICATION_TASK)
{
float sum_acc = std::accumulate(accuracy_v.begin(), accuracy_v.end(), 0.f);
float sum_ent = std::accumulate(cross_entropy_v.begin(), cross_entropy_v.end(), 0.f);
std::cout << "Train Average ACCURACY = " << sum_acc / accuracy_v.size() << "\n";
std::cout << "Train Average CROSS ENTROPY = " << sum_ent / cross_entropy_v.size() << "\n";
}
else
{
float sum_mse = std::accumulate(mse_v.begin(), mse_v.end(), 0.f);
float sum_mae = std::accumulate(mae_v.begin(), mae_v.end(), 0.f);
std::cout << "Train Average MSE = " << sum_mse / mse_v.size() << "\n";
std::cout << "Train Average MAE = " << sum_mae / mae_v.size() << "\n";
}

// Run test for current epoch
std::cout << "Run test for epoch: " << e + 1 << "/" << training_epochs << "\n";
Expand Down Expand Up @@ -296,29 +332,68 @@ int entry(int argc, char **argv)
float cross_entropy_metric = 0.f;
float accuracy = 0.f;

train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS,
reinterpret_cast<void *>(&cross_entropy_metric),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast<void *>(&accuracy),
cur_batch_size);

accuracy_v.push_back(accuracy);
cross_entropy_v.push_back(cross_entropy_metric);
if (CLASSIFICATION_TASK)
{
// Evaluate cross_entropy and accuracy metrics
train_interpreter.evaluateMetric(onert_micro::CROSS_ENTROPY_METRICS,
reinterpret_cast<void *>(&cross_entropy_metric),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::ACCURACY, reinterpret_cast<void *>(&accuracy),
cur_batch_size);

// Save them into vectors
accuracy_v.push_back(accuracy);
cross_entropy_v.push_back(cross_entropy_metric);
}
else
{
// Evaluate mse and mae metrics
train_interpreter.evaluateMetric(onert_micro::MSE_METRICS, reinterpret_cast<void *>(&mse),
cur_batch_size);
train_interpreter.evaluateMetric(onert_micro::MAE_METRICS, reinterpret_cast<void *>(&mae),
cur_batch_size);

// Save them into vectors
mse_v.push_back(mse);
mae_v.push_back(mae);
}
}
// Calculate and print average values
sum_acc = std::accumulate(accuracy_v.begin(), accuracy_v.end(), 0.f);
sum_ent = std::accumulate(cross_entropy_v.begin(), cross_entropy_v.end(), 0.f);
std::cout << "Test Average ACCURACY = " << sum_acc / accuracy_v.size() << "\n";
std::cout << "Test Average CROSS ENTROPY = " << sum_ent / cross_entropy_v.size() << "\n";

float acc = sum_acc / accuracy_v.size();
if (acc > max_accuracy)
if (CLASSIFICATION_TASK)
{
float sum_acc = std::accumulate(accuracy_v.begin(), accuracy_v.end(), 0.f);
float sum_ent = std::accumulate(cross_entropy_v.begin(), cross_entropy_v.end(), 0.f);
std::cout << "Test Average ACCURACY = " << sum_acc / accuracy_v.size() << "\n";
std::cout << "Test Average CROSS ENTROPY = " << sum_ent / cross_entropy_v.size() << "\n";

// Best checkpoint part
float acc = sum_acc / accuracy_v.size();
if (acc > max_accuracy)
{
// Save best checkpoint
train_interpreter.saveCheckpoint(config, checkpoints_path);
max_accuracy = acc;
std::cout << "Found new max Test ACCURACY = " << max_accuracy << " in epoch = " << e + 1
<< " / " << training_epochs << "\n";
}
}
else
{
// Save best checkpoint
train_interpreter.saveCheckpoint(config, checkpoints_path);
max_accuracy = acc;
std::cout << "Found new max Test ACCURACY = " << max_accuracy << " in epoch = " << e + 1
<< " / " << training_epochs << "\n";
float sum_mse = std::accumulate(mse_v.begin(), mse_v.end(), 0.f);
float sum_mae = std::accumulate(mae_v.begin(), mae_v.end(), 0.f);
std::cout << "Test Average MSE = " << sum_mse / mse_v.size() << "\n";
std::cout << "Test Average MAE = " << sum_mae / mae_v.size() << "\n";

// Best checkpoint part
float acc = sum_mae / mae_v.size();
if (acc < min_mae)
{
// Save best checkpoint
train_interpreter.saveCheckpoint(config, checkpoints_path);
min_mae = acc;
std::cout << "Found new min Test MAE = " << min_mae << " in epoch = " << e + 1 << " / "
<< training_epochs << "\n";
}
}
}

Expand Down
3 changes: 2 additions & 1 deletion onert-micro/onert-micro/include/OMConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ struct OMTrainingContext
float beta = 0.9f;
float beta_squares = 0.9f;
float epsilon = 10e-8;
uint32_t num_step = batch_size;
uint32_t num_step = 1;
uint32_t num_epoch = 0;
uint32_t epochs = 0;
};

/*
Expand Down
2 changes: 1 addition & 1 deletion onert-micro/onert-micro/include/OMTrainingInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class OMTrainingInterpreter
// -> calculate error (with target data which was set in SetTarget) ->
// -> run backward graph -> update optimizer state -> after batch_size steps update weights
// Warning: before using trainSingleStep call: 1) importTrainModel; 2) setInput; 3) setTarget
OMStatus trainSingleStep(const OMConfig &config);
OMStatus trainSingleStep(OMConfig &config);

// Reset all states and data saved into OMTrainingInterpreter (trained weights will not be reset)
OMStatus reset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class OMTrainingRuntimeModule : public OMRuntimeModule
// -> calculate error (with target data which was set in SetTarget) ->
// -> run backward graph -> update optimizer state -> after batch_size steps update weights
// Warning: before using trainSingleStep call: 1) importTrainModel; 2) setInput; 3) setTarget
OMStatus trainSingleStep(const OMConfig &config);
OMStatus trainSingleStep(OMConfig &config);

// Reset all states and data saved into OMTrainingInterpreter (trained weights will not be reset)
OMStatus reset();
Expand Down
63 changes: 63 additions & 0 deletions onert-micro/onert-micro/include/train/tests/OMTestTrainBase.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_TRAIN_TESTS_TEST_TRAIN_BASE_H
#define ONERT_MICRO_TRAIN_TESTS_TEST_TRAIN_BASE_H

#include <vector>

namespace onert_micro
{
namespace train
{
namespace test
{

template <typename T, typename U = T> class OMTestTrainBase
{
public:
virtual ~OMTestTrainBase() = default;

// Get ptr to trained model
virtual char *getModelPtr() = 0;

// Return model size
virtual size_t getModelSize() = 0;

// Return num train samples
virtual size_t getTrainNumSamples() = 0;

// Return num test samples
virtual size_t getTestNumSamples() = 0;

// Read train input data with cur size and cur offset
virtual std::vector<T> readTrainInputData(size_t size, size_t offset) = 0;

// Read train target data with cur size and cur offset
virtual std::vector<U> readTrainTargetData(size_t size, size_t offset) = 0;

// Read test input data with cur size and cur offset
virtual std::vector<T> readTestInputData(size_t size, size_t offset) = 0;

// Read test target data with cur size and cur offset
virtual std::vector<T> readTestTargetData(size_t size, size_t offset) = 0;
};

} // namespace test
} // namespace train
} // namespace onert_micro

#endif // ONERT_MICRO_TRAIN_TESTS_TEST_TRAIN_BASE_H
Loading

0 comments on commit 97976ac

Please sign in to comment.