Skip to content

Commit

Permalink
[luci] Add INT4 weights quantization (#12815)
Browse files Browse the repository at this point in the history
This commit adds INT4 weights quantization to QuantizeWeightsOnly.cpp.

ONE-DCO-1.0-Signed-off-by: Vyacheslav Bazhenov <[email protected]>

Co-authored-by: Vyacheslav Bazhenov <[email protected]>
  • Loading branch information
SlavikMIPT and Vyacheslav Bazhenov authored Mar 29, 2024
1 parent f1ea454 commit f2c6558
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 7 deletions.
18 changes: 15 additions & 3 deletions compiler/luci/pass/src/QuantizationUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,26 @@ void symmetric_wquant_with_minmax_per_layer(CircleConst *node, float min, float
}
}

int32_t max_for_sym_quant(const loco::DataType &type)
{
if (type == loco::DataType::S4)
return std::numeric_limits<int8_t>::max() >> 4;
else if (type == loco::DataType::S8)
return std::numeric_limits<int8_t>::max();
else if (type == loco::DataType::S16)
return std::numeric_limits<int16_t>::max();
else
throw std::runtime_error("Unsupported dtype for symmetric quantization");
};

void compute_sym_scale(float min, float max, float &scaling_factor, float &nudged_min,
float &nudged_max, loco::DataType out_type)
{
assert(min <= max);
assert(out_type == loco::DataType::S8 || out_type == loco::DataType::S16);
assert(out_type == loco::DataType::S4 || out_type == loco::DataType::S8 ||
out_type == loco::DataType::S16);

const int32_t kMaxScale = (out_type == loco::DataType::S16) ? std::numeric_limits<int16_t>::max()
: std::numeric_limits<int8_t>::max();
const int32_t kMaxScale = max_for_sym_quant(out_type);
const int32_t kMinScale = -kMaxScale;
const double qmin_double = kMinScale;
const double qmax_double = kMaxScale;
Expand Down
3 changes: 3 additions & 0 deletions compiler/luci/pass/src/QuantizationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
namespace luci
{

// Return the max value of dtype for symmetric quantization (int4/int8/int16)
int32_t max_for_sym_quant(const loco::DataType &type);

// Compute scale using given min/max for symmetric quantization (int8/int16)
void compute_sym_scale(float min, float max, float &scaling_factor, float &nudged_min,
float &nudged_max, loco::DataType out_type = loco::DataType::S16);
Expand Down
14 changes: 10 additions & 4 deletions compiler/luci/pass/src/QuantizeWeightsOnly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ void sym_wquant_per_channel(CircleConst *node, std::vector<float> &min, std::vec
std::vector<float> &nudged_max, int32_t &channel_dim_index)
{
assert(node->dtype() == loco::DataType::FLOAT32);
assert(out_type == loco::DataType::S8 || out_type == loco::DataType::S16);
const int32_t kMaxScale = (out_type == loco::DataType::S8) ? std::numeric_limits<int8_t>::max()
: std::numeric_limits<int16_t>::max();
assert(out_type == loco::DataType::S4 || out_type == loco::DataType::S8 ||
out_type == loco::DataType::S16);

const int32_t kMaxScale = max_for_sym_quant(out_type);
const int32_t kMinScale = -kMaxScale;

uint32_t size = node->size<loco::DataType::FLOAT32>();
Expand Down Expand Up @@ -163,7 +164,12 @@ void QuantizeWeightsOnly::quantize_weights(luci::CircleConst *weights)
std::vector<float> scaling_factor(min.size());
std::vector<int64_t> zp(min.size());

if (output_type == loco::DataType::S8)
if (output_type == loco::DataType::S4)
{
sym_wquant_per_channel<loco::DataType::S4>(weights, min, max, scaling_factor, nudged_min,
nudged_max, channel_dim_index);
}
else if (output_type == loco::DataType::S8)
{
sym_wquant_per_channel<loco::DataType::S8>(weights, min, max, scaling_factor, nudged_min,
nudged_max, channel_dim_index);
Expand Down

0 comments on commit f2c6558

Please sign in to comment.