Skip to content

Commit

Permalink
added a firefox matmul backend
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekziade committed Jan 9, 2025
1 parent 49a80df commit 5560a90
Show file tree
Hide file tree
Showing 17 changed files with 797 additions and 53 deletions.
15 changes: 9 additions & 6 deletions build.sh
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
#!/bin/bash
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
set -ex

# Get directory this script is in
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OS=$(uname -s)

if [ "$OS" = "Darwin" ]; then
DIR_OS="MacOS"
DIR_OS="MacOS"
else
DIR_OS="Linux"
DIR_OS="Linux"
fi

if [[ "$*" == *"--ios"* ]]; then
DIR_OS="iOS"
DIR_OS="iOS"
elif [[ "$*" == *"--android"* ]]; then
DIR_OS="Android"
DIR_OS="Android"
fi

python3 $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@"
PYTHON="${PYTHON:-python3}"

$PYTHON $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@"
1 change: 1 addition & 0 deletions cmake/onnxruntime_webassembly.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ jsepDownload:_pp_")
"SHELL:-s ASYNCIFY_STACK_SIZE=65536"
"SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']"
"SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']"
"SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0"
)
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js)
endif()
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGram
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, FirefoxMatMulInteger8);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, FirefoxMatMulInteger8);
// ******** Start: Quantization ******************* //
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool);
Expand Down Expand Up @@ -285,6 +286,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, FirefoxMatMulInteger8)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, FirefoxMatMulInteger8)>,
// add more kernels here
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention)>,
Expand Down Expand Up @@ -364,7 +367,6 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UnfoldTensor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DynamicTimeWarping)>,

#ifdef ENABLE_ATEN
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
#endif
Expand Down
184 changes: 183 additions & 1 deletion onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,35 @@
#include "core/providers/cpu/quantization/matmul_integer_base.h"
#include "core/util/math_cpuonly.h"
#include "core/util/qmath.h"
#include <cassert>

#include <chrono>
#include <algorithm>

namespace onnxruntime {
namespace contrib {

namespace {


using Index = uint32_t;

extern "C" void
__attribute__((import_module("wasm_gemm"), import_name("onnx_matmul_integer_to_float")))
GeckoMatmulIntegerToFloat(
const uint8_t* a_data,
float zero_point_A,
const int8_t* input_B,
const uint8_t* zero_point_B,
uint32_t rows_A,
uint32_t width,
uint32_t cols_B,
const float* b_scale_data,
float is_b_scale_per_column,
float* output
);


void ScaleOutput(const Tensor& scale, Tensor& output) {
ProcessBroadcastSpanFuncs funcs{
[](BroadcastHelper& per_iter_bh) {
Expand Down Expand Up @@ -51,12 +73,65 @@ class MatMulIntegerToFloatBase : public MatMulIntegerBase {
float a_scale,
uint8_t a_zp,
bool a_is_signed,
const Tensor* b_tensor,
const Tensor* b_tensor,
const Tensor* b_scale,
const Tensor* b_zp,
const Tensor* bias_tensor) const;
};

void MatMulFull(const uint8_t* inputMatrixA,
const int8_t* inputMatrixB,
float* output,
size_t rowsA,
size_t width,
size_t colsB,
uint8_t zeroPointA,
const uint8_t* zeroPointB,
const float* b_scale_data,
bool is_b_scale_per_column) {

float matrixScale = is_b_scale_per_column ? 0.0f : b_scale_data[0];
int32_t matrixZeroPointB = is_b_scale_per_column ? 0 : static_cast<int32_t>(zeroPointB[0]);

for (size_t rowIndex = 0; rowIndex < rowsA; ++rowIndex) {
const uint8_t* aRow = inputMatrixA + rowIndex * width; // Start of row in A
for (size_t colIndex = 0; colIndex < colsB; ++colIndex) {
int32_t tempResult = 0;

for (size_t k = 0; k < width; ++k) {
// Row-major access
uint8_t aValue = aRow[k];

// Column-major access for B
int8_t bValue = inputMatrixB[k * colsB + colIndex];

// Adjust for zero-point offsets
int32_t adjustedA = static_cast<int32_t>(aValue) - static_cast<int32_t>(zeroPointA);
int32_t adjustedB = static_cast<int32_t>(bValue);

if (is_b_scale_per_column) {
adjustedB -= static_cast<int32_t>(zeroPointB[colIndex]);
} else {
adjustedB -= matrixZeroPointB;
}
// Accumulate product
tempResult += adjustedA * adjustedB;
}

float scaledResult = tempResult;
if (is_b_scale_per_column) {
scaledResult *= b_scale_data[colIndex];
}
else {
scaledResult *= matrixScale;
}

// Store the scaled result in y_data
output[rowIndex * colsB + colIndex] = scaledResult;
}
}
}

Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
const uint8_t* a_data,
const TensorShape& a_shape,
Expand Down Expand Up @@ -150,8 +225,107 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
params.ldc = gemm_shape.N;
}

#if 0
std::vector<float> y_data_2(rowsA * colsB, 0.0f);
if (rowsA > 1) {
std::cout << "rowsA: " << rowsA << ", width: " << width << ", colsB: " << colsB << "\n";
std::cout << "a_zp: " << static_cast<int>(a_zp) << "\n";
std::cout << "is_b_scale_per_column: " << is_b_scale_per_column << "\n";
std::cout << "multiplier_per_tensor: " << multiplier_per_tensor << "\n";
std::cout << "b_scale_data sample: [";
for (size_t i = 0; i < 25; ++i) {
if (i > 0) std::cout << ", ";
std::cout << b_scale_data[i];
}
std::cout << "]\n";
std::cout << "b_zero point sample: [";
for (size_t i = 0; i < 25; ++i) {
if (i > 0) std::cout << ", ";
std::cout << static_cast<int>(b_zp_ptr[i]) << ", ";
}
std::cout << "]\n";

if (bias_data != nullptr) {
size_t bias_size = static_cast<size_t>(bias_tensor->Shape().Size()); // Get the total size of bias_data
size_t display_limit = std::min(bias_size, static_cast<size_t>(100));
std::cout << "First " << display_limit << " elements of bias_data: [";
for (size_t i = 0; i < display_limit; ++i) {
if (i > 0) std::cout << ", ";
std::cout << bias_data[i];
}
std::cout << "]" << std::endl;
}
std::cout << "multiplier_per_tensor: " << multiplier_per_tensor << std::endl;
std::cout << "b_scale_data[0]: " << b_scale_data[0] << std::endl;
}
#endif
//auto start = std::chrono::steady_clock::now();
//std::cout << "Calling f32Multiply\n";
// should split in parts and call ctx.ParallelFor just on the rows part

// rowsA = M
// width = K
// colsB = N
size_t rowsA = static_cast<size_t>(helper.M());
size_t width = static_cast<size_t>(helper.K());
size_t colsB = static_cast<size_t>(helper.N());

const int8_t* b_data = static_cast<const int8_t*>(b_tensor->DataRaw());

GeckoMatmulIntegerToFloat(a_data,
a_zp,
b_data,
b_zp_ptr,
rowsA,
width,
colsB,
b_scale_data,
is_b_scale_per_column,
y_data
);

// MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool());
/*
auto end = std::chrono::steady_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
std::cout << "Done calling f32Multiply. Duration: " << duration << " nano\n";
std::cout << "Calling MlasGemmBatch\n";
auto start2 = std::chrono::steady_clock::now();
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool());
auto end2 = std::chrono::steady_clock::now();
auto duration2 = std::chrono::duration_cast<std::chrono::nanoseconds>(end2 - start2).count();
std::cout << "Done calling MlasGemmBatch. Duration: " << duration2 << " nano\n";
*/
/*
// Compare y_data and y_data_2
size_t total_elements = rowsA * colsB;
size_t display_limit = std::min(total_elements, static_cast<size_t>(100));
bool mismatch_found = false;
for (size_t i = 0; i < total_elements; ++i) {
if (std::fabs(y_data[i] - y_data_2[i]) > 1e-6) { // Tolerance for floating-point comparison
std::cerr << "Mismatch at index " << i << ": y_data=" << y_data[i] << ", y_data_2=" << y_data_2[i] << std::endl;
mismatch_found = true;
break;
}
}
if (mismatch_found) {
std::cerr << "Displaying the first 100 elements of y_data and y_data_2:" << std::endl;
std::cerr << "[";
for (size_t i = 0; i < display_limit; ++i) {
std::cerr << "(Index " << i << ": y_data=" << y_data[i] << ", y_data_2=" << y_data_2[i] << ")";
if (i != display_limit - 1) {
std::cerr << ", ";
}
}
std::cerr << "]" << std::endl;
std::cerr << "Mismatch found between y_data and y_data_2!" << std::endl;
assert(false && "Validation failed: y_data and y_data_2 are not equal.");
}
*/
return Status::OK();
}

Expand Down Expand Up @@ -221,6 +395,10 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
ParQuantizeLinearStd(a_data, a_data_quant, narrow<size_t>(num_of_elements), a_scale, a_zero_point, ctx->GetOperatorThreadPool());

bool is_b_scale_supported = IsBQuantParamSupported(b_scale_tensor->Shape(), b ? b->Shape() : b_shape_);

//std::cout << "dynamic quantize matmul calling ComputeCommon" << std::endl;


ORT_RETURN_IF_ERROR(ComputeCommon(
ctx,
a_data_quant,
Expand All @@ -234,6 +412,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
ctx->Input<Tensor>(IN_BIAS)));

if (!is_b_scale_supported) {
//std::cout << "dynamic quantize matmul: b scale is not supported\n";
ScaleOutput(*b_scale_tensor, *ctx->Output<Tensor>(0));
}

Expand Down Expand Up @@ -275,6 +454,7 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const {
a_zero_point = *(static_cast<const uint8_t*>(a_zero_point_tensor->DataRaw()));
}

//std::cout << "matmul integer float calling ComputeCommon" << std::endl;
const Tensor* b_zp_tensor = ctx->Input<Tensor>(IN_B_ZERO_POINT);
ORT_RETURN_IF_ERROR(ComputeCommon(
ctx,
Expand All @@ -289,9 +469,11 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const {
ctx->Input<Tensor>(IN_BIAS)));

if (!is_a_scale_scalar) {
//std::cout << "dynamic quantize matmul: a scale is not scalar\n";
ScaleOutput(*a_scale_tensor, *ctx->Output<Tensor>(0));
}
if (!is_b_scale_supported) {
//std::cout << "dynamic quantize matmul: b scale is not supported\n";
ScaleOutput(*b_scale_tensor, *ctx->Output<Tensor>(0));
}

Expand Down
Loading

0 comments on commit 5560a90

Please sign in to comment.