From 97ea8feaa6b98c4bd9bf6ae5d1d8623a29d870d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tarek=20Ziad=C3=A9?= Date: Mon, 2 Dec 2024 10:09:49 +0100 Subject: [PATCH] added a firefox matmul backend --- build.sh | 15 +- cmake/onnxruntime_webassembly.cmake | 1 + .../contrib_ops/cpu/cpu_contrib_kernels.cc | 6 +- .../quantization/firefox_matmul_integer.cc | 178 ++++++++++ .../cpu/quantization/firefox_matmul_integer.h | 309 ++++++++++++++++++ onnxruntime/core/framework/session_state.cc | 3 + .../core/graph/contrib_ops/contrib_defs.cc | 51 +++ onnxruntime/core/graph/contrib_ops/ms_opset.h | 2 + .../firefox_matmul_integer_test.cc | 50 +++ .../test/framework/inference_session_test.cc | 3 + onnxruntime/wasm/pre-jsep.js | 60 ++-- onnxruntime/wasm/pre.js | 91 +++++- 12 files changed, 730 insertions(+), 39 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc create mode 100644 onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h create mode 100644 onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc diff --git a/build.sh b/build.sh index bf799ac8b7211..0b293effe6330 100755 --- a/build.sh +++ b/build.sh @@ -1,21 +1,24 @@ #!/bin/bash # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +set -ex # Get directory this script is in -DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" OS=$(uname -s) if [ "$OS" = "Darwin" ]; then - DIR_OS="MacOS" + DIR_OS="MacOS" else - DIR_OS="Linux" + DIR_OS="Linux" fi if [[ "$*" == *"--ios"* ]]; then - DIR_OS="iOS" + DIR_OS="iOS" elif [[ "$*" == *"--android"* ]]; then - DIR_OS="Android" + DIR_OS="Android" fi -python3 $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@" +PYTHON="${PYTHON:-python3}" + +$PYTHON $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@" diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 66268cefac9ef..3a5575b163b35 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -382,6 +382,7 @@ jsepDownload:_pp_") "SHELL:-s ASYNCIFY_STACK_SIZE=65536" "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']" + "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" ) set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) endif() diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index c742cd1e95bdd..f42fd54c3bc00 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -62,7 +62,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGram class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); - +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, FirefoxMatMulInteger8); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, FirefoxMatMulInteger8); // ******** Start: Quantization ******************* // class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool); @@ -285,6 +286,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // add more kernels here BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -364,7 +367,6 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - #ifdef ENABLE_ATEN BuildKernelCreateInfo, #endif diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc new file mode 100644 index 0000000000000..5ae79111d5c78 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "firefox_matmul_integer.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/common.h" +#include "core/util/math_cpuonly.h" +#include "core/util/qmath.h" + +namespace onnxruntime { +namespace contrib { + +ONNX_OPERATOR_TYPED_KERNEL_EX( + FirefoxMatMulInteger8, + kMSDomain, + 1, + uint8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + FirefoxMatMulInteger8); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + FirefoxMatMulInteger8, + kMSDomain, + 1, + int8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + FirefoxMatMulInteger8); + + + +/** Typical Call + +Input Tensor A shape: {1,171,1024} +Input Tensor B shape: {1024,1024} +A Zero Point shape: {} +A Zero Point value: 123 +B Zero Point shape: {1024} +B Zero Point is per-column: 1 +Computing helper with A and B shapes. +Output Tensor Y shape: {1,171,1024} +GEMM Shape - M: 171, N: 1024, K: 1024, AIsSigned: 0, BIsSigned: 1 +Batch size: 1 + +*/ +Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { + std::cout << "FirefoxMatMulInteger8::Compute started" << std::endl; + const auto* a = ctx->Input(IN_A); + const auto* b = packed_b_ ? nullptr : ctx->Input(IN_B); + + // Validate zero points + uint8_t a_offset = 0; + const auto* a_zero_point = ctx->Input(IN_A_ZERO_POINT); + if (a_zero_point != nullptr) { + ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point), + "MatmulInteger : input1 zero point must be a scalar or 1D tensor of size 1"); + a_offset = *(static_cast(a_zero_point->DataRaw())); + } + + bool is_b_zp_per_column = false; + uint8_t b_default_offset = 0; + const uint8_t* b_offset_ptr = &b_default_offset; + const auto* b_zero_point = ctx->Input(IN_B_ZERO_POINT); + if (b_zero_point != nullptr) { + ORT_ENFORCE(IsBQuantParamSupported(b_zero_point->Shape(), b ? b->Shape() : b_shape_), + "MatmulInteger : B zero point is not valid"); + is_b_zp_per_column = !IsScalarOr1ElementVector(b_zero_point); + b_offset_ptr = static_cast(b_zero_point->DataRaw()); + } + + MatMulComputeHelper helper; + const uint8_t* b_data; + bool b_is_signed; + if (nullptr != b) { + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape(), nullptr, b_zero_point ? &b_zero_point->Shape() : nullptr)); + b_data = static_cast(b->DataRaw()); + b_is_signed = b->IsDataType(); + } else { + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_, nullptr, b_zero_point ? &b_zero_point->Shape() : nullptr)); + b_data = static_cast(packed_b_.get()); + b_is_signed = b_is_signed_; + } + + Tensor* y = ctx->Output(OUT_Y, helper.OutputShape()); + if (y->Shape().Size() == 0) { + return Status::OK(); + } + const uint8_t* a_data = static_cast(a->DataRaw()); + auto* y_data = y->MutableData(); + + +#ifdef __EMSCRIPTEN__ + + float alpha = 25; + float unquant_mult_forprep = (-1)*(alpha)*(alpha)/(127.0f); + //float quant_mult = 127/alpha; + +/* + // prepate the two matrices + // we need to import `intgemm/aligned.h` + // XXX speeds up because the memory start is a multiple of 8 + AlignedVector A_prepared(a_data.size()); + AlignedVector B_prepared(b_data.size()); + + int8PrepareA(a_data, A_prepared, quant_mult, A_rows, width); + int8PrepareB(b_data, B_prepared, quant_mult, width, B_cols); + */ + // XXX I don't think I need to add a bias here + // Multiply + // take unit8 for a and int8 for b + std::cout << "int8MultiplyAndAddBias" << std::endl; + + int8MultiplyAndAddBias( + // XXX this will work as long as a_data is aligned + // look at the memory address printf %p multiple of 8 + reinterpret_cast(a_data), + 1.0, // scale_A, + a_offset, + reinterpret_cast(b_data), + 1.0, //scale_B, + 0, // b_offset, // XXX + 0, // input_bias_prepared, + unquant_mult_forprep, + 1, //A_rows, + static_cast(helper.K()), // width, + 1, // B_cols, + reinterpret_cast(y_data) + ); + + std::cout << "int8MultiplyAndAddBias done" << std::endl; + + return Status::OK(); + #endif + + // MLas fallback. + MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape; + gemm_shape.M = static_cast(helper.M()); + gemm_shape.N = static_cast(helper.N()); + gemm_shape.K = static_cast(helper.K()); + gemm_shape.AIsSigned = a->IsDataType(); + gemm_shape.BIsSigned = b_is_signed; + + const size_t batch_size = helper.OutputOffsets().size(); + + std::vector gemm_data_vec(batch_size); + + for (size_t batch = 0; batch < batch_size; batch++) { + auto& gemm_params = gemm_data_vec[batch]; + gemm_params.lda = gemm_shape.K; + gemm_params.ZeroPointA = a_offset; + gemm_params.ldb = gemm_shape.N; + gemm_params.ZeroPointB = b_offset_ptr + helper.RightZeroPointOffsets()[batch]; + gemm_params.PerColumnZeroPoints = is_b_zp_per_column; + gemm_params.ldc = gemm_shape.N; + gemm_params.BIsPacked = bool(packed_b_); + gemm_params.A = a_data + helper.LeftOffsets()[batch]; + gemm_params.B = b_data + helper.RightOffsets()[batch]; + gemm_params.C = y_data + helper.OutputOffsets()[batch]; + } + + std::cout << "Calling MlasGemmBatch." << std::endl; + MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); + std::cout << "Exiting FirefoxMatMulInteger8::Compute" << std::endl; + return Status::OK(); +} + + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h new file mode 100644 index 0000000000000..d20a1aa123397 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h @@ -0,0 +1,309 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cpu/quantization/matmul_integer_base.h" +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math_cpuonly.h" + +namespace onnxruntime { +namespace contrib { + +class FirefoxMatMulInteger8 final : public MatMulIntegerBase { + public: + FirefoxMatMulInteger8(const OpKernelInfo& info) : MatMulIntegerBase(info) {} + Status Compute(OpKernelContext* context) const override; + + enum InputTensors : int { + IN_A = 0, + IN_B = 1, + IN_A_ZERO_POINT = 2, + IN_B_ZERO_POINT = 3 + }; + + enum OutputTensors : int { OUT_Y = 0 }; + + protected: + int GetBIdx() const override { return IN_B; } +}; + +/** + * Headers for the gemmology functions + */ +#ifdef __EMSCRIPTEN__ +#include + + +/** Main interface for integer matrix multiplication followed by addition of bias for wasm. + * + * C = A * B + Bias + * + * Input matrix A: + * - is a 2-D matrix that typically represents activations as floating point values + * - no. of rows should be a multiple of 1 (i.e. no restriction) + * - no. of columns should be a multiple of 64 + * - is represented as array (contiguous memory locations) in row-major format + * + * Input matrix B: + * - is a 2-D matrix that typically represents fixed model parameters as floating point values + * - no. of rows should be: + * -- equal to no. of columns of Input matrix A + * -- a multiple of 64 + * - no. of columns should be a multiple of 8 + * - is represented as array (contiguous memory locations) in row-major format + * + * Please note that it is also possible to pass Input matrix B in 2 more forms: + * - One that is already a quantized and transposed version of Input matrix B + * - Other that is already a transposed version of Input matrix B + * + * Input Bias: + * - is an array (contiguous memory locations) that represents bias + * - size of the array should be equal to the no. of columns of Input matrix B + * + * Output matrix C: + * - is a 2-D matrix that represents the result (= A * B + Bias) + * - no. of rows will be equal to no. of rows of Input matrix A + * - no. of columns will be equal to no. of columns of Input matrix B (in untransposed form) + * - is represented as array (contiguous memory locations) in row-major format + * + * Please note that most of the functions in this interface might have architecture specific + * implementations. + * + * Conventions followed throughout this file: + * - Unless explicitly mentioned, Input matrix B always means an unquantized (i.e. float values) + * and non-transposed version + * - no. of rows of Input matrix A = `rows_A` + * - no. of columns of Input matrix A = no. of rows of Input matrix B = `width` + * - no. of columns of Input matrix B = `cols_B` + */ + +#include + +using Index = uint32_t; + +/** + * Prepare B for the Matrix Multiply function from Input matrix B. + * + * Quantization is performed on the input. + * The final prepared B is in CPU-dependent format and can be used as an input to matrix multiply + * function (`int8MultiplyAndAddBias`). + * + * Please note that this interface might have architecture specific implementation. + * + * @param[in] input_B An array representing the Input matrix B in row-major format. + * Size of the array = `width` * `cols_B`. + * Shape of the matrix: (`width`, `cols_B`) + * @param[in] scale The scaling factor (for quantization) + * @param[in] zero_point The zero point (for quantization) + * @param[in] width No. of rows of Input matrix B. It should be a multiple of 64. + * @param[in] cols_B No. of columns of Input matrix B. It should be a multiple of 8. + * @param[out] output An array representing the prepared B matrix. + * Size of the array = `width` * `cols_B`. + */ +extern "C" void __attribute__((import_module("wasm_gemm"), import_name("int8_prepare_b"))) +int8PrepareB(const float* input_B, + float scale, + float zero_point, + Index width, + Index cols_B, + int8_t* output); + +/** + * Prepare B for the Matrix Multiply function from transposed version of Input matrix B. + * + * Quantization is performed on floating values of input. + * The final prepared B is in CPU-dependent format and can be used as an input to matrix multiply + * function (`int8MultiplyAndAddBias`). + * + * Please note that this interface might have architecture specific implementation. + * + * @param[in] input_B_transposed An array representing transposed version of Input matrix B. + * It is in column-major format. + * Size of the array = `width` * `cols_B`. + * Shape of the matrix: (`cols_B`, `width`) + * @param[in] scale The scaling factor (for quantization) + * @param[in] zero_point The zero point (for quantization) + * @param[in] width No. of rows of Input matrix B. It should be a multiple of 64. + * @param[in] cols_B No. of columns of Input matrix B. Should be a multiple of 8. + * @param[out] output An array representing the prepared B matrix. + * Size of the array = `width` * `cols_B`. + */ +extern "C" void + __attribute__((import_module("wasm_gemm"), import_name("int8_prepare_b_from_transposed"))) + int8PrepareBFromTransposed(const float* input_B_transposed, + float scale, + float zero_point, + Index width, + Index cols_B, + int8_t* output); + +/** + * Prepare B for the Matrix Multiply function from a quantized and transposed version of Input + * matrix B which is also in a CPU-independent format. + * + * The final prepared B is in CPU-dependent format and can be used as an input to matrix multiply + * function (`int8MultiplyAndAddBias`). + * + * This function is useful while using the quantized models that are stored in a CPU-independent + * format on the disk. + * + * @param[in] input_B_quant_transposed An array representing the quantized and transposed + * version of Input matrix B. It is in column-major format. + * Size of the array = `width` * `cols_B`. + * Shape of the matrix: (`cols_B`, `width`) + * @param[in] width No. of rows of Input matrix B. Should be multiple of 64 + * @param[in] cols_B No. of columns of Input matrix B. Should be multiple of 8 + * @param[out] output An array representing the prepared B matrix. + * Size of the array = `width` * `cols_B`. + */ +extern "C" void __attribute__((import_module("wasm_gemm"), + import_name("int8_prepare_b_from_quantized_transposed"))) +int8PrepareBFromQuantizedTransposed(const int8_t* input_B_quant_transposed, + Index width, + Index cols_B, + int8_t* output); + +/** + * Prepare A for the Matrix Multiply function from Input matrix A. + * + * It performs quantization on floating values of input. + * The final prepared A might be architecture dependent. e.g. On some architectures like x86, it + * might be unsigned (achieved by adding 127 to quantized values) while on others like Arm, it might + * be signed. + * The final prepared A can be used as an input to matrix multiply function + * (`int8MultiplyAndAddBias`). + * + * Please note that this interface might have architecture specific implementation. + * + * @param[in] input_A An array representing the Input matrix A in row-major format. + * Size of the array = `rows_A` * `width`. + * Shape of the matrix: (`rows_A`, `width`) + * @param[in] scale The scaling factor (for quantization) + * @param[in] zero_point The zero point (for quantization) + * @param[in] rows_A No. of rows of Input matrix A. No restriction on its size. + * @param[in] width No. of columns of Input matrix A. It should be a multiple of 64. + * @param[out] output An array representing the prepared A matrix. + * Size of the array = `rows_A` * `width`. + */ +extern "C" void __attribute__((import_module("wasm_gemm"), import_name("int8_prepare_a"))) +int8PrepareA(const float* input_A, + float scale, + float zero_point, + Index rows_A, + Index width, + int8_t* output); + +/** + * Prepares bias for the Matrix Multiply function. + * + * It uses the prepared B (which must be obtained by using any of the int8PrepareB* functions) and + * a bias input to prepare the final bias. + * + * The final bias can be used as an input to matrix multiply function (`int8MultiplyAndAddBias`). + * + * @param[in] input_B_prepared An array representing the prepared B matrix. + * Size of the array = `width` * `cols_B`. + * @param[in] scale_A The scaling factor (for quantization) of A + * @param[in] zero_point_A The zero point (for quantization) of A + * @param[in] scale_B The scaling factor (for quantization) of B + * @param[in] zero_point_B The zero point (for quantization) of B + * factor that is prepared from `scale_A` and `scale_B`. + * @param[in] width No. of rows of Input matrix B (unquantized & non-transposed). + * It should be a multiple of 64. + * @param[in] cols_B No. of columns of Input matrix B (unquantized & non-transposed) + * It should be a multiple of 8. + * @param[in] input_bias An array representing the input bias. Size of array = `cols_B` + * @param[out] output An array representing the final prepared bias. + * Size of the array = `cols_B` + */ +extern "C" void __attribute__((import_module("wasm_gemm"), import_name("int8_prepare_bias"))) +int8PrepareBias(const int8_t* input_B_prepared, + float scale_A, + float zero_point_A, + float scale_B, + float zero_point_B, + Index width, + Index cols_B, + const float* input_bias, + float* output); + +/** + * Perform multiplication of 2 matrices followed by adding a bias. + * + * i.e Output = A_prepared * B_prepared + Bias_prepared + * + * The inputs A_prepared, B_prepared and Bias_prepared of this function must be + * obtained by using `int8PrepareA`, one of the `int8PrepareB*` and `int8PrepareBias` + * functions respectively. + * + * Please note that this interface might have architecture specific implementation. + * + * @param[in] input_A_prepared An array representing the prepared A matrix. + * This must be obtained by using `int8PrepareA` function. + * Size of the array = `rows_A` * `width`. + * @param[in] scale_A The scaling factor (for quantization) of A + * @param[in] zero_point_A The zero point (for quantization) of A + * @param[in] input_B_prepared An array representing the prepared B matrix. + * This must be obtained by using one of `int8PrepareB*` + * functions. Size of the array = `width` * `cols_B`. + * @param[in] scale_B The scaling factor (for quantization) of B + * @param[in] zero_point_B The zero point (for quantization) of B + * @param[in] input_bias_prepared An array representing the prepared bias. + * This must be obtained by using `int8PrepareBias` function. + * Size of the array = `cols_B` + * @param[in] unquant_multiplier A value that will be multiplied to the final unquantization + * factor that is prepared from `scale_A` and `scale_B`. + * @param[in] rows_A No. of rows of Input matrix A. No restriction on its size. + * @param[in] width No. of columns of Input matrix A (same as no. of columns of + * Input matrix B). It should be a multiple of 64. + * @param[in] cols_B No. of columns of Input matrix B. Should be a multiple of 8. + * @param[out] output An array representing the result matrix in row-major format. + * Size of the array = `rows_A` * `cols_B`. + */ +extern "C" void + __attribute__((import_module("wasm_gemm"), import_name("int8_multiply_and_add_bias"))) + int8MultiplyAndAddBias(const int8_t* input_A_prepared, + float scale_A, + float zero_point_A, + const int8_t* input_B_prepared, + float scale_B, + float zero_point_B, + const float* input_bias_prepared, + float unquant_multiplier, + Index rows_A, + Index width, + Index cols_B, + float* output); + +/** + * Select a subset of columns of prepared B. + * + * Indices of the columns to be selected are specified by an array. + * + * @param[in] input_B_prepared An array representing the prepared B matrix. + * This must be obtained by using one of the `int8PrepareB*` + * functions Size of the array = `width` * `cols_B`. + * @param[in] width No. of rows of Input matrix B. It should be a multiple of 64. + * @param[in] cols_B No. of columns of Input matrix B. It should be a multiple of 8. + * @param[in] cols An array of column indices to be selected from prepared B. + * All indices of the array should be valid. i.e. + * 0 <= cols[N] < cols_B where N = 0, 1, 2 .... (`num_cols`-1) + * @param[in] num_cols Size of the `cols` array. It should be a multiple of 8. + * @param[out] output An array representing the selected columns of prepared B. + * Size of the array = `width` * `num_cols`. + */ +extern "C" void __attribute__((import_module("wasm_gemm"), import_name("int8_select_columns_of_b"))) +int8SelectColumnsOfB(const int8_t* input_B_prepared, + Index width, + Index cols_B, + const Index* cols, + const Index num_cols, + int8_t* output); + + +#endif + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 0d0b22ff61e01..0571942a2b30a 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1100,7 +1100,10 @@ static Status VerifyEachNodeIsAssignedToAnEpImpl(const Graph& graph, bool is_ver NodePlacementMap& node_placements, NodePlacementSet& node_placement_provider_set) { for (const auto& node : graph.Nodes()) { + const auto& node_provider = node.GetExecutionProviderType(); + printf("%s(%s), provider: %s\n", node.OpType().c_str(), node.Name().c_str(), node_provider.c_str()); + if (node_provider.empty()) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Could not find an implementation for ", diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 09a4a77780916..02bfed2b9660f 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1980,6 +1980,56 @@ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy- ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1); })); + +constexpr const char* FirefoxMatMulInteger_doc = R"DOC( +Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html +)DOC"; + + + +ONNX_MS_OPERATOR_SET_SCHEMA(FirefoxMatMulInteger8, 1, + OpSchema() + .SetDoc(FirefoxMatMulInteger_doc) + .Input(0, "A", "N-dimensional matrix A", "T1") + .Input(1, "B", "N-dimensional matrix B", "T2") + .Input(2, "a_zero_point", + "Zero point tensor for input 'A'. It's optional and default value is 0. It could be a scalar or a 1-D " + "tensor, " + "which means a per-tensor or per-column quantization. If it's a 1-D tensor, its number " + "of elements should be equal to the number of columns of input 'A'.", + "T1", OpSchema::Optional) + .Input(3, "b_zero_point", + "Zero point tensor for input 'B'. It's optional and default value is 0. It could be a scalar or a 1-D " + "tensor, " + "which means a per-tensor or per-column quantization. If it's a 1-D tensor, its number " + "of elements should be equal to the number of columns of input 'B'.", + "T2", OpSchema::Optional) + .Output(0, "Y", "Matrix multiply results from A * B", "T3") + .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input A data types as 8-bit integer tensor") + .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input B data types as 8-bit integer tensor") + .TypeConstraint("T3", + {"tensor(int32)", "tensor(uint32)"}, + "Constrain output Y data types as 32-bit integer tensor." + "T3 must be tensor(uint16) when both T1 and T2 are tensor(uint8)," + "or must be tensor(int16) when either T1 or T2 is tensor(int8).") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + auto a_type = ctx.getInputType(0); + auto b_type = ctx.getInputType(1); + auto y_type = ctx.getOutputType(0); + if (nullptr == a_type || nullptr == b_type || nullptr == y_type || + a_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType || + b_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType) { + fail_type_inference( + "inputs are expected to have tensor type and output type should not be null."); + } + + // Right now we only support int16 + y_type->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto::INT32); + + ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1); + })); + + /** * @brief Shape inference for MatMul with right hand side matrix quantized into int4 * @param ctx @@ -3780,6 +3830,7 @@ Having this op allows runtime to do operator re-ordering to reduce compute FLOPs #endif + #ifndef _OPSCHEMA_LIB_ // Register the NCHWc schemas if supported by the platform. if (MlasNchwcGetBlockSize() > 1) { diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index a9a89f756b071..5d98743f4078d 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -79,6 +79,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Irfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, IsAllFinite); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LongformerAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulInteger16); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FirefoxMatMulInteger8); #ifndef ORT_MINIMAL_BUILD class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4); #endif @@ -189,6 +190,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); #ifndef ORT_MINIMAL_BUILD fn(GetOpSchema()); #endif diff --git a/onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc b/onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc new file mode 100644 index 0000000000000..3b8e079ec705a --- /dev/null +++ b/onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math_cpuonly.h" + +namespace onnxruntime { +namespace test { + +TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_1) { + OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain); + test.AddInput("T1", {1, 1}, {15}); + test.AddInput("T2", {1, 1}, {8}); + test.AddOutput("T3", {1, 1}, {120}); // Result is 15 * 8 + test.Run(); +} + +TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_2) { + OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain); + test.AddInput("T1", {1, 2}, {-7, 10}); + test.AddInput("T2", {2, 1}, {-8, -11}); + test.AddOutput("T3", {1, 1}, {8}); // Result is (-7 * -8) + (10 * -11) + test.Run(); +} + +TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_Empty_input) { + OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain); + test.AddInput("T1", {0, 2}, {}); + test.AddInput("T2", {2, 1}, {-8, -11}); + test.AddOutput("T3", {0, 1}, {}); // Empty input produces an empty output + test.Run(); +} + +TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_3) { + OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain); + test.AddInput("T1", {3, 2}, {-7, 10, 10, -113, 22, -36}); + test.AddInput("T2", {2, 4}, {-8, -11, 13, 14, -9, 12, 3, -6}); + test.AddOutput("T3", {3, 4}, + {-158, 97, -61, -2, // First row results + 989, -1426, 1693, 1682, // Second row results + 282, -518, 280, -372}); // Third row results + test.Run(); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 7f4616c964e33..8d8227acb48ab 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -1329,8 +1329,11 @@ TEST(InferenceSessionTests, TestOptionalInputs) { "Invalid input name"); // missing required + printf("here"); + ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(RunOptionalInputTest(false, true, false, version, sess_env), (version == 3 ? "Invalid input name" : "Missing Input:")); + printf("here 2"); } } diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 45e2475548df5..d23200a389100 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -3,13 +3,13 @@ 'use strict'; + // // This file contains the pre-run code for the ORT WebAssembly module. The code in this file will be injected into the // final module using Emscripten's `--pre-js` option. // // This file will only be used in build with flag `--use_jsep`. - /** * initialize JSEP for asyncify support. */ @@ -109,7 +109,7 @@ let jsepInitAsync = () => { if (Module.jsepSessionState) { throw new Error('Session already started'); } - const state = Module.jsepSessionState = {sessionHandle: args[0], errors: []}; + const state = Module.jsepSessionState = { sessionHandle: args[0], errors: [] }; // Run the acyncified function: OrtRun() or OrtRunWithBinding() const ret = await runAsyncFunc(...args); @@ -141,21 +141,21 @@ let jsepInitAsync = () => { // replace the original functions with asyncified versions Module['_OrtCreateSession'] = jsepWrapAsync( - Module['_OrtCreateSession'], - () => Module['_OrtCreateSession'], - v => Module['_OrtCreateSession'] = v); + Module['_OrtCreateSession'], + () => Module['_OrtCreateSession'], + v => Module['_OrtCreateSession'] = v); Module['_OrtRun'] = runAsync(jsepWrapAsync( - Module['_OrtRun'], - () => Module['_OrtRun'], - v => Module['_OrtRun'] = v)); + Module['_OrtRun'], + () => Module['_OrtRun'], + v => Module['_OrtRun'] = v)); Module['_OrtRunWithBinding'] = runAsync(jsepWrapAsync( - Module['_OrtRunWithBinding'], - () => Module['_OrtRunWithBinding'], - v => Module['_OrtRunWithBinding'] = v)); + Module['_OrtRunWithBinding'], + () => Module['_OrtRunWithBinding'], + v => Module['_OrtRunWithBinding'] = v)); Module['_OrtBindInput'] = jsepWrapAsync( - Module['_OrtBindInput'], - () => Module['_OrtBindInput'], - v => Module['_OrtBindInput'] = v); + Module['_OrtBindInput'], + () => Module['_OrtBindInput'], + v => Module['_OrtBindInput'] = v); // remove this function to make sure it is called only once. jsepInitAsync = undefined; @@ -170,16 +170,16 @@ Module['jsepInit'] = (name, params) => { if (name === 'webgpu') { [Module.jsepBackend, - Module.jsepAlloc, - Module.jsepFree, - Module.jsepCopy, - Module.jsepCopyAsync, - Module.jsepCreateKernel, - Module.jsepReleaseKernel, - Module.jsepRunKernel, - Module.jsepCaptureBegin, - Module.jsepCaptureEnd, - Module.jsepReplay] = params; + Module.jsepAlloc, + Module.jsepFree, + Module.jsepCopy, + Module.jsepCopyAsync, + Module.jsepCreateKernel, + Module.jsepReleaseKernel, + Module.jsepRunKernel, + Module.jsepCaptureBegin, + Module.jsepCaptureEnd, + Module.jsepReplay] = params; // expose webgpu backend functions const backend = Module.jsepBackend; @@ -211,11 +211,11 @@ Module['jsepInit'] = (name, params) => { // change the name. [Module.jsepBackend, - Module.jsepReserveTensorId, - Module.jsepReleaseTensorId, - Module['jsepEnsureTensor'], - Module.jsepUploadTensor, - Module['jsepDownloadTensor'], + Module.jsepReserveTensorId, + Module.jsepReleaseTensorId, + Module['jsepEnsureTensor'], + Module.jsepUploadTensor, + Module['jsepDownloadTensor'], ] = params; // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. @@ -243,7 +243,7 @@ Module['jsepInit'] = (name, params) => { }; Module['jsepRegisterMLConstant'] = (externalFilePath, dataOffset, dataLength, builder, desc) => { return backend['registerMLConstant']( - externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); + externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); }; } }; diff --git a/onnxruntime/wasm/pre.js b/onnxruntime/wasm/pre.js index 9b5f3ce545b78..2d786a844f259 100644 --- a/onnxruntime/wasm/pre.js +++ b/onnxruntime/wasm/pre.js @@ -49,4 +49,93 @@ Module['unmountExternalData'] = () => { * @suppress {checkVars} */ var SharedArrayBuffer = globalThis.SharedArrayBuffer ?? - new WebAssembly.Memory({'initial': 0, 'maximum': 0, 'shared': true}).buffer.constructor; + new WebAssembly.Memory({ 'initial': 0, 'maximum': 0, 'shared': true }).buffer.constructor; + + + +function asmjsMangle(x) { + var unmangledSymbols = ["stackAlloc", "stackSave", "stackRestore"]; + return x.indexOf("dynCall_") == 0 || unmangledSymbols.includes(x) ? x : "_" + x; +} + +function exportAsmFunctions(asm) { + var global_object = this; + for (var __exportedFunc in asm) { + var jsname = asmjsMangle(__exportedFunc); + Module[jsname] = asm[__exportedFunc]; + if (global_object) { + global_object[__exportedFunc] = asm[__exportedFunc]; + } + } +} + + +function fallbackGemm(gemmToFallbackFunctionsMap) { + // The fallback gemm implementation + const FALLBACK_GEMM = "asm"; + + let fallbackGemmModuleExports = {}; + for (let key in gemmToFallbackFunctionsMap) { + fallbackGemmModuleExports[key] = (...a) => + Module[FALLBACK_GEMM][gemmToFallbackFunctionsMap[key]](...a); + } + return fallbackGemmModuleExports; +} + +/** +* Custom call to instantiate WebAssembly module. so we can use custom imports +*/ Module["instantiateWasm"] = async (info, receiveInstance) => { + const wasmBinaryFile = findWasmBinary(); + const bytes = await getBinaryPromise(wasmBinaryFile); + const module = await WebAssembly.compile(bytes); + let imports = getWasmImports(); + + // XXX mozIntGemm can't be used from web pages - we use a fallback if we are not privileged + const OPTIMIZED_GEMM = "mozIntGemm"; + + const optimizedGemmModule = WebAssembly[OPTIMIZED_GEMM]; + if (!optimizedGemmModule) { + const GEMM_TO_FALLBACK_FUNCTIONS_MAP = { + int8_prepare_a: "int8PrepareAFallback", + int8_prepare_b: "int8PrepareBFallback", + int8_prepare_b_from_transposed: "int8PrepareBFromTransposedFallback", + int8_prepare_b_from_quantized_transposed: + "int8PrepareBFromQuantizedTransposedFallback", + int8_prepare_bias: "int8PrepareBiasFallback", + int8_multiply_and_add_bias: "int8MultiplyAndAddBiasFallback", + int8_select_columns_of_b: "int8SelectColumnsOfBFallback", + }; + imports.wasm_gemm = fallbackGemm(GEMM_TO_FALLBACK_FUNCTIONS_MAP); + } + + else { + var INITIAL_MEMORY = 16777216; + var gemmWasmMemory = new WebAssembly.Memory({ + "initial": INITIAL_MEMORY / 65536, + "maximum": 4294967296 / 65536, + "shared": false + }); + const optimizedGemmModuleExports = new WebAssembly.Instance(optimizedGemmModule(), { + "": { + memory: gemmWasmMemory + } + }).exports; + imports.wasm_gemm = optimizedGemmModuleExports; + } + function mozReceiveInstance(instance) { + // XXX do we need a moz specific stuff here? + //var exports = instance.exports; + //Module.asm = exports; + // wasmTable = Module.asm.__indirect_function_table; ??? + //exportAsmFunctions(exports); + return receiveInstance(instance); + } + try { + var instance = new WebAssembly.Instance(module, imports); + mozReceiveInstance(instance); + } catch (error) { + console.error("Error creating WebAssembly instance:", error); + throw error; + } +}; +