Skip to content

Commit

Permalink
[tests/onert_train] Change the type number of loss and optimizer (#12425
Browse files Browse the repository at this point in the history
)

This commit changes the type number of loss and optimizer to match
the enum value of `nnfw_experimental.h`.

ONE-DCO-1.0-Signed-off-by: Jiyoung Yun <[email protected]>
  • Loading branch information
jyoungyun authored Jan 10, 2024
1 parent 3dc37ed commit ee4487c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
12 changes: 6 additions & 6 deletions tests/tools/onert_train/src/args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,18 +224,18 @@ void Args::Initialize(void)
("epoch", po::value<int>()->default_value(5)->notifier([&](const auto &v) { _epoch = v; }), "Epoch number (default: 5)")
("batch_size", po::value<int>()->default_value(32)->notifier([&](const auto &v) { _batch_size = v; }), "Batch size (default: 32)")
("learning_rate", po::value<float>()->default_value(0.001)->notifier([&](const auto &v) { _learning_rate = v; }), "Learning rate (default: 0.001)")
("loss", po::value<int>()->default_value(0)->notifier([&] (const auto &v) { _loss_type = v; }),
("loss", po::value<int>()->default_value(1)->notifier([&] (const auto &v) { _loss_type = v; }),
"Loss type\n"
"0: MEAN_SQUARED_ERROR (default)\n"
"1: CATEGORICAL_CROSSENTROPY")
"1: MEAN_SQUARED_ERROR (default)\n"
"2: CATEGORICAL_CROSSENTROPY")
("loss_reduction_type", po::value<int>()->default_value(1)->notifier([&] (const auto &v) { _loss_reduction_type = v; }),
"Loss Reduction type\n"
"1: SUM_OVER_BATCH_SIZE(default)\n"
"2: SUM")
("optimizer", po::value<int>()->default_value(0)->notifier([&] (const auto &v) { _optimizer_type = v; }),
("optimizer", po::value<int>()->default_value(1)->notifier([&] (const auto &v) { _optimizer_type = v; }),
"Optimizer type\n"
"0: SGD (default)\n"
"1: Adam")
"1: SGD (default)\n"
"2: Adam")
("metric", po::value<int>()->default_value(-1)->notifier([&] (const auto &v) { _metric_type = v; }),
"Metric type\n"
"Simply calculates the metric value using the variables (default: none)\n"
Expand Down
8 changes: 4 additions & 4 deletions tests/tools/onert_train/src/onert_train.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ int main(const int argc, char **argv)
auto convertLossType = [](int type) {
switch (type)
{
case 0:
return NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR;
case 1:
return NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR;
case 2:
return NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY;
default:
std::cerr << "E: not supported loss type" << std::endl;
Expand All @@ -150,9 +150,9 @@ int main(const int argc, char **argv)
auto convertOptType = [](int type) {
switch (type)
{
case 0:
return NNFW_TRAIN_OPTIMIZER_SGD;
case 1:
return NNFW_TRAIN_OPTIMIZER_SGD;
case 2:
return NNFW_TRAIN_OPTIMIZER_ADAM;
default:
std::cerr << "E: not supported optimizer type" << std::endl;
Expand Down

0 comments on commit ee4487c

Please sign in to comment.