diff --git a/tests/tools/onert_train/src/args.cc b/tests/tools/onert_train/src/args.cc index b8fa920f0da..f26a291b0c2 100644 --- a/tests/tools/onert_train/src/args.cc +++ b/tests/tools/onert_train/src/args.cc @@ -228,10 +228,10 @@ void Args::Initialize(void) "Loss type\n" "0: MEAN_SQUARED_ERROR (default)\n" "1: CATEGORICAL_CROSSENTROPY\n") - ("loss_reduction_type", po::value()->default_value(0)->notifier([&] (const auto &v) { _loss_reduction_type = v; }), + ("loss_reduction_type", po::value()->default_value(1)->notifier([&] (const auto &v) { _loss_reduction_type = v; }), "Loss Reduction type\n" - "0: SUM_OVER_BATCH_SIZE(default)\n" - "1: SUM\n") + "1: SUM_OVER_BATCH_SIZE(default)\n" + "2: SUM\n") ("optimizer", po::value()->default_value(0)->notifier([&] (const auto &v) { _optimizer_type = v; }), "Optimizer type\n" "0: SGD (default)\n" diff --git a/tests/tools/onert_train/src/onert_train.cc b/tests/tools/onert_train/src/onert_train.cc index 858662b049c..84ea3671b8b 100644 --- a/tests/tools/onert_train/src/onert_train.cc +++ b/tests/tools/onert_train/src/onert_train.cc @@ -137,9 +137,9 @@ int main(const int argc, char **argv) auto convertLossReductionType = [](int type) { switch (type) { - case 0: - return NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE; case 1: + return NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE; + case 2: return NNFW_TRAIN_LOSS_REDUCTION_SUM; default: std::cerr << "E: not supported loss reduction type" << std::endl;