Skip to content

Commit

Permalink
[tests/onert_train] Change the type number of loss and optimizer
Browse files Browse the repository at this point in the history
This commit changes the type number of loss and optimizer to match
the enum value of `nnfw_experimental.h`.

ONE-DCO-1.0-Signed-off-by: Jiyoung Yun <[email protected]>
  • Loading branch information
jyoungyun committed Jan 9, 2024
1 parent c925d8a commit 33f5624
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
12 changes: 6 additions & 6 deletions tests/tools/onert_train/src/args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,19 @@ void Args::Initialize(void)
("epoch", po::value<int>()->default_value(5)->notifier([&](const auto &v) { _epoch = v; }), "Epoch number (default: 5)")
("batch_size", po::value<int>()->default_value(32)->notifier([&](const auto &v) { _batch_size = v; }), "Batch size (default: 32)")
("learning_rate", po::value<float>()->default_value(0.001)->notifier([&](const auto &v) { _learning_rate = v; }), "Learning rate (default: 0.001)")
("loss", po::value<int>()->default_value(0)->notifier([&] (const auto &v) { _loss_type = v; }),
("loss", po::value<int>()->default_value(1)->notifier([&] (const auto &v) { _loss_type = v; }),
"Loss type\n"
"0: MEAN_SQUARED_ERROR (default)\n"
"1: CATEGORICAL_CROSSENTROPY\n")
"1: MEAN_SQUARED_ERROR (default)\n"
"2: CATEGORICAL_CROSSENTROPY\n")
("loss_reduction_type", po::value<int>()->default_value(0)->notifier([&] (const auto &v) { _loss_reduction_type = v; }),
"Loss Reduction type\n"
"0: AUTO (default)\n"
"1: SUM_OVER_BATCH_SIZE\n"
"2: SUM\n")
("optimizer", po::value<int>()->default_value(0)->notifier([&] (const auto &v) { _optimizer_type = v; }),
("optimizer", po::value<int>()->default_value(1)->notifier([&] (const auto &v) { _optimizer_type = v; }),
"Optimizer type\n"
"0: SGD (default)\n"
"1: Adam\n")
"1: SGD (default)\n"
"2: Adam\n")
("metric", po::value<int>()->default_value(-1)->notifier([&] (const auto &v) { _metric_type = v; }),
"Metricy type\n"
" Simply calculates the metric value using the variables (default: none)\n"
Expand Down
8 changes: 4 additions & 4 deletions tests/tools/onert_train/src/onert_train.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ int main(const int argc, char **argv)
auto convertLossType = [](int type) {
switch (type)
{
case 0:
return NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR;
case 1:
return NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR;
case 2:
return NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY;
default:
std::cerr << "E: not supported loss type" << std::endl;
Expand All @@ -152,9 +152,9 @@ int main(const int argc, char **argv)
auto convertOptType = [](int type) {
switch (type)
{
case 0:
return NNFW_TRAIN_OPTIMIZER_SGD;
case 1:
return NNFW_TRAIN_OPTIMIZER_SGD;
case 2:
return NNFW_TRAIN_OPTIMIZER_ADAM;
default:
std::cerr << "E: not supported optimizer type" << std::endl;
Expand Down

0 comments on commit 33f5624

Please sign in to comment.