Skip to content

Commit

Permalink
introduced CirclePlusGen::markAllOpsAsTrainable
Browse files Browse the repository at this point in the history
  • Loading branch information
mbencer committed Jun 14, 2024
1 parent 9a9ee93 commit 2c0b710
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 32 deletions.
5 changes: 5 additions & 0 deletions tests/nnfw_api/lib/CircleGen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ uint32_t CircleGen::nextSubgraph()
return ind;
}

uint32_t CircleGen::getCurrentSubgraphOpsSize() const
{
return _subgraph_contexts.back().operators.size();
}

CircleBuffer CircleGen::finish()
{
std::vector<flatbuffers::Offset<circle::SubGraph>> subgraphs;
Expand Down
1 change: 1 addition & 0 deletions tests/nnfw_api/lib/CircleGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class CircleGen
uint32_t addTensor(const TensorParams &params, const SparsityParams &sp);
void setInputsAndOutputs(const std::vector<int> &inputs, const std::vector<int> &outputs);
uint32_t nextSubgraph();
uint32_t getCurrentSubgraphOpsSize() const;
CircleBuffer finish();

// ===== Add Operator methods begin (SORTED IN ALPHABETICAL ORDER) =====
Expand Down
8 changes: 8 additions & 0 deletions tests/nnfw_api/lib/CirclePlusGen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "CirclePlusGen.h"

#include <numeric>

CircleBuffers CirclePlusGen::finish()
{
CircleBuffers cbufs;
Expand All @@ -26,6 +28,12 @@ CircleBuffers CirclePlusGen::finish()

void CirclePlusGen::addTrainInfo(const TrainInfo &info) { _info = info; }

void CirclePlusGen::markAllOpsAsTrainable()
{
_info.trainable_ops.resize(getCurrentSubgraphOpsSize());
std::iota(std::begin(_info.trainable_ops), std::end(_info.trainable_ops), 0);
}

CircleBuffer CirclePlusGen::createModelTraining()
{
circle::OptimizerOptions optimizer_opt_type = circle::OptimizerOptions::OptimizerOptions_NONE;
Expand Down
2 changes: 2 additions & 0 deletions tests/nnfw_api/lib/CirclePlusGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class CirclePlusGen : public CircleGen
public:
void addTrainInfo(const TrainInfo &info);

void markAllOpsAsTrainable();

// NOTE: this is overriden from CircleGen::finish()
CircleBuffers finish();

Expand Down
17 changes: 10 additions & 7 deletions tests/nnfw_api/src/GenModelTests/one_op_trains/Conv2D.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

namespace
{
CircleBuffers gen_conv2d_test_model(const std::vector<int32_t> &trainable_ops)
CirclePlusGen gen_conv2d_test_model()
{
CirclePlusGen cgen;

Expand All @@ -34,17 +34,19 @@ CircleBuffers gen_conv2d_test_model(const std::vector<int32_t> &trainable_ops)

float learning_rate = 0.01f;
int32_t batch_size = 1;
cgen.addTrainInfo(
{circle::Optimizer::Optimizer_SGD, learning_rate, circle::LossFn::LossFn_MEAN_SQUARED_ERROR,
circle::LossReductionType::LossReductionType_SumOverBatchSize, batch_size, trainable_ops});
cgen.addTrainInfo({circle::Optimizer::Optimizer_SGD, learning_rate,
circle::LossFn::LossFn_MEAN_SQUARED_ERROR,
circle::LossReductionType::LossReductionType_SumOverBatchSize, batch_size});

return cgen.finish();
return cgen;
}
} // namespace

TEST_F(GenModelTrain, OneOp_Conv2D_training_enabled)
{
_context = std::make_unique<GenModelTrainContext>(gen_conv2d_test_model({0}));
auto cgen = gen_conv2d_test_model();
cgen.markAllOpsAsTrainable();
_context = std::make_unique<GenModelTrainContext>(cgen.finish());
_context->addTrainCase(
uniformTCD<float>({{{4, 0, -5, 1, 0, 4, -1, 1, -1, -3, 3, -2, -4,
1, -2, 2, 4, -4, 2, 2, 0, 4, -1, -2, 4}}}, // input dataset
Expand All @@ -62,7 +64,8 @@ TEST_F(GenModelTrain, OneOp_Conv2D_training_enabled)

TEST_F(GenModelTrain, OneOp_Conv2D_training_disabled)
{
_context = std::make_unique<GenModelTrainContext>(gen_conv2d_test_model({}));
auto cgen = gen_conv2d_test_model();
_context = std::make_unique<GenModelTrainContext>(cgen.finish());
_context->addTrainCase(
uniformTCD<float>({{{4, 0, -5, 1, 0, 4, -1, 1, -1, -3, 3, -2, -4,
1, -2, 2, 4, -4, 2, 2, 0, 4, -1, -2, 4}}}, // input dataset
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@ TEST_F(GenModelTrain, OneOp_FullyConnected)

float learning_rate = 0.01f;
int32_t batch_size = 1;
cgen.addTrainInfo({circle::Optimizer::Optimizer_SGD,
learning_rate,
cgen.addTrainInfo({circle::Optimizer::Optimizer_SGD, learning_rate,
circle::LossFn::LossFn_MEAN_SQUARED_ERROR,
circle::LossReductionType::LossReductionType_SumOverBatchSize,
batch_size,
{0}});
circle::LossReductionType::LossReductionType_SumOverBatchSize, batch_size});
cgen.markAllOpsAsTrainable();

_context = std::make_unique<GenModelTrainContext>(cgen.finish());
_context->addTrainCase(
Expand Down Expand Up @@ -67,12 +65,10 @@ TEST_F(GenModelTrain, OneOp_FullyConnected_OptionalBias)

float learning_rate = 0.01f;
int32_t batch_size = 2;
cgen.addTrainInfo({circle::Optimizer::Optimizer_SGD,
learning_rate,
cgen.addTrainInfo({circle::Optimizer::Optimizer_SGD, learning_rate,
circle::LossFn::LossFn_MEAN_SQUARED_ERROR,
circle::LossReductionType::LossReductionType_SumOverBatchSize,
batch_size,
{0}});
circle::LossReductionType::LossReductionType_SumOverBatchSize, batch_size});
cgen.markAllOpsAsTrainable();

_context = std::make_unique<GenModelTrainContext>(cgen.finish());
_context->addTrainCase(
Expand Down Expand Up @@ -104,12 +100,10 @@ TEST_F(GenModelTrain, neg_OneOp_FullyConnected_FourOperand)

float learning_rate = 0.01f;
int32_t batch_size = 1;
cgen.addTrainInfo({circle::Optimizer::Optimizer_SGD,
learning_rate,
cgen.addTrainInfo({circle::Optimizer::Optimizer_SGD, learning_rate,
circle::LossFn::LossFn_MEAN_SQUARED_ERROR,
circle::LossReductionType::LossReductionType_SumOverBatchSize,
batch_size,
{0}});
circle::LossReductionType::LossReductionType_SumOverBatchSize, batch_size});
cgen.markAllOpsAsTrainable();

_context = std::make_unique<GenModelTrainContext>(cgen.finish());
_context->setBackends({"train"});
Expand All @@ -134,12 +128,10 @@ TEST_F(GenModelTrain, neg_OneOp_FullyConnected_InvalidWeightShape)

float learning_rate = 0.01f;
int32_t batch_size = 1;
cgen.addTrainInfo({circle::Optimizer::Optimizer_SGD,
learning_rate,
cgen.addTrainInfo({circle::Optimizer::Optimizer_SGD, learning_rate,
circle::LossFn::LossFn_MEAN_SQUARED_ERROR,
circle::LossReductionType::LossReductionType_SumOverBatchSize,
batch_size,
{0}});
circle::LossReductionType::LossReductionType_SumOverBatchSize, batch_size});
cgen.markAllOpsAsTrainable();

_context = std::make_unique<GenModelTrainContext>(cgen.finish());
_context->setBackends({"train"});
Expand All @@ -161,12 +153,10 @@ TEST_F(GenModelTrain, neg_OneOp_FullyConnected_NoBias)

float learning_rate = 0.01f;
int32_t batch_size = 1;
cgen.addTrainInfo({circle::Optimizer::Optimizer_SGD,
learning_rate,
cgen.addTrainInfo({circle::Optimizer::Optimizer_SGD, learning_rate,
circle::LossFn::LossFn_MEAN_SQUARED_ERROR,
circle::LossReductionType::LossReductionType_SumOverBatchSize,
batch_size,
{0}});
circle::LossReductionType::LossReductionType_SumOverBatchSize, batch_size});
cgen.markAllOpsAsTrainable();

_context = std::make_unique<GenModelTrainContext>(cgen.finish());
_context->setBackends({"train"});
Expand Down

0 comments on commit 2c0b710

Please sign in to comment.