From 387d5c80f248163350d8ab6fecb7f94fb3eded3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tarek=20Ziad=C3=A9?= Date: Wed, 11 Dec 2024 14:16:16 +0100 Subject: [PATCH] use the real values --- .../quantization/firefox_matmul_integer.cc | 74 ++++++++++--------- onnxruntime/wasm/pre.js | 4 +- 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc index 8c71c88b25a51..9877191a37106 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc @@ -131,45 +131,50 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { #ifdef __EMSCRIPTEN__ //MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); - // moz gemmology will be called here... - Index rows_A = 4; - Index width = 64; // Must be a multiple of 64 - Index cols_B = 8; // Must be a multiple of 8 - - // Generate example data for A and B - std::vector A(rows_A * width, 1); // Example data for matrix A - std::vector B(width * cols_B, 1); // Example data for matrix B - std::vector bias(cols_B, 0.0f); // Example bias, set to 0 - // Prepare output buffer - std::vector output(rows_A * cols_B, 0.0f); - - // Quantization parameters - float scale_A = 0.1f; // Example scale factor for A - float zero_point_A = 0.0f; // Example zero point for A - float scale_B = 0.2f; // Example scale factor for B - float zero_point_B = 0.0f; // Example zero point for B - float unquant_multiplier = 1.0f; // Example multiplier + std::vector float_output(helper.M() * helper.N(), 0.0f); // Call the function - int8MultiplyAndAddBias(A.data(), - scale_A, - zero_point_A, - B.data(), - scale_B, - zero_point_B, - bias.data(), - unquant_multiplier, - rows_A, - width, - cols_B, - output.data()); + // matix A (M x K) * matrix B (K x N) + // matrix C (M x N) + size_t rows_a = static_cast(helper.M()); + size_t cols_b = static_cast(helper.N()); + size_t width = static_cast(helper.K()); + + int8MultiplyAndAddBias(reinterpret_cast(a_data), + 1.0f, // scale factor for A + a_offset, + reinterpret_cast(b_data), + 1.0f, // scale factor for B + 0, // b_zero_point + 0, // we don't have any bias + 1.0f, // quantization multiplier + rows_a, // rows A + width, // width + cols_b, // col B + float_output.data()); + + // temporarily convert to int32 + size_t num_elements = rows_a * cols_b; + + for (size_t i = 0; i < num_elements; ++i) { + // Convert and assign: round and cast the float to int32_t + y_data[i] = static_cast(std::round(float_output[i])); + + // Optional: Clamp to int32 range (unlikely needed if input floats are reasonable) + y_data[i] = std::clamp( + y_data[i], + std::numeric_limits::min(), + std::numeric_limits::max() + ); + } // Print the output + std::cout << "Output matrix:\n"; - for (Index i = 0; i < rows_A; ++i) { - for (Index j = 0; j < cols_B; ++j) { - std::cout << output[i * cols_B + j] << " "; + for (Index i = 0; i < rows_a; ++i) { + for (Index j = 0; j < cols_b; ++j) { + std::cout << y_data[i * cols_b + j] << " "; } std::cout << "\n"; } @@ -204,6 +209,7 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { std::cout << std::endl; // Move to the next row } + // the result is dequantized to float... gemmology::Shift::Multiply( reinterpret_cast(casted_a_data), casted_b_data, @@ -213,6 +219,8 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { gemmology::callbacks::Write(reinterpret_cast(y_data)) ); + // and we want int32.. + // // Get the shape of the tensor std::cout << "y data result:" << std::endl; diff --git a/onnxruntime/wasm/pre.js b/onnxruntime/wasm/pre.js index 2d786a844f259..d80a2011059f4 100644 --- a/onnxruntime/wasm/pre.js +++ b/onnxruntime/wasm/pre.js @@ -112,8 +112,8 @@ function fallbackGemm(gemmToFallbackFunctionsMap) { var INITIAL_MEMORY = 16777216; var gemmWasmMemory = new WebAssembly.Memory({ "initial": INITIAL_MEMORY / 65536, - "maximum": 4294967296 / 65536, - "shared": false + "maximum": 65536, // Maximum number of pages (4 GB) + "shared": true }); const optimizedGemmModuleExports = new WebAssembly.Instance(optimizedGemmModule(), { "": {