Skip to content

Commit

Permalink
rebase fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Artem Balyshev committed Jun 17, 2024
1 parent 01d7bae commit 130afb4
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion onert-micro/eval-driver/TrainingDriver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ int entry(int argc, char **argv)
onert_micro::OMTrainingContext train_context;
train_context.batch_size = BATCH_SIZE;
train_context.num_of_train_layers = num_train_layers;
train_context.lambda = lambda;
train_context.learning_rate = lambda;
train_context.loss = loss;
train_context.optimizer = train_optim;
train_context.beta = beta;
Expand Down
1 change: 1 addition & 0 deletions onert-micro/onert-micro/include/OMConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ struct OMTrainingContext
float epsilon = 10e-8;
uint32_t num_step = 0;
uint32_t num_epoch = 0;
uint32_t epochs = 0;
};

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class BostonHousingTaskTest : public ::testing::Test
onert_micro::OMTrainingContext train_context;
train_context.batch_size = batch_size;
train_context.num_of_train_layers = num_train_layers;
train_context.lambda = lambda;
train_context.learning_rate = lambda;
train_context.loss = loss;
train_context.optimizer = train_optim;
train_context.beta = beta;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ OMStatus Adam::updateWeights(const onert_micro::OMTrainingContext &training_conf
return UnknownError;

float *f_weight_data = reinterpret_cast<float *>(weight_data);
float lambda = training_config.lambda;
float lambda = training_config.learning_rate;
float num_step = static_cast<float>(training_config.num_step);
float beta_in_pow_batch = std::pow(beta, num_step);
float beta_square_in_pow_batch = std::pow(beta_squares, num_step);
Expand Down
2 changes: 1 addition & 1 deletion onert-micro/onert-micro/src/train/train_optimizers/SGD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ OMStatus SGD::updateWeights(const onert_micro::OMTrainingContext &training_confi
return UnknownError;

float *f_weight_data = reinterpret_cast<float *>(weight_data);
float lambda = training_config.lambda;
float lambda = training_config.learning_rate;
const uint32_t batch_size = training_config.batch_size;
for (uint32_t i = 0; i < flat_size; ++i)
{
Expand Down

0 comments on commit 130afb4

Please sign in to comment.