From bb706761a4992245e35ca67fa413ee6b44b6d65a Mon Sep 17 00:00:00 2001 From: Hyeongseok Oh Date: Wed, 31 Jul 2024 13:51:35 +0900 Subject: [PATCH] [onert/onert_train] Use arser for onert_train (#13562) This commit updates onert_train to use arser for argument parsing instead of boost::program_options. ONE-DCO-1.0-Signed-off-by: Hyeongseok Oh --- tests/tools/onert_train/CMakeLists.txt | 4 +- tests/tools/onert_train/src/args.cc | 348 ++++++++++----------- tests/tools/onert_train/src/args.h | 7 +- tests/tools/onert_train/src/onert_train.cc | 6 - 4 files changed, 176 insertions(+), 189 deletions(-) diff --git a/tests/tools/onert_train/CMakeLists.txt b/tests/tools/onert_train/CMakeLists.txt index 75dbd2fe352..ae3140d0e71 100644 --- a/tests/tools/onert_train/CMakeLists.txt +++ b/tests/tools/onert_train/CMakeLists.txt @@ -14,7 +14,6 @@ list(APPEND ONERT_TRAIN_SRCS "src/rawformatter.cc") list(APPEND ONERT_TRAIN_SRCS "src/rawdataloader.cc") list(APPEND ONERT_TRAIN_SRCS "src/metrics.cc") -nnfw_find_package(Boost REQUIRED program_options) nnfw_find_package(HDF5 QUIET) if (HDF5_FOUND) @@ -32,11 +31,10 @@ else() endif(HDF5_FOUND) target_include_directories(onert_train PRIVATE src) -target_include_directories(onert_train PRIVATE ${Boost_INCLUDE_DIRS}) target_link_libraries(onert_train nnfw_lib_tflite jsoncpp) target_link_libraries(onert_train nnfw-dev) -target_link_libraries(onert_train ${Boost_PROGRAM_OPTIONS_LIBRARY}) +target_link_libraries(onert_train arser) target_link_libraries(onert_train nnfw_lib_benchmark) install(TARGETS onert_train DESTINATION bin) diff --git a/tests/tools/onert_train/src/args.cc b/tests/tools/onert_train/src/args.cc index dd0ce55403d..2970732d829 100644 --- a/tests/tools/onert_train/src/args.cc +++ b/tests/tools/onert_train/src/args.cc @@ -19,7 +19,7 @@ #include "misc/to_underlying.h" #include -#include +#include #include #include #include @@ -140,209 +140,207 @@ Args::Args(const int argc, char **argv) void Args::Initialize(void) { - auto process_nnpackage = [&](const std::string &package_filename) { - _package_filename = package_filename; + _arser.add_argument("path").type(arser::DataType::STR).help("NN Package or NN Modelfile path"); + + _arser.add_argument("--version") + .nargs(0) + .default_value(false) + .help("Print version and exit immediately"); + _arser.add_argument("--nnpackage") + .type(arser::DataType::STR) + .help("NN Package file(directory) name"); + _arser.add_argument("--modelfile").type(arser::DataType::STR).help("NN Model filename"); + _arser.add_argument("--export_circle").type(arser::DataType::STR).help("Path to export circle"); + _arser.add_argument("--export_circleplus") + .type(arser::DataType::STR) + .help("Path to export circle+"); + _arser.add_argument("--load_input:raw") + .type(arser::DataType::STR) + .help({"NN Model Raw Input data file", "The datafile must have data for each input number.", + "If there are 3 inputs, the data of input0 must exist as much as data_length,", + "and the data for input1 and input2 must be held sequentially as data_length."}); + _arser.add_argument("--load_expected:raw") + .type(arser::DataType::STR) + .help({"NN Model Raw Expected data file", "(Same data policy with load_input:raw)"}); + _arser.add_argument("--mem_poll", "-m") + .nargs(0) + .default_value(false) + .help("Check memory polling (default: false)"); + _arser.add_argument("--epoch") + .type(arser::DataType::INT32) + .default_value(5) + .help("Epoch number (default: 5)"); + _arser.add_argument("--batch_size") + .type(arser::DataType::INT32) + .help({"Batch size", "If not given, model's hyper parameter is used"}); + _arser.add_argument("--learning_rate") + .type(arser::DataType::FLOAT) + .help({"Learning rate", "If not given, model's hyper parameter is used"}); + _arser.add_argument("--loss").type(arser::DataType::INT32).help("Loss type"); + _arser.add_argument("--loss_reduction_type") + .type(arser::DataType::INT32) + .help("Loss reduction type"); + _arser.add_argument("--optimizer").type(arser::DataType::INT32).help("Optimizer type"); + _arser.add_argument("--metric") + .type(arser::DataType::INT32) + .default_value(-1) + .help({"Metric type", "Simply calculates the metric value using the variables (default: none)", + "0: CATEGORICAL_ACCURACY"}); + _arser.add_argument("--validation_split") + .type(arser::DataType::FLOAT) + .default_value(0.0f) + .help("Float between 0 and 1(0 < float < 1). Fraction of the training data to be used as " + "validation data."); + _arser.add_argument("--verbose_level", "-v") + .type(arser::DataType::INT32) + .default_value(0) + .help({"Verbose level", "0: prints the only result. Messages btw run don't print", + "1: prints result and message btw run", "2: prints all of messages to print"}); + _arser.add_argument("--output_sizes") + .type(arser::DataType::STR) + .help({"The output buffer size in JSON 1D array", + "If not given, the model's output sizes are used", + "e.g. '[0, 40, 2, 80]' to set 0th tensor to 40 and 2nd tensor to 80."}); + _arser.add_argument("--num_of_trainable_ops") + .type(arser::DataType::INT32) + .help({"Number of the layers to be trained from the back of the model.", + "\"-1\" means that all layers will be trained.", + "\"0\" means that no layer will be trained."}); +} + +void Args::Parse(const int argc, char **argv) +{ + try + { + _arser.parse(argc, argv); + + if (_arser.get("--version")) + { + _print_version = true; + return; + } - std::cerr << "Package Filename " << _package_filename << std::endl; - checkPackage(package_filename); - }; + // Require modelfile, nnpackage, or path + if (!_arser["--nnpackage"] && !_arser["--modelfile"] && !_arser["path"]) + { + std::cerr << "Require one of options modelfile, nnpackage, or path." << std::endl; + exit(1); + } - auto process_modelfile = [&](const std::string &model_filename) { - _model_filename = model_filename; + // Cannot use both single model file and nnpackage at once + if (_arser["--nnpackage"] && _arser["--modelfile"]) + { + std::cerr << "Cannot use both single model file and nnpackage at once." << std::endl; + exit(1); + } - std::cerr << "Model Filename " << _model_filename << std::endl; - checkModelfile(model_filename); + if (_arser["--nnpackage"]) + { + std::cout << "Package Filename " << _package_filename << std::endl; + _package_filename = _arser.get("--nnpackage"); + } - _use_single_model = true; - }; + if (_arser["--modelfile"]) + { + std::cout << "Model Filename " << _model_filename << std::endl; + _model_filename = _arser.get("--modelfile"); + } - auto process_path = [&](const std::string &path) { - struct stat sb; - if (stat(path.c_str(), &sb) == 0) + if (_arser["path"]) { - if (sb.st_mode & S_IFDIR) + auto path = _arser.get("path"); + struct stat sb; + if (stat(path.c_str(), &sb) == 0) { - _package_filename = path; - checkPackage(path); - std::cerr << "Package Filename " << path << std::endl; + if (sb.st_mode & S_IFDIR) + { + _package_filename = path; + checkPackage(path); + std::cout << "Package Filename " << path << std::endl; + } + else + { + _model_filename = path; + checkModelfile(path); + std::cout << "Model Filename " << path << std::endl; + _use_single_model = true; + } } else { - _model_filename = path; - checkModelfile(path); - std::cerr << "Model Filename " << path << std::endl; - _use_single_model = true; + std::cerr << "Cannot find: " << path << "\n"; + exit(1); } } - else + + if (_arser["--export_circle"]) + _export_circle_filename = _arser.get("--export_circle"); + if (_arser["--export_circleplus"]) + _export_circleplus_filename = _arser.get("--export_circleplus"); + if (_arser["--load_input:raw"]) { - std::cerr << "Cannot find: " << path << "\n"; - exit(1); + _load_raw_input_filename = _arser.get("--load_input:raw"); + checkModelfile(_load_raw_input_filename); } - }; - - auto process_export_circle = [&](const std::string &path) { _export_circle_filename = path; }; - auto process_export_circleplus = [&](const std::string &path) { - _export_circleplus_filename = path; - }; - - auto process_load_raw_inputfile = [&](const std::string &input_filename) { - _load_raw_input_filename = input_filename; - - std::cerr << "Model Input Filename " << _load_raw_input_filename << std::endl; - checkModelfile(_load_raw_input_filename); - }; - - auto process_load_raw_expectedfile = [&](const std::string &expected_filename) { - _load_raw_expected_filename = expected_filename; - - std::cerr << "Model Expected Filename " << _load_raw_expected_filename << std::endl; - checkModelfile(_load_raw_expected_filename); - }; - - auto process_validation_split = [&](const float v) { - if (v < 0.f || v > 1.f) + if (_arser["--load_expected:raw"]) { - std::cerr << "Invalid validation_split. Float between 0 and 1." << std::endl; - exit(1); + _load_raw_expected_filename = _arser.get("--load_expected:raw"); + checkModelfile(_load_raw_expected_filename); } - _validation_split = v; - }; - auto process_output_sizes = [&](const std::string &output_sizes_json_str) { - Json::Value root; - Json::Reader reader; - if (!reader.parse(output_sizes_json_str, root, false)) + _mem_poll = _arser.get("--mem_poll"); + _epoch = _arser.get("--epoch"); + + if (_arser["--batch_size"]) + _batch_size = _arser.get("--batch_size"); + if (_arser["--learning_rate"]) + _learning_rate = _arser.get("--learning_rate"); + if (_arser["--loss"]) + _loss_type = checkValidation("loss", valid_loss, _arser.get("--loss")); + if (_arser["--loss_reduction_type"]) + _loss_reduction_type = checkValidation("loss_reduction_type", valid_loss_rdt, + _arser.get("--loss_reduction_type")); + if (_arser["--optimizer"]) + _optimizer_type = checkValidation("optimizer", valid_optim, _arser.get("--optimizer")); + _metric_type = _arser.get("--metric"); + + _validation_split = _arser.get("--validation_split"); + if (_validation_split < 0.f || _validation_split > 1.f) { - std::cerr << "Invalid JSON format for output_sizes \"" << output_sizes_json_str << "\"\n"; + std::cerr << "Invalid validation_split. Float between 0 and 1." << std::endl; exit(1); } - auto arg_map = argArrayToMap(root); - for (auto &pair : arg_map) + _verbose_level = _arser.get("--verbose_level"); + + if (_arser["--output_sizes"]) { - uint32_t key = pair.first; - Json::Value &val_json = pair.second; - if (!val_json.isUInt()) + auto output_sizes_json_str = _arser.get("--output_sizes"); + Json::Value root; + Json::Reader reader; + if (!reader.parse(output_sizes_json_str, root, false)) { - std::cerr << "All the values in `output_sizes` must be unsigned integers\n"; + std::cerr << "Invalid JSON format for output_sizes \"" << output_sizes_json_str << "\"\n"; exit(1); } - uint32_t val = val_json.asUInt(); - _output_sizes[key] = val; - } - }; - - // General options - po::options_description general("General options", 100); - - // clang-format off - general.add_options() - ("help,h", "Print available options") - ("version", "Print version and exit immediately") - ("nnpackage", po::value()->notifier(process_nnpackage), "NN Package file(directory) name") - ("modelfile", po::value()->notifier(process_modelfile), "NN Model filename") - ("path", po::value()->notifier(process_path), "NN Package or NN Modelfile path") - ("export_circle", po::value()->notifier(process_export_circle), "Path to export circle") - ("export_circleplus", po::value()->notifier(process_export_circleplus), "Path to export circle+") - ("load_input:raw", po::value()->notifier(process_load_raw_inputfile), - "NN Model Raw Input data file\n" - "The datafile must have data for each input number.\n" - "If there are 3 inputs, the data of input0 must exist as much as data_length, " - "and the data for input1 and input2 must be held sequentially as data_length." - ) - ("load_expected:raw", po::value()->notifier(process_load_raw_expectedfile), - "NN Model Raw Expected data file\n" - "(Same data policy with load_input:raw)" - ) - ("mem_poll,m", po::value()->default_value(false)->notifier([&](const auto &v) { _mem_poll = v; }), "Check memory polling (default: false)") - ("epoch", po::value()->default_value(5)->notifier([&](const auto &v) { _epoch = v; }), "Epoch number (default: 5)") - ("batch_size", po::value()->notifier([&](const auto &v) { _batch_size = v; }), - "Batch size\n" - "If not given, model's hyper parameter is used") - ("learning_rate", po::value()->notifier([&](const auto &v) { _learning_rate = v; }), - "Learning rate\n" - "If not given, model's hyper parameter is used") - ("loss", po::value() - ->notifier([&](const auto& v){_loss_type = checkValidation("loss", valid_loss, v);}), - genHelpMsg("Loss type", valid_loss).c_str() - ) - ("loss_reduction_type", po::value() - ->notifier([&](const auto &v){_loss_reduction_type = checkValidation("loss_reduction_type", valid_loss_rdt, v);}), - genHelpMsg("Loss Reduction type", valid_loss_rdt).c_str() - ) - ("optimizer", po::value() - ->notifier([&](const auto& v){_optimizer_type = checkValidation("optimizer", valid_optim, v);}), - genHelpMsg("Optimizer type", valid_optim).c_str() - ) - ("metric", po::value()->default_value(-1)->notifier([&] (const auto &v) { _metric_type = v; }), - "Metric type\n" - "Simply calculates the metric value using the variables (default: none)\n" - "0: CATEGORICAL_ACCURACY") - ("validation_split", po::value()->default_value(0.0f)->notifier(process_validation_split), - "Float between 0 and 1(0 < float < 1). Fraction of the training data to be used as validation data.") - ("verbose_level,v", po::value()->default_value(0)->notifier([&](const auto &v) { _verbose_level = v; }), - "Verbose level\n" - "0: prints the only result. Messages btw run don't print\n" - "1: prints result and message btw run\n" - "2: prints all of messages to print") - ("output_sizes", po::value()->notifier(process_output_sizes), - "The output buffer size in JSON 1D array\n" - "If not given, the model's output sizes are used\n" - "e.g. '[0, 40, 2, 80]' to set 0th tensor to 40 and 2nd tensor to 80.") - ("num_of_trainable_ops", po::value()->notifier([&](const auto &ops_num) { _num_of_trainable_ops = ops_num; }), - "Number of the layers to be trained from the back of the model. \"-1\" means that all layers will be trained. " - "\"0\" means that no layer will be trained.") - ; - // clang-format on - - _options.add(general); - _positional.add("path", -1); -} - -void Args::Parse(const int argc, char **argv) -{ - po::variables_map vm; - po::store(po::command_line_parser(argc, argv).options(_options).positional(_positional).run(), - vm); - if (vm.count("help")) - { - std::cout << "onert_train\n\n"; - std::cout << "Usage: " << argv[0] << " [model path] []\n\n"; - std::cout << _options; - std::cout << "\n"; - - exit(0); - } - - if (vm.count("version")) - { - _print_version = true; - return; - } - - { - auto conflicting_options = [&](const std::string &o1, const std::string &o2) { - if ((vm.count(o1) && !vm[o1].defaulted()) && (vm.count(o2) && !vm[o2].defaulted())) + auto arg_map = argArrayToMap(root); + for (auto &pair : arg_map) { - throw boost::program_options::error(std::string("Two options '") + o1 + "' and '" + o2 + - "' cannot be given at once."); + uint32_t key = pair.first; + Json::Value &val_json = pair.second; + if (!val_json.isUInt()) + { + std::cerr << "All the values in `output_sizes` must be unsigned integers\n"; + exit(1); + } + uint32_t val = val_json.asUInt(); + _output_sizes[key] = val; } - }; - - // Cannot use both single model file and nnpackage at once - conflicting_options("modelfile", "nnpackage"); - - // Require modelfile, nnpackage, or path - if (!vm.count("modelfile") && !vm.count("nnpackage") && !vm.count("path")) - throw boost::program_options::error( - std::string("Require one of options modelfile, nnpackage, or path.")); - } + } - try - { - po::notify(vm); + if (_arser["--num_of_trainable_ops"]) + _num_of_trainable_ops = _arser.get("--num_of_trainable_ops"); } catch (const std::bad_cast &e) { diff --git a/tests/tools/onert_train/src/args.h b/tests/tools/onert_train/src/args.h index f9d840543a2..e260584b1ae 100644 --- a/tests/tools/onert_train/src/args.h +++ b/tests/tools/onert_train/src/args.h @@ -22,13 +22,11 @@ #include #include #include -#include +#include #include "nnfw_experimental.h" #include "types.h" -namespace po = boost::program_options; - namespace onert_train { @@ -97,8 +95,7 @@ class Args }; private: - po::positional_options_description _positional; - po::options_description _options; + arser::Arser _arser; std::string _package_filename; std::string _model_filename; diff --git a/tests/tools/onert_train/src/onert_train.cc b/tests/tools/onert_train/src/onert_train.cc index 5f66e9c0823..fefb4183328 100644 --- a/tests/tools/onert_train/src/onert_train.cc +++ b/tests/tools/onert_train/src/onert_train.cc @@ -28,7 +28,6 @@ #include "rawdataloader.h" #include "metrics.h" -#include #include #include #include @@ -384,11 +383,6 @@ int main(const int argc, char **argv) return 0; } - catch (boost::program_options::error &e) - { - std::cerr << "E: " << e.what() << std::endl; - exit(-1); - } catch (std::runtime_error &e) { std::cerr << "E: Fail to run by runtime error:" << e.what() << std::endl;