Skip to content

Commit

Permalink
[luci-interpreter] Support Hybrid INT4 weights FullyConnected kernel (S…
Browse files Browse the repository at this point in the history
…amsung#12801)

This commit adds support of INT4 quantized weights in FullyConnected kernel.

ONE-DCO-1.0-Signed-off-by: Vyacheslav Bazhenov <[email protected]>

Co-authored-by: Vyacheslav Bazhenov <[email protected]>
  • Loading branch information
SlavikMIPT and Vyacheslav Bazhenov authored Mar 27, 2024
1 parent 41000f9 commit c78bf8e
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 14 deletions.
67 changes: 55 additions & 12 deletions compiler/luci-interpreter/src/kernels/FullyConnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ void FullyConnected::configure()
LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::S8);
LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::S32)
}
else if (weights()->element_type() == DataType::S4)
{
// TODO support other combinations when needed
LUCI_INTERPRETER_CHECK(input()->element_type() == DataType::FLOAT32);
LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
LUCI_INTERPRETER_CHECK(!bias() || bias()->element_type() == DataType::FLOAT32)
}
else
{
throw std::runtime_error("luci-intp FullyConnected(1) Unsupported type.");
Expand Down Expand Up @@ -89,19 +96,30 @@ void FullyConnected::configure()

void FullyConnected::execute() const
{
switch (input()->element_type())
const bool is_hybrid =
(input()->element_type() == DataType::FLOAT32 && weights()->element_type() == DataType::S4 &&
output()->element_type() == DataType::FLOAT32 &&
(!bias() || bias()->element_type() == DataType::FLOAT32));
if (is_hybrid)
{
case DataType::U8:
evalQuantized();
break;
case DataType::S8:
evalQuantizedS8();
break;
case DataType::FLOAT32:
evalFloat();
break;
default:
throw std::runtime_error("luci-intp FullyConnected(2) Unsupported type.");
evalHybridWI4AF32();
}
else
{
switch (input()->element_type())
{
case DataType::U8:
evalQuantized();
break;
case DataType::S8:
evalQuantizedS8();
break;
case DataType::FLOAT32:
evalFloat();
break;
default:
throw std::runtime_error("luci-intp FullyConnected(2) Unsupported type.");
}
}
}

Expand Down Expand Up @@ -188,5 +206,30 @@ void FullyConnected::evalQuantizedS8() const
getTensorShape(output()), getTensorData<int8_t>(output()));
}

void FullyConnected::evalHybridWI4AF32() const
{
float activation_min{};
float activation_max{};
calculateActivationRange(_params.activation, &activation_min, &activation_max);

tflite::FullyConnectedParams params{};
params.float_activation_min = activation_min;
params.float_activation_max = activation_max;
params.weights_format = tflite::FullyConnectedWeightsFormat::kDefault;

const int8_t *weights_int4 = getTensorData<int8_t>(weights());
float *weights_float = getTensorData<float>(scratch());
const Shape &weights_shape = weights()->shape();
for (int32_t i = 0; i < weights_shape.num_elements(); ++i)
{
// 1bit for sign, 3bit for value
weights_float[i] = weights()->scale() * weights_int4[i];
}
tflite::reference_ops::FullyConnected(
params, getTensorShape(input()), getTensorData<float>(input()), getTensorShape(scratch()),
getTensorData<float>(scratch()), getTensorShape(bias()), getTensorData<float>(bias()),
getTensorShape(output()), getTensorData<float>(output()));
}

} // namespace kernels
} // namespace luci_interpreter
10 changes: 9 additions & 1 deletion compiler/luci-interpreter/src/kernels/FullyConnected.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,17 @@ class FullyConnected : public KernelWithParams<FullyConnectedParams>
public:
FullyConnected(const Tensor *input, const Tensor *weights, const Tensor *bias, Tensor *output,
const FullyConnectedParams &params);

FullyConnected(const Tensor *input, const Tensor *weights, const Tensor *bias, Tensor *output,
Tensor *scratch, const FullyConnectedParams &params)
: FullyConnected(input, weights, bias, output, params)
{
_scratch = scratch;
}
const Tensor *input() const { return _inputs[0]; }
const Tensor *weights() const { return _inputs[1]; }
const Tensor *bias() const { return _inputs[2]; }
Tensor *output() const { return _outputs[0]; }
Tensor *scratch() const { return _scratch; }

void configure() override;
void execute() const override;
Expand All @@ -43,6 +49,8 @@ class FullyConnected : public KernelWithParams<FullyConnectedParams>
void evalFloat() const;
void evalQuantized() const;
void evalQuantizedS8() const;
void evalHybridWI4AF32() const;
Tensor *_scratch = nullptr;
};

} // namespace kernels
Expand Down
50 changes: 50 additions & 0 deletions compiler/luci-interpreter/src/kernels/FullyConnected.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,56 @@ TYPED_TEST(FullyConnectedTest, Simple)
});
}

TEST(FullyConnectedTest, SimpleS4)
{
std::initializer_list<int32_t> input_shape{1, 2};
std::initializer_list<int32_t> weights_shape{4, 2};
std::initializer_list<int32_t> bias_shape{4};
std::initializer_list<int32_t> output_shape{1, 4};
std::initializer_list<float> input_data{
1, 3, // batch = 0
};
std::initializer_list<int8_t> weights_initializer{
0, 1, // unit = 0
0, 0, // unit = 1
-1, -1, // unit = 2
0, 0, // unit = 3
};
std::initializer_list<float> bias_data{0, 1, 2, 3};
std::initializer_list<float> output_data{
1.5, 1, 0, 3, // batch = 0
};
std::unique_ptr<IMemoryManager> memory_manager = std::make_unique<TestMemoryManager>();

Tensor input_tensor =
makeInputTensor<DataType::FLOAT32>(input_shape, input_data, memory_manager.get());
std::vector<int8_t> quantized_data(weights_initializer);
Tensor weights_tensor(DataType::S4, weights_shape, {{0.5}, {0}}, "");
memory_manager->allocate_memory(weights_tensor);
weights_tensor.writeData(quantized_data.data(), quantized_data.size() * sizeof(int8_t));
Tensor weights_scratch(DataType::FLOAT32, weights_shape, {}, "");
memory_manager->allocate_memory(weights_scratch);

Tensor bias_tensor =
makeInputTensor<DataType::FLOAT32>(bias_shape, bias_data, memory_manager.get());
Tensor output_tensor = makeOutputTensor(DataType::FLOAT32);

const float quantized_tolerance = getTolerance(-8, 7, 15);

FullyConnectedParams params{};
params.activation = Activation::RELU;

FullyConnected kernel(&input_tensor, &weights_tensor, &bias_tensor, &output_tensor,
&weights_scratch, params);
kernel.configure();
memory_manager->allocate_memory(output_tensor);
kernel.execute();

EXPECT_THAT(extractTensorShape(output_tensor), ::testing::ElementsAreArray(output_shape));
EXPECT_THAT(extractTensorData<float>(output_tensor),
FloatArrayNear(output_data, quantized_tolerance));
}

TEST(FullyConnectedTest, InvalidBiasType_NEG)
{
Shape input_shape{3, 2, 2, 1};
Expand Down
13 changes: 12 additions & 1 deletion compiler/luci-interpreter/src/loader/nodes/FullyConnected.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,18 @@ std::unique_ptr<Kernel> build_kernel_CircleFullyConnected(const luci::CircleNode
FullyConnectedParams params{};
params.activation = node->fusedActivationFunction();
params.keep_num_dims = node->keep_num_dims();

if (weights->element_type() == loco::DataType::S4)
{
auto scratchpad =
std::make_unique<Tensor>(input->element_type(), weights->shape(), AffineQuantization{}, "");
scratchpad->set_observable(false);
scratchpad->set_data_buffer(nullptr);
Tensor *scratchpad_tmp =
helper.getRuntimeGraph(node->graph())->addTensor(std::move(scratchpad));
helper.getRuntimeGraph(node->graph())->configureAllocations(scratchpad_tmp);
return std::make_unique<kernels::FullyConnected>(input, weights, bias, output, scratchpad_tmp,
params);
}
return std::make_unique<kernels::FullyConnected>(input, weights, bias, output, params);
}

Expand Down

0 comments on commit c78bf8e

Please sign in to comment.