Skip to content

Commit

Permalink
[onert/test] Get training info from session in onert_train
Browse files Browse the repository at this point in the history
This PR adds logic of getting training information from the session in onert_train.

ONE-DCO-1.0-Signed-off-by: SeungHui Youn <[email protected]>
  • Loading branch information
zetwhite committed Jan 5, 2024
1 parent f165eaf commit 4164b7c
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tests/tools/onert_train/src/onert_train.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,11 @@ int main(const int argc, char **argv)
return "accuracy";
};

// prepare training info
// get training information
nnfw_train_info tri;
NNPR_ENSURE_STATUS(nnfw_train_get_traininfo(session, &tri));

// overwrite training information using the arguments
tri.batch_size = args.getBatchSize();
tri.learning_rate = args.getLearningRate();
tri.loss_info.loss = convertLossType(args.getLossType());
Expand All @@ -186,6 +189,7 @@ int main(const int argc, char **argv)
std::cout << tri;
std::cout << "========================" << std::endl;

// set training information
NNPR_ENSURE_STATUS(nnfw_train_set_traininfo(session, &tri));

// prepare execution
Expand Down

0 comments on commit 4164b7c

Please sign in to comment.