Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Vyacheslav Bazhenov committed Jan 20, 2025
1 parent 9b17a98 commit 05b2e7a
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 21 deletions.
19 changes: 16 additions & 3 deletions compiler/luci/pass/include/luci/Pass/QuantizationParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,23 @@ struct LayerInfo
QuantizationGranularity granularity;
};

enum struct QuantizationAlgorithm
enum struct QuantizationAlgorithmType
{
Common = 0,
MinimumMSE = 1,
Base = 0,
MinimumMSE = 1
};

struct QuantizationAlgorithmParams
{
QuantizationAlgorithmType type = QuantizationAlgorithmType::Base;

// Params of Golden-section search algorithm
// Number of iterations of Golden-section search
size_t iterations_num = 100;

// scaling_factor_max = scaling_factor_base * (1 + range)
// scaling_factor_min = scaling_factor_base * (1 - range)
float range = 0.1;
};

} // namespace luci
Expand Down
6 changes: 3 additions & 3 deletions compiler/luci/pass/include/luci/Pass/QuantizeWeightsPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class QuantizeWeightsPass : public logo::Pass
loco::DataType input_model_dtype = loco::DataType::Unknown;
loco::DataType output_model_dtype = loco::DataType::Unknown;
QuantizationGranularity granularity = QuantizationGranularity::ChannelWise;
QuantizationAlgorithm algorithm = QuantizationAlgorithm::Common;
QuantizationAlgorithmParams algorithm_params;
};

public:
Expand All @@ -48,14 +48,14 @@ class QuantizeWeightsPass : public logo::Pass

public:
QuantizeWeightsPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype,
QuantizationGranularity granularity, QuantizationAlgorithm algorithm)
QuantizationGranularity granularity, QuantizationAlgorithmParams algorithm_params)
{
_ctx = std::make_unique<Context>();
{
_ctx->input_model_dtype = input_model_dtype;
_ctx->output_model_dtype = output_model_dtype;
_ctx->granularity = granularity;
_ctx->algorithm = algorithm;
_ctx->algorithm_params = algorithm_params;
}
}
virtual const char *name(void) const { return "luci::QuantizeWeightsPass"; }
Expand Down
6 changes: 3 additions & 3 deletions compiler/luci/pass/src/QuantizeWeightsOnly.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ namespace luci
struct QuantizeWeightsOnly final : public luci::CircleNodeMutableVisitor<void>
{
QuantizeWeightsOnly(loco::DataType input, loco::DataType output, QuantizationGranularity gr,
QuantizationAlgorithm alg)
: input_type(input), output_type(output), granularity(gr), algorithm(alg)
QuantizationAlgorithmParams alg_par)
: input_type(input), output_type(output), granularity(gr), algorithm_params(alg_par)
{
}

loco::DataType input_type;
loco::DataType output_type;
QuantizationGranularity granularity;
QuantizationAlgorithm algorithm;
QuantizationAlgorithmParams algorithm_params;

private:
void quantize_weights(luci::CircleConst *weights);
Expand Down
2 changes: 1 addition & 1 deletion compiler/luci/pass/src/QuantizeWeightsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ bool QuantizeWeightsPass::run(loco::Graph *g)
{
auto circle_node = loco::must_cast<luci::CircleNode *>(node);
QuantizeWeightsOnly qw(_ctx->input_model_dtype, _ctx->output_model_dtype, _ctx->granularity,
_ctx->algorithm);
_ctx->algorithm_params);
circle_node->accept(&qw);
}

Expand Down
33 changes: 22 additions & 11 deletions compiler/luci/pass/src/QuantizeWeightsPass.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ struct QuantizeWeightsPassTest : public ::testing::Test
weight->dtype(loco::DataType::FLOAT32);
weight->shape({C, H, W, C});
weight->size<loco::DataType::FLOAT32>(C * H * W * C);
for(uint32_t i = 0; i < weight->size<loco::DataType::FLOAT32>(); ++i)
{
weight->at<loco::DataType::FLOAT32>(i) = 1.0 * i;
}
conv->filter(weight);
conv->padding(luci::Padding::SAME);
conv->fusedActivationFunction(luci::FusedActFunc::NONE);
Expand All @@ -86,21 +90,24 @@ struct QuantizeWeightsPassTest : public ::testing::Test

TEST_F(QuantizeWeightsPassTest, name)
{
luci::QuantizationAlgorithmParams params;
params.type = luci::QuantizationAlgorithmType::Base;
luci::QuantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::S8,
luci::QuantizationGranularity::ChannelWise,
luci::QuantizationAlgorithm::Common);
luci::QuantizationGranularity::ChannelWise, params);
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}

TEST_F(QuantizeWeightsPassTest, name_ctx)
{
luci::QuantizationAlgorithmParams params;
params.type = luci::QuantizationAlgorithmType::Base;
auto ctx = std::make_unique<luci::QuantizeWeightsPass::Context>();
{
ctx->input_model_dtype = loco::DataType::FLOAT32;
ctx->output_model_dtype = loco::DataType::S8;
ctx->granularity = luci::QuantizationGranularity::ChannelWise;
ctx->algorithm = luci::QuantizationAlgorithm::Common;
ctx->algorithm_params = params;
}

luci::QuantizeWeightsPass pass(std::move(ctx));
Expand All @@ -110,34 +117,38 @@ TEST_F(QuantizeWeightsPassTest, name_ctx)

TEST_F(QuantizeWeightsPassTest, run_minimum_mse_s8)
{
luci::QuantizationAlgorithmParams params;
params.type = luci::QuantizationAlgorithmType::MinimumMSE;
luci::QuantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::S8,
luci::QuantizationGranularity::ChannelWise,
luci::QuantizationAlgorithm::MinimumMSE);
luci::QuantizationGranularity::ChannelWise, params);
pass.run(&_g);
}

TEST_F(QuantizeWeightsPassTest, run_input_U8_mse_NEG)
{
luci::QuantizationAlgorithmParams params;
params.type = luci::QuantizationAlgorithmType::MinimumMSE;
luci::QuantizeWeightsPass pass(loco::DataType::U8, loco::DataType::S8,
luci::QuantizationGranularity::ChannelWise,
luci::QuantizationAlgorithm::MinimumMSE);
luci::QuantizationGranularity::ChannelWise, params);
EXPECT_THROW(pass.run(&_g), std::runtime_error);
}

TEST_F(QuantizeWeightsPassTest, run_input_U8_NEG)
{
loco::Graph g;
luci::QuantizationAlgorithmParams params;
params.type = luci::QuantizationAlgorithmType::Base;
luci::QuantizeWeightsPass pass(loco::DataType::U8, loco::DataType::S8,
luci::QuantizationGranularity::ChannelWise,
luci::QuantizationAlgorithm::Common);
luci::QuantizationGranularity::ChannelWise, params);
EXPECT_THROW(pass.run(&_g), std::runtime_error);
}

TEST_F(QuantizeWeightsPassTest, run_output_f32_NEG)
{
loco::Graph g;
luci::QuantizationAlgorithmParams params;
params.type = luci::QuantizationAlgorithmType::Base;
luci::QuantizeWeightsPass pass(loco::DataType::FLOAT32, loco::DataType::FLOAT32,
luci::QuantizationGranularity::ChannelWise,
luci::QuantizationAlgorithm::Common);
luci::QuantizationGranularity::ChannelWise, params);
EXPECT_THROW(pass.run(&_g), std::runtime_error);
}

0 comments on commit 05b2e7a

Please sign in to comment.