From 63d7cf6f608c3490b3d9b6bb7928225845416698 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tarek=20Ziad=C3=A9?= Date: Sun, 22 Dec 2024 16:03:03 +0100 Subject: [PATCH] naive impl --- .../quantization/dynamic_quantize_matmul.cc | 157 +++++++++++++++++- 1 file changed, 156 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 69eabcfe2654a..b108d9e9f907c 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -9,6 +9,7 @@ #include "core/providers/cpu/quantization/matmul_integer_base.h" #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" +#include #include @@ -16,6 +17,20 @@ namespace onnxruntime { namespace contrib { namespace { + + +using Index = uint32_t; +extern "C" void + __attribute__((import_module("wasm_gemm"), import_name("int8_multiply"))) + int8Multiply(const uint8_t* input_A, + float zero_point_A, + const int8_t* input_B, + //const uint8_t* zero_point_B, + Index rows_A, + Index width, + Index cols_B, + float* output); + void ScaleOutput(const Tensor& scale, Tensor& output) { ProcessBroadcastSpanFuncs funcs{ [](BroadcastHelper& per_iter_bh) { @@ -51,12 +66,65 @@ class MatMulIntegerToFloatBase : public MatMulIntegerBase { float a_scale, uint8_t a_zp, bool a_is_signed, - const Tensor* b_tensor, +const Tensor* b_tensor, const Tensor* b_scale, const Tensor* b_zp, const Tensor* bias_tensor) const; }; +void MatMulFull(const uint8_t* inputMatrixA, + const int8_t* inputMatrixB, + float* output, + size_t rowsA, + size_t width, + size_t colsB, + uint8_t zeroPointA, + const uint8_t* zeroPointB, + const float* b_scale_data, + bool is_b_scale_per_column) { + + float matrixScale = is_b_scale_per_column ? 0.0f : b_scale_data[0]; + int32_t matrixZeroPointB = is_b_scale_per_column ? 0 : static_cast(zeroPointB[0]); + + for (size_t rowIndex = 0; rowIndex < rowsA; ++rowIndex) { + const uint8_t* aRow = inputMatrixA + rowIndex * width; // Start of row in A + for (size_t colIndex = 0; colIndex < colsB; ++colIndex) { + int32_t tempResult = 0; + + for (size_t k = 0; k < width; ++k) { + // Row-major access + uint8_t aValue = aRow[k]; + + // Column-major access for B + int8_t bValue = inputMatrixB[k * colsB + colIndex]; + + // Adjust for zero-point offsets + int32_t adjustedA = static_cast(aValue) - static_cast(zeroPointA); + int32_t adjustedB = static_cast(bValue); + + if (is_b_scale_per_column) { + adjustedB -= static_cast(zeroPointB[colIndex]); + } else { + adjustedB -= matrixZeroPointB; + } + // Accumulate product + tempResult += adjustedA * adjustedB; + } + + float scaledResult = tempResult; + if (is_b_scale_per_column) { + scaledResult *= b_scale_data[colIndex]; + } + else { + scaledResult *= matrixScale; + } + + // Store the scaled result in y_data + output[rowIndex * colsB + colIndex] = scaledResult; + } + } +} + Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, const uint8_t* a_data, const TensorShape& a_shape, @@ -150,11 +218,95 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, params.ldc = gemm_shape.N; } + // rowsA = M + // width = K + // colsB = N + size_t rowsA = static_cast(helper.M()); + size_t width = static_cast(helper.K()); + size_t colsB = static_cast(helper.N()); + const int8_t* b_data = static_cast(b_tensor->DataRaw()); + + #if 0 + size_t total_elements = rowsA * colsB; + size_t display_limit = std::min(total_elements, static_cast(100)); + std::vector y_data_2(rowsA * colsB, 0.0f); + std::cout << "Calling MatMulFull with the following parameters:\n"; + std::cout << "rowsA: " << rowsA << ", width: " << width << ", colsB: " << colsB << "\n"; + std::cout << "a_zp: " << static_cast(a_zp) << "\n"; + std::cout << "is_b_scale_per_column: " << is_b_scale_per_column << "\n"; + std::cout << "multiplier_per_tensor: " << multiplier_per_tensor << "\n"; + std::cout << "b_scale_data sample: ["; + for (size_t i = 0; i < 25; ++i) { + if (i > 0) std::cout << ", "; + std::cout << b_scale_data[i]; + } + std::cout << "]\n"; + std::cout << "b_zero point sample: ["; + for (size_t i = 0; i < 25; ++i) { + if (i > 0) std::cout << ", "; + std::cout << static_cast(b_zp_ptr[i]) << ", "; + } + std::cout << "]\n"; + + if (bias_data != nullptr) { + size_t bias_size = static_cast(bias_tensor->Shape().Size()); // Get the total size of bias_data + size_t display_limit = std::min(bias_size, static_cast(100)); + std::cout << "First " << display_limit << " elements of bias_data: ["; + for (size_t i = 0; i < display_limit; ++i) { + if (i > 0) std::cout << ", "; + std::cout << bias_data[i]; + } + std::cout << "]" << std::endl; + } + std::cout << "multiplier_per_tensor: " << multiplier_per_tensor << std::endl; + std::cout << "b_scale_data[0]: " << b_scale_data[0] << std::endl; + #endif + + MatMulFull(a_data, b_data, y_data_2.data(), rowsA, width, colsB, a_zp, b_zp_ptr, b_scale_data, is_b_scale_per_column); + + #if 0 MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool()); + // Compare y_data and y_data_2 + bool mismatch_found = false; + for (size_t i = 0; i < total_elements; ++i) { + if (std::fabs(y_data[i] - y_data_2[i]) > 1e-6) { // Tolerance for floating-point comparison + std::cerr << "Mismatch at index " << i << ": y_data=" << y_data[i] << ", y_data_2=" << y_data_2[i] << std::endl; + mismatch_found = true; + break; + } + } + + if (mismatch_found) { + std::cerr << "Displaying the first 100 elements of y_data and y_data_2:" << std::endl; + std::cerr << "["; + for (size_t i = 0; i < display_limit; ++i) { + std::cerr << "(Index " << i << ": y_data=" << y_data[i] << ", y_data_2=" << y_data_2[i] << ")"; + if (i != display_limit - 1) { + std::cerr << ", "; + } + } + std::cerr << "]" << std::endl; + std::cerr << "Mismatch found between y_data and y_data_2!" << std::endl; + assert(false && "Validation failed: y_data and y_data_2 are not equal."); + } + #endif return Status::OK(); } +/* + int8Multiply( + reinterpret_cast(a_data), + a_zp, + b_data, + //reinterpret_cast(b_zero_point->DataRaw()), + rowsA, + width, + colsB, + reinterpret_cast(y_data) + ); + */ + class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { public: DynamicQuantizeMatMul(const OpKernelInfo& info) : MatMulIntegerToFloatBase(info) {} @@ -234,6 +386,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { ctx->Input(IN_BIAS))); if (!is_b_scale_supported) { + std::cout << "dynamic quantize matmul: b scale is not supported\n"; ScaleOutput(*b_scale_tensor, *ctx->Output(0)); } @@ -289,9 +442,11 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const { ctx->Input(IN_BIAS))); if (!is_a_scale_scalar) { + std::cout << "dynamic quantize matmul: a scale is not scalar\n"; ScaleOutput(*a_scale_tensor, *ctx->Output(0)); } if (!is_b_scale_supported) { + std::cout << "dynamic quantize matmul: b scale is not supported\n"; ScaleOutput(*b_scale_tensor, *ctx->Output(0)); }