diff --git a/build.sh b/build.sh index bf799ac8b7211..0b293effe6330 100755 --- a/build.sh +++ b/build.sh @@ -1,21 +1,24 @@ #!/bin/bash # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +set -ex # Get directory this script is in -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" OS=$(uname -s) if [ "$OS" = "Darwin" ]; then - DIR_OS="MacOS" + DIR_OS="MacOS" else - DIR_OS="Linux" + DIR_OS="Linux" fi if [[ "$*" == *"--ios"* ]]; then - DIR_OS="iOS" + DIR_OS="iOS" elif [[ "$*" == *"--android"* ]]; then - DIR_OS="Android" + DIR_OS="Android" fi -python3 $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@" +PYTHON="${PYTHON:-python3}" + +$PYTHON $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@" diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 66268cefac9ef..3a5575b163b35 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -382,6 +382,7 @@ jsepDownload:_pp_") "SHELL:-s ASYNCIFY_STACK_SIZE=65536" "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']" + "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" ) set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) endif() diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index c742cd1e95bdd..f42fd54c3bc00 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -62,7 +62,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGram class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); - +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, FirefoxMatMulInteger8); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, FirefoxMatMulInteger8); // ******** Start: Quantization ******************* // class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool); @@ -285,6 +286,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // add more kernels here BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -364,7 +367,6 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - #ifdef ENABLE_ATEN BuildKernelCreateInfo, #endif diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 69eabcfe2654a..9b9d258a9c947 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -9,13 +9,35 @@ #include "core/providers/cpu/quantization/matmul_integer_base.h" #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" +#include +#include #include namespace onnxruntime { namespace contrib { namespace { + + +using Index = uint32_t; + +extern "C" void + __attribute__((import_module("wasm_gemm"), import_name("onnx_matmul_integer_to_float"))) + GeckoMatmulIntegerToFloat( + const uint8_t* a_data, + float zero_point_A, + const int8_t* input_B, + const uint8_t* zero_point_B, + uint32_t rows_A, + uint32_t width, + uint32_t cols_B, + const float* b_scale_data, + float is_b_scale_per_column, + float* output +); + + void ScaleOutput(const Tensor& scale, Tensor& output) { ProcessBroadcastSpanFuncs funcs{ [](BroadcastHelper& per_iter_bh) { @@ -51,12 +73,65 @@ class MatMulIntegerToFloatBase : public MatMulIntegerBase { float a_scale, uint8_t a_zp, bool a_is_signed, - const Tensor* b_tensor, +const Tensor* b_tensor, const Tensor* b_scale, const Tensor* b_zp, const Tensor* bias_tensor) const; }; +void MatMulFull(const uint8_t* inputMatrixA, + const int8_t* inputMatrixB, + float* output, + size_t rowsA, + size_t width, + size_t colsB, + uint8_t zeroPointA, + const uint8_t* zeroPointB, + const float* b_scale_data, + bool is_b_scale_per_column) { + + float matrixScale = is_b_scale_per_column ? 0.0f : b_scale_data[0]; + int32_t matrixZeroPointB = is_b_scale_per_column ? 0 : static_cast(zeroPointB[0]); + + for (size_t rowIndex = 0; rowIndex < rowsA; ++rowIndex) { + const uint8_t* aRow = inputMatrixA + rowIndex * width; // Start of row in A + for (size_t colIndex = 0; colIndex < colsB; ++colIndex) { + int32_t tempResult = 0; + + for (size_t k = 0; k < width; ++k) { + // Row-major access + uint8_t aValue = aRow[k]; + + // Column-major access for B + int8_t bValue = inputMatrixB[k * colsB + colIndex]; + + // Adjust for zero-point offsets + int32_t adjustedA = static_cast(aValue) - static_cast(zeroPointA); + int32_t adjustedB = static_cast(bValue); + + if (is_b_scale_per_column) { + adjustedB -= static_cast(zeroPointB[colIndex]); + } else { + adjustedB -= matrixZeroPointB; + } + // Accumulate product + tempResult += adjustedA * adjustedB; + } + + float scaledResult = tempResult; + if (is_b_scale_per_column) { + scaledResult *= b_scale_data[colIndex]; + } + else { + scaledResult *= matrixScale; + } + + // Store the scaled result in y_data + output[rowIndex * colsB + colIndex] = scaledResult; + } + } +} + Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, const uint8_t* a_data, const TensorShape& a_shape, @@ -150,8 +225,107 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx, params.ldc = gemm_shape.N; } + #if 0 + std::vector y_data_2(rowsA * colsB, 0.0f); + if (rowsA > 1) { + std::cout << "rowsA: " << rowsA << ", width: " << width << ", colsB: " << colsB << "\n"; + std::cout << "a_zp: " << static_cast(a_zp) << "\n"; + std::cout << "is_b_scale_per_column: " << is_b_scale_per_column << "\n"; + std::cout << "multiplier_per_tensor: " << multiplier_per_tensor << "\n"; + std::cout << "b_scale_data sample: ["; + for (size_t i = 0; i < 25; ++i) { + if (i > 0) std::cout << ", "; + std::cout << b_scale_data[i]; + } + std::cout << "]\n"; + std::cout << "b_zero point sample: ["; + for (size_t i = 0; i < 25; ++i) { + if (i > 0) std::cout << ", "; + std::cout << static_cast(b_zp_ptr[i]) << ", "; + } + std::cout << "]\n"; + + if (bias_data != nullptr) { + size_t bias_size = static_cast(bias_tensor->Shape().Size()); // Get the total size of bias_data + size_t display_limit = std::min(bias_size, static_cast(100)); + std::cout << "First " << display_limit << " elements of bias_data: ["; + for (size_t i = 0; i < display_limit; ++i) { + if (i > 0) std::cout << ", "; + std::cout << bias_data[i]; + } + std::cout << "]" << std::endl; + } + std::cout << "multiplier_per_tensor: " << multiplier_per_tensor << std::endl; + std::cout << "b_scale_data[0]: " << b_scale_data[0] << std::endl; + } + #endif + //auto start = std::chrono::steady_clock::now(); + //std::cout << "Calling f32Multiply\n"; + // should split in parts and call ctx.ParallelFor just on the rows part + + // rowsA = M + // width = K + // colsB = N + size_t rowsA = static_cast(helper.M()); + size_t width = static_cast(helper.K()); + size_t colsB = static_cast(helper.N()); + + const int8_t* b_data = static_cast(b_tensor->DataRaw()); + + GeckoMatmulIntegerToFloat(a_data, + a_zp, + b_data, + b_zp_ptr, + rowsA, + width, + colsB, + b_scale_data, + is_b_scale_per_column, + y_data + ); + + // MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool()); + /* + auto end = std::chrono::steady_clock::now(); + auto duration = std::chrono::duration_cast(end - start).count(); + std::cout << "Done calling f32Multiply. Duration: " << duration << " nano\n"; + + std::cout << "Calling MlasGemmBatch\n"; + auto start2 = std::chrono::steady_clock::now(); MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool()); + auto end2 = std::chrono::steady_clock::now(); + auto duration2 = std::chrono::duration_cast(end2 - start2).count(); + std::cout << "Done calling MlasGemmBatch. Duration: " << duration2 << " nano\n"; + */ + /* + + // Compare y_data and y_data_2 + + size_t total_elements = rowsA * colsB; + size_t display_limit = std::min(total_elements, static_cast(100)); + bool mismatch_found = false; + for (size_t i = 0; i < total_elements; ++i) { + if (std::fabs(y_data[i] - y_data_2[i]) > 1e-6) { // Tolerance for floating-point comparison + std::cerr << "Mismatch at index " << i << ": y_data=" << y_data[i] << ", y_data_2=" << y_data_2[i] << std::endl; + mismatch_found = true; + break; + } + } + if (mismatch_found) { + std::cerr << "Displaying the first 100 elements of y_data and y_data_2:" << std::endl; + std::cerr << "["; + for (size_t i = 0; i < display_limit; ++i) { + std::cerr << "(Index " << i << ": y_data=" << y_data[i] << ", y_data_2=" << y_data_2[i] << ")"; + if (i != display_limit - 1) { + std::cerr << ", "; + } + } + std::cerr << "]" << std::endl; + std::cerr << "Mismatch found between y_data and y_data_2!" << std::endl; + assert(false && "Validation failed: y_data and y_data_2 are not equal."); + } + */ return Status::OK(); } @@ -221,6 +395,10 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { ParQuantizeLinearStd(a_data, a_data_quant, narrow(num_of_elements), a_scale, a_zero_point, ctx->GetOperatorThreadPool()); bool is_b_scale_supported = IsBQuantParamSupported(b_scale_tensor->Shape(), b ? b->Shape() : b_shape_); + + //std::cout << "dynamic quantize matmul calling ComputeCommon" << std::endl; + + ORT_RETURN_IF_ERROR(ComputeCommon( ctx, a_data_quant, @@ -234,6 +412,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { ctx->Input(IN_BIAS))); if (!is_b_scale_supported) { + //std::cout << "dynamic quantize matmul: b scale is not supported\n"; ScaleOutput(*b_scale_tensor, *ctx->Output(0)); } @@ -275,6 +454,7 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const { a_zero_point = *(static_cast(a_zero_point_tensor->DataRaw())); } + //std::cout << "matmul integer float calling ComputeCommon" << std::endl; const Tensor* b_zp_tensor = ctx->Input(IN_B_ZERO_POINT); ORT_RETURN_IF_ERROR(ComputeCommon( ctx, @@ -289,9 +469,11 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const { ctx->Input(IN_BIAS))); if (!is_a_scale_scalar) { + //std::cout << "dynamic quantize matmul: a scale is not scalar\n"; ScaleOutput(*a_scale_tensor, *ctx->Output(0)); } if (!is_b_scale_supported) { + //std::cout << "dynamic quantize matmul: b scale is not supported\n"; ScaleOutput(*b_scale_tensor, *ctx->Output(0)); } diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc new file mode 100644 index 0000000000000..4acbaa2a6b1fc --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include +#include +#include "firefox_matmul_integer.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/common.h" +#include "core/util/math_cpuonly.h" +#include "core/util/qmath.h" +#include // For time measurement + + +// Define aliases for convenience +using Clock = std::chrono::high_resolution_clock; +using Microseconds = std::chrono::microseconds; +using Index = std::size_t; + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_TYPED_KERNEL_EX( + FirefoxMatMulInteger8, + kMSDomain, + 1, + uint8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + FirefoxMatMulInteger8); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + FirefoxMatMulInteger8, + kMSDomain, + 1, + int8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + FirefoxMatMulInteger8); + + +Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { + const auto* a = ctx->Input(IN_A); + const auto* b = packed_b_ ? nullptr : ctx->Input(IN_B); + uint8_t a_offset = 0; + const auto* a_zero_point = ctx->Input(IN_A_ZERO_POINT); + + if (a_zero_point != nullptr) { + ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point), + "MatmulInteger : input1 zero point must be a scalar or 1D tensor of size 1"); + a_offset = *(static_cast(a_zero_point->DataRaw())); + } + + uint8_t b_default_offset = 0; + const auto* b_zero_point = ctx->Input(IN_B_ZERO_POINT); + bool b_is_signed; + const uint8_t* b_offset_ptr = &b_default_offset; + bool is_b_zp_per_column = false; + + if (b_zero_point != nullptr) { + ORT_ENFORCE(IsBQuantParamSupported(b_zero_point->Shape(), b ? b->Shape() : b_shape_), + "MatmulInteger : B zero point is not valid"); + is_b_zp_per_column = !IsScalarOr1ElementVector(b_zero_point); + b_offset_ptr = static_cast(b_zero_point->DataRaw()); + } + + MatMulComputeHelper helper; + + const uint8_t* b_data; + if (nullptr != b) { + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape(), nullptr, b_zero_point ? &b_zero_point->Shape() : nullptr)); + b_data = static_cast(b->DataRaw()); + b_is_signed = b->IsDataType(); + } else { + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_, nullptr, b_zero_point ? &b_zero_point->Shape() : nullptr)); + b_data = static_cast(packed_b_.get()); + b_is_signed = b_is_signed_; + } + + + size_t M = static_cast(helper.M()); + size_t K = static_cast(helper.K()); + size_t N = static_cast(helper.N()); + + Tensor* y = ctx->Output(OUT_Y, helper.OutputShape()); + if (y->Shape().Size() == 0) { + return Status::OK(); + } + + const uint8_t* a_data = static_cast(a->DataRaw()); + auto* y_data = y->MutableData(); + + MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape; + gemm_shape.M = M; + gemm_shape.N = N; + gemm_shape.K = K; + gemm_shape.AIsSigned = a->IsDataType(); + gemm_shape.BIsSigned = b_is_signed; + + const size_t batch_size = helper.OutputOffsets().size(); + + std::vector gemm_data_vec(batch_size); + + for (size_t batch = 0; batch < batch_size; batch++) { + auto& gemm_params = gemm_data_vec[batch]; + gemm_params.lda = gemm_shape.K; + gemm_params.ZeroPointA = a_offset; + gemm_params.ldb = gemm_shape.N; + gemm_params.ZeroPointB = b_offset_ptr + helper.RightZeroPointOffsets()[batch]; + gemm_params.PerColumnZeroPoints = is_b_zp_per_column; + gemm_params.ldc = gemm_shape.N; + gemm_params.BIsPacked = bool(packed_b_); + gemm_params.A = a_data + helper.LeftOffsets()[batch]; + gemm_params.B = b_data + helper.RightOffsets()[batch]; + gemm_params.C = y_data + helper.OutputOffsets()[batch]; + } + #if 0 + std::cout << "Matrix A (sample):\n"; + for (size_t i = 0; i < 5; ++i) { + for (size_t j = 0; j < 5; ++j) { + std::cout << static_cast(a_data[i * helper.K() + j]) << " "; + } + std::cout << "\n"; + } + std::cout << "\n"; + std::cout << "Matrix B (sample):\n"; + for (size_t i = 0; i < 5; ++i) { + for (size_t j = 0; j < 5; ++j) { + std::cout << static_cast(b_data[i * helper.N() + j]) << " "; + } + std::cout << "\n"; + } + + std::cout << "b_zero_point content: \n"; + if (b_zero_point != nullptr) { + size_t b_zero_point_size = static_cast(b_zero_point->Shape()[0]); + const uint8_t* b_zp_data = static_cast(b_zero_point->DataRaw()); + for (size_t i = 0; i < b_zero_point_size; ++i) { + std::cout << static_cast(b_zp_data[i]) << " "; + } + std::cout << "\n"; + } else { + std::cout << "b_zero_point is null\n"; + } + #endif + //auto start_matmul = Clock::now(); + /* + int8Multiply( + reinterpret_cast(a->DataRaw()), + a_offset, + reinterpret_cast(b->DataRaw()), + reinterpret_cast(b_zero_point->DataRaw()), + M, + K, + N, + reinterpret_cast(y_data) + ); + */ + //auto end_matmul = Clock::now(); + //auto matmul_time = std::chrono::duration_cast(end_matmul - start_matmul).count(); + + // rowsA = M + // width = K + // colsB = N + #if 0 + for (size_t rowIndex = 0; rowIndex < rowsA; ++rowIndex) { + const uint8_t* aRow = inputMatrixAPtr + rowIndex * width; // Start of row in A + for (size_t colIndex = 0; colIndex < colsB; ++colIndex) { + int32_t tempResult = 0; + + for (size_t k = 0; k < width; ++k) { + // Row-major access + uint8_t aValue = aRow[k]; + + // Column-major access for B + int8_t bValue = inputMatrixBPtr[k * colsB + colIndex]; + + // Adjust for zero-point offsets + int32_t adjustedA = static_cast(aValue) - static_cast(a_offset); + int32_t adjustedB = static_cast(bValue); // - static_cast(b_offset_ptr[colIndex]); + + // Accumulate product + tempResult += adjustedA * adjustedB; + } + + // Write result to the output array + outputPtr[rowIndex * colsB + colIndex] = tempResult; + } + } + + // Mlas (will fallback if we don't meet requirements) + auto start_mblas = Clock::now(); + MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); + auto end_mblas = Clock::now(); + auto mblas_time = std::chrono::duration_cast(end_mblas - start_mblas).count(); + // Output timing results + std::cout << "Timing (microseconds):\n"; + std::cout << "MatMulFull: " << matmul_time << "\n"; + std::cout << "MlasGemmBatch: " << mblas_time << "\n"; + +#endif + MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h new file mode 100644 index 0000000000000..26a8aedfe1531 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cpu/quantization/matmul_integer_base.h" +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math_cpuonly.h" + +namespace onnxruntime { +namespace contrib { + +class FirefoxMatMulInteger8 final : public MatMulIntegerBase { + public: + FirefoxMatMulInteger8(const OpKernelInfo& info) : MatMulIntegerBase(info) {} + Status Compute(OpKernelContext* context) const override; + + enum InputTensors : int { + IN_A = 0, + IN_B = 1, + IN_A_ZERO_POINT = 2, + IN_B_ZERO_POINT = 3 + }; + + enum OutputTensors : int { OUT_Y = 0 }; + + protected: + int GetBIdx() const override { return IN_B; } +}; + +/** + * Headers for the gemmology functions + */ +#ifdef __EMSCRIPTEN__ +#include + + + +#include + +#if 0 +extern "C" void + __attribute__((import_module("wasm_gemm"), import_name("int8_multiply"))) + int8Multiply(const uint8_t* input_A, + float zero_point_A, + const int8_t* input_B, + const uint8_t* zero_point_B, + float rows_A, + float width, + float cols_B, + float* output); +#endif + +#endif + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 0d0b22ff61e01..0571942a2b30a 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1100,7 +1100,10 @@ static Status VerifyEachNodeIsAssignedToAnEpImpl(const Graph& graph, bool is_ver NodePlacementMap& node_placements, NodePlacementSet& node_placement_provider_set) { for (const auto& node : graph.Nodes()) { + const auto& node_provider = node.GetExecutionProviderType(); + printf("%s(%s), provider: %s\n", node.OpType().c_str(), node.Name().c_str(), node_provider.c_str()); + if (node_provider.empty()) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Could not find an implementation for ", diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 09a4a77780916..02bfed2b9660f 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1980,6 +1980,56 @@ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy- ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1); })); + +constexpr const char* FirefoxMatMulInteger_doc = R"DOC( +Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html +)DOC"; + + + +ONNX_MS_OPERATOR_SET_SCHEMA(FirefoxMatMulInteger8, 1, + OpSchema() + .SetDoc(FirefoxMatMulInteger_doc) + .Input(0, "A", "N-dimensional matrix A", "T1") + .Input(1, "B", "N-dimensional matrix B", "T2") + .Input(2, "a_zero_point", + "Zero point tensor for input 'A'. It's optional and default value is 0. It could be a scalar or a 1-D " + "tensor, " + "which means a per-tensor or per-column quantization. If it's a 1-D tensor, its number " + "of elements should be equal to the number of columns of input 'A'.", + "T1", OpSchema::Optional) + .Input(3, "b_zero_point", + "Zero point tensor for input 'B'. It's optional and default value is 0. It could be a scalar or a 1-D " + "tensor, " + "which means a per-tensor or per-column quantization. If it's a 1-D tensor, its number " + "of elements should be equal to the number of columns of input 'B'.", + "T2", OpSchema::Optional) + .Output(0, "Y", "Matrix multiply results from A * B", "T3") + .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input A data types as 8-bit integer tensor") + .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input B data types as 8-bit integer tensor") + .TypeConstraint("T3", + {"tensor(int32)", "tensor(uint32)"}, + "Constrain output Y data types as 32-bit integer tensor." + "T3 must be tensor(uint16) when both T1 and T2 are tensor(uint8)," + "or must be tensor(int16) when either T1 or T2 is tensor(int8).") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + auto a_type = ctx.getInputType(0); + auto b_type = ctx.getInputType(1); + auto y_type = ctx.getOutputType(0); + if (nullptr == a_type || nullptr == b_type || nullptr == y_type || + a_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType || + b_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType) { + fail_type_inference( + "inputs are expected to have tensor type and output type should not be null."); + } + + // Right now we only support int16 + y_type->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto::INT32); + + ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1); + })); + + /** * @brief Shape inference for MatMul with right hand side matrix quantized into int4 * @param ctx @@ -3780,6 +3830,7 @@ Having this op allows runtime to do operator re-ordering to reduce compute FLOPs #endif + #ifndef _OPSCHEMA_LIB_ // Register the NCHWc schemas if supported by the platform. if (MlasNchwcGetBlockSize() > 1) { diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index a9a89f756b071..5d98743f4078d 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -79,6 +79,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Irfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, IsAllFinite); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LongformerAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulInteger16); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FirefoxMatMulInteger8); #ifndef ORT_MINIMAL_BUILD class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4); #endif @@ -189,6 +190,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); #ifndef ORT_MINIMAL_BUILD fn(GetOpSchema()); #endif diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index 859fcd049ac7d..026a1215af42c 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -17,6 +17,7 @@ Module Name: #include "mlasi.h" #include "qgemm.h" +#include // // Define the parameters to execute segments of a QGEMM operation on worker @@ -144,6 +145,8 @@ MlasGemmBatch( const double Complexity = double(M) * double(N) * double(K) * double(BatchN); + //std::cout << "Complexity: " << Complexity << std::endl; + ptrdiff_t TargetThreadCount; if (Complexity < double(MLAS_QGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { @@ -194,10 +197,16 @@ MlasGemmBatch( WorkBlock.ThreadCountN = 1; } TargetThreadCount = ThreadsPerGemm * BatchN; + //std::cout << "ThreadsPerGemm: " << ThreadsPerGemm << std::endl; + //std::cout << "TargetThreadCount: " << TargetThreadCount << std::endl; + //std::cout << "MaximumThreadCount: " << MaximumThreadCount << std::endl; + + MlasTrySimpleParallel(ThreadPool, TargetThreadCount, [&](ptrdiff_t tid) { const auto gemm_i = tid / ThreadsPerGemm; const auto blk_i = tid % ThreadsPerGemm; + //std::cout << "gemm_i: " << gemm_i << " blk_i: " << blk_i << std::endl; MlasGemmQuantThreaded(&WorkBlock, &Shape, &DataParams[gemm_i], blk_i); }); } @@ -277,6 +286,13 @@ MlasSymmQgemmBatch( const size_t ThreadCountM = MlasDivRoundup(M, StrideM); const size_t ThreadCountN = MlasDivRoundup(N, StrideN); ThreadsPerGemm = ThreadCountM * ThreadCountN; + + /* + std::cout << "ThreadsPerGemm" << ThreadsPerGemm << std::endl; + std::cout << "TargetThreadCount " < #include +extern "C" void + __attribute__((import_module("wasm_gemm"), import_name("mlas_gemm_u8x8"))) + xMlasGemmU8X8MultiplyAccumulateRowWasmSimd( + const float* A, + const float* B, + const float* C + ); + + // // Define the default striding parameters used for the quantized integer // matrix/matrix multiply operation. diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp index 1f33d77adf4b9..a8e563155ff4e 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_wasmsimd.cpp @@ -17,6 +17,7 @@ Module Name: #include "mlasi.h" #include "qgemm.h" + // wasm implementation of "_mm_unpacklo_epi8" v128_t __attribute__((__always_inline__, __nodebug__)) wasm_i8x16_unpacklo(v128_t a, v128_t b) { return wasm_i8x16_shuffle(a, b, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23); @@ -324,23 +325,58 @@ MlasGemmQuantCopyPackB( MLAS_FORCEINLINE void -MlasGemmU8X8MultiplyAccumulateRowWasmSimd( +localMlasGemmU8X8MultiplyAccumulateRowWasmSimd( v128_t ABroadcast, const int16_t* B, v128_t Accumulators[2] ) { - v128_t BElements0 = wasm_v128_load(&B[0]); - v128_t BElements1 = wasm_v128_load(&B[8]); + v128_t BElements0 = wasm_v128_load(&B[0]); + v128_t BElements1 = wasm_v128_load(&B[8]); - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_i32x4_dot_i16x8(BElements0, ABroadcast)); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_i32x4_dot_i16x8(BElements1, ABroadcast)); + Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_i32x4_dot_i16x8(BElements0, ABroadcast)); + Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_i32x4_dot_i16x8(BElements1, ABroadcast)); } -template<> -size_t -MlasGemmQuantKernel( +#include +#include + + + +MLAS_FORCEINLINE +void MlasGemmU8X8MultiplyAccumulateRowWasmSimd( + v128_t ABroadcast, + const int16_t* B, + v128_t Accumulators[2] +) { + + v128_t BElements0 = wasm_v128_load(&B[0]); + v128_t BElements1 = wasm_v128_load(&B[8]); + + Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_i32x4_dot_i16x8(BElements0, ABroadcast)); + Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_i32x4_dot_i16x8(BElements1, ABroadcast)); +} + +extern "C" void + __attribute__((import_module("wasm_gemm"), import_name("mlas_gemm"))) + xMlasGemm( + const uint8_t* A, + const uint8_t* B, + int32_t* C, + float PackedCountK, + float CountM, + float CountN, + float ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + float ZeroMode + ); + + + +size_t MlasGemmWasm( const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, int32_t* C, @@ -351,11 +387,7 @@ MlasGemmQuantKernel( const int32_t* RowSumBuffer, const int32_t* ColumnSumBuffer, const int32_t* ZeroPointB, - bool ZeroMode - ) -{ - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(ldc); + bool ZeroMode) { while (CountN > 0) { @@ -499,6 +531,58 @@ MlasGemmQuantKernel( return 1; } +template<> +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + //MLAS_UNREFERENCED_PARAMETER(CountM); + //MLAS_UNREFERENCED_PARAMETER(ldc); + //std::cout << "Calling MlasGemmQuantKernel" << std::endl; + + MlasGemmWasm( + A, + B, + C, + PackedCountK, + CountM, + CountN, + ldc, + RowSumBuffer, + ColumnSumBuffer, + ZeroPointB, + ZeroMode + ); + + /* + xMlasGemm( + reinterpret_cast(A), + reinterpret_cast(B), + reinterpret_cast(C), + static_cast(PackedCountK), + static_cast(CountM), + static_cast(CountN), + static_cast(ldc), + RowSumBuffer, + ColumnSumBuffer, + ZeroPointB, + static_cast(ZeroMode)); + */ + return 1; +} + + const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd = { MlasGemmQuantOperation, nullptr, diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index 3d3e831a12d13..ed482c4f82eff 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -11,6 +11,20 @@ #include "core/mlas/inc/mlas.h" #include "core/util/qmath.h" + +extern "C" void + __attribute__((import_module("wasm_gemm"), import_name("onnx_dequantize_linear"))) +GeckoDequantizeLinear( + float M, + float K, + float N, + uint32_t input, + uint32_t scale, + uint32_t output, + uint32_t zero_point); + + + namespace onnxruntime { template @@ -265,6 +279,8 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( template struct DequantizeLinearApply; + + // The dimensions before quantize axis and after quantize axis can be flattened. // After flattening, the tensor can be represented by a rank-3 tensor. // If the quantization happens on the first or last axis, the flattened tensor is @@ -284,8 +300,13 @@ struct DequantizeLinearApply { * @param[out] output same shape as input * @param[in] zero_point same shape as scale */ + + // T= uint8, OutT ==float void op(size_t M, size_t K, size_t N, const T* input, const OutT* scale, OutT* output, const T* zero_point) { + // 1 1 136134656 + // 1 1 896 + //auto start = std::chrono::high_resolution_clock::now(); for (size_t m = 0; m < M; m++) { for (size_t k = 0; k < K; k++) { auto zp = zero_point ? static_cast(zero_point[k]) : 0; @@ -295,6 +316,8 @@ struct DequantizeLinearApply { } } } + //auto end = std::chrono::high_resolution_clock::now(); + //std::cout << "time: " << std::chrono::duration_cast(end - start).count() << std::endl; } /** @@ -463,6 +486,7 @@ DEQUANTIZE_LINEAR_APPLY_FLOAT8(Float8E5M2FNUZ) #endif + // formula is Y = (X - ZeroPoint) * Scale template Status DequantizeLinear::Compute(OpKernelContext* ctx) const { @@ -508,10 +532,25 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { static_cast(block_size_), input, scale, output, zero_point); } else { + //auto start = std::chrono::high_resolution_clock::now(); + if (process_block_size > 1000) { + + + GeckoDequantizeLinear(static_cast(process_block_count), + static_cast(broadcast_dim), + static_cast(process_block_size), + reinterpret_cast(input), + reinterpret_cast(scale), + reinterpret_cast(output), + reinterpret_cast(zero_point)); + } else { + //auto end = std::chrono::high_resolution_clock::now(); + //std::cout << std::chrono::duration_cast(end - start).count() << " micros" << std::endl; DequantizeLinearApply().op(static_cast(process_block_count), static_cast(broadcast_dim), static_cast(process_block_size), input, scale, output, zero_point); + } } } else if (to == ONNX_NAMESPACE::TensorProto::FLOAT16) { const MLFloat16* scale = x_scale.Data(); diff --git a/onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc b/onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc new file mode 100644 index 0000000000000..3b8e079ec705a --- /dev/null +++ b/onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math_cpuonly.h" + +namespace onnxruntime { +namespace test { + +TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_1) { + OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain); + test.AddInput("T1", {1, 1}, {15}); + test.AddInput("T2", {1, 1}, {8}); + test.AddOutput("T3", {1, 1}, {120}); // Result is 15 * 8 + test.Run(); +} + +TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_2) { + OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain); + test.AddInput("T1", {1, 2}, {-7, 10}); + test.AddInput("T2", {2, 1}, {-8, -11}); + test.AddOutput("T3", {1, 1}, {8}); // Result is (-7 * -8) + (10 * -11) + test.Run(); +} + +TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_Empty_input) { + OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain); + test.AddInput("T1", {0, 2}, {}); + test.AddInput("T2", {2, 1}, {-8, -11}); + test.AddOutput("T3", {0, 1}, {}); // Empty input produces an empty output + test.Run(); +} + +TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_3) { + OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain); + test.AddInput("T1", {3, 2}, {-7, 10, 10, -113, 22, -36}); + test.AddInput("T2", {2, 4}, {-8, -11, 13, 14, -9, 12, 3, -6}); + test.AddOutput("T3", {3, 4}, + {-158, 97, -61, -2, // First row results + 989, -1426, 1693, 1682, // Second row results + 282, -518, 280, -372}); // Third row results + test.Run(); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 7f4616c964e33..8d8227acb48ab 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -1329,8 +1329,11 @@ TEST(InferenceSessionTests, TestOptionalInputs) { "Invalid input name"); // missing required + printf("here"); + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(RunOptionalInputTest(false, true, false, version, sess_env), (version == 3 ? "Invalid input name" : "Missing Input:")); + printf("here 2"); } } diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 45e2475548df5..d23200a389100 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -3,13 +3,13 @@ 'use strict'; + // // This file contains the pre-run code for the ORT WebAssembly module. The code in this file will be injected into the // final module using Emscripten's `--pre-js` option. // // This file will only be used in build with flag `--use_jsep`. - /** * initialize JSEP for asyncify support. */ @@ -109,7 +109,7 @@ let jsepInitAsync = () => { if (Module.jsepSessionState) { throw new Error('Session already started'); } - const state = Module.jsepSessionState = {sessionHandle: args[0], errors: []}; + const state = Module.jsepSessionState = { sessionHandle: args[0], errors: [] }; // Run the acyncified function: OrtRun() or OrtRunWithBinding() const ret = await runAsyncFunc(...args); @@ -141,21 +141,21 @@ let jsepInitAsync = () => { // replace the original functions with asyncified versions Module['_OrtCreateSession'] = jsepWrapAsync( - Module['_OrtCreateSession'], - () => Module['_OrtCreateSession'], - v => Module['_OrtCreateSession'] = v); + Module['_OrtCreateSession'], + () => Module['_OrtCreateSession'], + v => Module['_OrtCreateSession'] = v); Module['_OrtRun'] = runAsync(jsepWrapAsync( - Module['_OrtRun'], - () => Module['_OrtRun'], - v => Module['_OrtRun'] = v)); + Module['_OrtRun'], + () => Module['_OrtRun'], + v => Module['_OrtRun'] = v)); Module['_OrtRunWithBinding'] = runAsync(jsepWrapAsync( - Module['_OrtRunWithBinding'], - () => Module['_OrtRunWithBinding'], - v => Module['_OrtRunWithBinding'] = v)); + Module['_OrtRunWithBinding'], + () => Module['_OrtRunWithBinding'], + v => Module['_OrtRunWithBinding'] = v)); Module['_OrtBindInput'] = jsepWrapAsync( - Module['_OrtBindInput'], - () => Module['_OrtBindInput'], - v => Module['_OrtBindInput'] = v); + Module['_OrtBindInput'], + () => Module['_OrtBindInput'], + v => Module['_OrtBindInput'] = v); // remove this function to make sure it is called only once. jsepInitAsync = undefined; @@ -170,16 +170,16 @@ Module['jsepInit'] = (name, params) => { if (name === 'webgpu') { [Module.jsepBackend, - Module.jsepAlloc, - Module.jsepFree, - Module.jsepCopy, - Module.jsepCopyAsync, - Module.jsepCreateKernel, - Module.jsepReleaseKernel, - Module.jsepRunKernel, - Module.jsepCaptureBegin, - Module.jsepCaptureEnd, - Module.jsepReplay] = params; + Module.jsepAlloc, + Module.jsepFree, + Module.jsepCopy, + Module.jsepCopyAsync, + Module.jsepCreateKernel, + Module.jsepReleaseKernel, + Module.jsepRunKernel, + Module.jsepCaptureBegin, + Module.jsepCaptureEnd, + Module.jsepReplay] = params; // expose webgpu backend functions const backend = Module.jsepBackend; @@ -211,11 +211,11 @@ Module['jsepInit'] = (name, params) => { // change the name. [Module.jsepBackend, - Module.jsepReserveTensorId, - Module.jsepReleaseTensorId, - Module['jsepEnsureTensor'], - Module.jsepUploadTensor, - Module['jsepDownloadTensor'], + Module.jsepReserveTensorId, + Module.jsepReleaseTensorId, + Module['jsepEnsureTensor'], + Module.jsepUploadTensor, + Module['jsepDownloadTensor'], ] = params; // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. @@ -243,7 +243,7 @@ Module['jsepInit'] = (name, params) => { }; Module['jsepRegisterMLConstant'] = (externalFilePath, dataOffset, dataLength, builder, desc) => { return backend['registerMLConstant']( - externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); + externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); }; } }; diff --git a/onnxruntime/wasm/pre.js b/onnxruntime/wasm/pre.js index 9b5f3ce545b78..47854471783ac 100644 --- a/onnxruntime/wasm/pre.js +++ b/onnxruntime/wasm/pre.js @@ -49,4 +49,33 @@ Module['unmountExternalData'] = () => { * @suppress {checkVars} */ var SharedArrayBuffer = globalThis.SharedArrayBuffer ?? - new WebAssembly.Memory({'initial': 0, 'maximum': 0, 'shared': true}).buffer.constructor; + new WebAssembly.Memory({ 'initial': 0, 'maximum': 0, 'shared': true }).buffer.constructor; + + +/** +* Custom call to instantiate WebAssembly module. so we can use custom imports +*/ Module["instantiateWasm"] = async (info, receiveInstance) => { + const wasmBinaryFile = findWasmBinary(); + const bytes = await getBinaryPromise(wasmBinaryFile); + const module = await WebAssembly.compile(bytes); + let imports = getWasmImports(); + + const OPTIMIZED_GEMM = "mozIntGemm"; + const optimizedGemmModule = WebAssembly[OPTIMIZED_GEMM]; + const optimizedGemmModuleExports = new WebAssembly.Instance(optimizedGemmModule(), { + "": { + memory: wasmMemory + } + }).exports; + + imports.wasm_gemm = optimizedGemmModuleExports; + + try { + var instance = new WebAssembly.Instance(module, imports); + receiveInstance(instance); + } catch (error) { + console.error("Error creating WebAssembly instance:", error); + throw error; + } +}; +