Skip to content

Commit

Permalink
[luci/pass] Remove Unused Constructor and Update Test Cases
Browse files Browse the repository at this point in the history
This commit removes the unused constructor `QuantizeDequantizeWeightsWithGPTQPass(loco::DataType input_model_dtype, loco::DataType output_model_dtype, QuantizationGranularity granularity)` and updates the test cases to use the remaining constructors.

ONE-DCO-1.0-Signed-off-by: y01000.you <[email protected]>
  • Loading branch information
y01000.you committed Dec 18, 2024
1 parent 2a75faa commit 77d43d2
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ class QuantizeDequantizeWeightsWithGPTQ final : public luci::CircleNodeMutableVi
QuantizationGranularity _granularity;
std::unordered_map<const luci::CircleNode *, std::vector<float>> *_hessian_map;

void fake_quantize(luci::CircleConst *weights) const
void fake_quantize(luci::CircleConst *weights)
{
if (_granularity != luci::QuantizationGranularity::ChannelWise)
{
Expand Down Expand Up @@ -577,7 +577,7 @@ class QuantizeDequantizeWeightsWithGPTQ final : public luci::CircleNodeMutableVi
weights->quantparam(std::move(quantparam));
}

void fake_quantize_with_gptq(luci::CircleConst *weights, std::vector<float> &hessian) const
void fake_quantize_with_gptq(luci::CircleConst *weights, std::vector<float> &hessian)
{
if (_granularity != luci::QuantizationGranularity::ChannelWise)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ struct QuantizeWeightsWithGPTQPassTest : public ::testing::Test

TEST_F(QuantizeWeightsWithGPTQPassTest, name)
{
luci::QuantizeDequantizeWeightsWithGPTQPass pass(loco::DataType::FLOAT32, loco::DataType::U8,
luci::QuantizationGranularity::ChannelWise);
auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsWithGPTQPass::Context>();
std::unordered_map<const luci::CircleNode *, std::vector<float>> hessian_map;

luci::QuantizeDequantizeWeightsWithGPTQPass pass(std::move(ctx), &hessian_map);
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}
Expand All @@ -87,40 +89,23 @@ TEST_F(QuantizeWeightsWithGPTQPassTest, name_ctx)
ctx->output_model_dtype = loco::DataType::U8;
ctx->granularity = luci::QuantizationGranularity::ChannelWise;
}
std::unordered_map<const luci::CircleNode *, std::vector<float>> hessian_map;

luci::QuantizeDequantizeWeightsWithGPTQPass pass(std::move(ctx));
luci::QuantizeDequantizeWeightsWithGPTQPass pass(std::move(ctx), &hessian_map);
auto const name = pass.name();
ASSERT_NE(nullptr, name);
}

// Negative test: Unsupported granularity - Invalid value
TEST_F(QuantizeWeightsWithGPTQPassTest, run_granularity_invalid_NEG)
{
auto invalid_granularity = static_cast<luci::QuantizationGranularity>(999);
luci::QuantizeDequantizeWeightsWithGPTQPass pass(loco::DataType::FLOAT32, loco::DataType::U8,
invalid_granularity);
ASSERT_EXIT(pass.run(&_g), ::testing::KilledBySignal(SIGSEGV), ".*");
}

// Negative test: Unsupported output data type - FLOAT32
TEST_F(QuantizeWeightsWithGPTQPassTest, run_output_f32_NEG)
{
luci::QuantizeDequantizeWeightsWithGPTQPass pass(loco::DataType::FLOAT32, loco::DataType::FLOAT32,
luci::QuantizationGranularity::ChannelWise);
// Since output type is FLOAT32, an exception is expected
EXPECT_THROW(pass.run(&_g), std::runtime_error);
}

// Negative test: Provide an empty hessian map
TEST_F(QuantizeWeightsWithGPTQPassTest, run_with_empty_hessian_map_NEG)
{
std::unordered_map<const luci::CircleNode *, std::vector<float>> hessian_map;
auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsWithGPTQPass::Context>();
{
ctx->input_model_dtype = loco::DataType::FLOAT32;
ctx->output_model_dtype = loco::DataType::U8;
ctx->granularity = luci::QuantizationGranularity::ChannelWise;
}
std::unordered_map<const luci::CircleNode *, std::vector<float>> hessian_map;

luci::QuantizeDequantizeWeightsWithGPTQPass pass(std::move(ctx), &hessian_map);
// Expect no exception, pass should handle empty hessian map gracefully
Expand All @@ -147,15 +132,30 @@ TEST_F(QuantizeWeightsWithGPTQPassTest, run_with_non_float_weights_NEG)
// Set dtype to INT32
weight->dtype(loco::DataType::S32);

luci::QuantizeDequantizeWeightsWithGPTQPass pass(loco::DataType::FLOAT32, loco::DataType::U8,
luci::QuantizationGranularity::ChannelWise);
auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsWithGPTQPass::Context>();
{
ctx->input_model_dtype = loco::DataType::FLOAT32;
ctx->output_model_dtype = loco::DataType::U8;
ctx->granularity = luci::QuantizationGranularity::ChannelWise;
}
std::unordered_map<const luci::CircleNode *, std::vector<float>> hessian_map;

luci::QuantizeDequantizeWeightsWithGPTQPass pass(std::move(ctx), &hessian_map);
// The pass should skip this node without exception
EXPECT_NO_THROW(pass.run(&_g));
}

// Positive test: Run pass with valid hessian map
TEST_F(QuantizeWeightsWithGPTQPassTest, run_with_valid_hessian)
{

auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsWithGPTQPass::Context>();
{
ctx->input_model_dtype = loco::DataType::FLOAT32;
ctx->output_model_dtype = loco::DataType::U8;
ctx->granularity = luci::QuantizationGranularity::ChannelWise;
}

// Create a hessian map with valid data
std::unordered_map<const luci::CircleNode *, std::vector<float>> hessian_map;
// Find the conv node
Expand All @@ -181,29 +181,37 @@ TEST_F(QuantizeWeightsWithGPTQPassTest, run_with_valid_hessian)

hessian_map[conv] = hessian;

auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsWithGPTQPass::Context>();
{
ctx->input_model_dtype = loco::DataType::FLOAT32;
ctx->output_model_dtype = loco::DataType::U8;
ctx->granularity = luci::QuantizationGranularity::ChannelWise;
}

luci::QuantizeDequantizeWeightsWithGPTQPass pass(std::move(ctx), &hessian_map);
EXPECT_NO_THROW(pass.run(&_g));
}

// Negative test: Input model data type is U8 (unsupported)
TEST_F(QuantizeWeightsWithGPTQPassTest, run_input_U8_NEG)
{
luci::QuantizeDequantizeWeightsWithGPTQPass pass(loco::DataType::U8, loco::DataType::U8,
luci::QuantizationGranularity::ChannelWise);

auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsWithGPTQPass::Context>();
{
ctx->input_model_dtype = loco::DataType::U8;
ctx->output_model_dtype = loco::DataType::U8;
ctx->granularity = luci::QuantizationGranularity::ChannelWise;
}
std::unordered_map<const luci::CircleNode *, std::vector<float>> hessian_map;

luci::QuantizeDequantizeWeightsWithGPTQPass pass(std::move(ctx), &hessian_map);
EXPECT_THROW(pass.run(&_g), std::runtime_error);
}

// Negative test: Output model data type is S32 (unsupported)
TEST_F(QuantizeWeightsWithGPTQPassTest, run_output_S32_NEG)
{
luci::QuantizeDequantizeWeightsWithGPTQPass pass(loco::DataType::FLOAT32, loco::DataType::S32,
luci::QuantizationGranularity::ChannelWise);

auto ctx = std::make_unique<luci::QuantizeDequantizeWeightsWithGPTQPass::Context>();
{
ctx->input_model_dtype = loco::DataType::FLOAT32;
ctx->output_model_dtype = loco::DataType::S32;
ctx->granularity = luci::QuantizationGranularity::ChannelWise;
}
std::unordered_map<const luci::CircleNode *, std::vector<float>> hessian_map;
luci::QuantizeDequantizeWeightsWithGPTQPass pass(std::move(ctx), &hessian_map);
EXPECT_THROW(pass.run(&_g), std::runtime_error);
}

0 comments on commit 77d43d2

Please sign in to comment.