Skip to content

Commit

Permalink
added a firefox matmul backend
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekziade committed Dec 6, 2024
1 parent 49a80df commit c0b1091
Show file tree
Hide file tree
Showing 11 changed files with 361 additions and 37 deletions.
15 changes: 9 additions & 6 deletions build.sh
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
#!/bin/bash
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
set -ex

# Get directory this script is in
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OS=$(uname -s)

if [ "$OS" = "Darwin" ]; then
DIR_OS="MacOS"
DIR_OS="MacOS"
else
DIR_OS="Linux"
DIR_OS="Linux"
fi

if [[ "$*" == *"--ios"* ]]; then
DIR_OS="iOS"
DIR_OS="iOS"
elif [[ "$*" == *"--android"* ]]; then
DIR_OS="Android"
DIR_OS="Android"
fi

python3 $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@"
PYTHON="${PYTHON:-python3}"

$PYTHON $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@"
1 change: 1 addition & 0 deletions cmake/onnxruntime_webassembly.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ jsepDownload:_pp_")
"SHELL:-s ASYNCIFY_STACK_SIZE=65536"
"SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']"
"SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']"
"SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0"
)
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js)
endif()
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGram
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, FirefoxMatMulInteger8);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, FirefoxMatMulInteger8);
// ******** Start: Quantization ******************* //
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool);
Expand Down Expand Up @@ -285,6 +286,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, FirefoxMatMulInteger8)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, FirefoxMatMulInteger8)>,
// add more kernels here
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention)>,
Expand Down Expand Up @@ -364,7 +367,6 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UnfoldTensor)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DynamicTimeWarping)>,

#ifdef ENABLE_ATEN
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen)>,
#endif
Expand Down
160 changes: 160 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <iostream>

#include "firefox_matmul_integer.h"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"
#include "core/util/math_cpuonly.h"
#include "core/util/qmath.h"

namespace onnxruntime {
namespace contrib {

ONNX_OPERATOR_TYPED_KERNEL_EX(
FirefoxMatMulInteger8,
kMSDomain,
1,
uint8_t,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<uint8_t>())
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<uint8_t>(), DataTypeImpl::GetTensorType<int8_t>()})
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
FirefoxMatMulInteger8);

ONNX_OPERATOR_TYPED_KERNEL_EX(
FirefoxMatMulInteger8,
kMSDomain,
1,
int8_t,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<int8_t>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int8_t>())
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
FirefoxMatMulInteger8);


#ifdef __EMSCRIPTEN__
#include <emscripten/emscripten.h>

extern "C" void __attribute__((import_name("int8MultiplyAndAddBias")))
int8MultiplyAndAddBias();

#endif

Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {



const auto* a = ctx->Input<Tensor>(IN_A);
const auto* b = packed_b_ ? nullptr : ctx->Input<Tensor>(IN_B);

std::cout << "Input Tensor A shape: " << (a ? a->Shape().ToString() : "null") << std::endl;
if (b) {
std::cout << "Input Tensor B shape: " << b->Shape().ToString() << std::endl;
} else {
std::cout << "Using packed B." << std::endl;
}

// Validate zero points
uint8_t a_offset = 0;
const auto* a_zero_point = ctx->Input<Tensor>(IN_A_ZERO_POINT);
if (a_zero_point != nullptr) {
std::cout << "A Zero Point shape: " << a_zero_point->Shape().ToString() << std::endl;
ORT_ENFORCE(IsScalarOr1ElementVector(a_zero_point),
"MatmulInteger : input1 zero point must be a scalar or 1D tensor of size 1");
a_offset = *(static_cast<const uint8_t*>(a_zero_point->DataRaw()));
std::cout << "A Zero Point value: " << static_cast<int>(a_offset) << std::endl;
}

bool is_b_zp_per_column = false;
uint8_t b_default_offset = 0;
const uint8_t* b_offset_ptr = &b_default_offset;
const auto* b_zero_point = ctx->Input<Tensor>(IN_B_ZERO_POINT);
if (b_zero_point != nullptr) {
std::cout << "B Zero Point shape: " << b_zero_point->Shape().ToString() << std::endl;
ORT_ENFORCE(IsBQuantParamSupported(b_zero_point->Shape(), b ? b->Shape() : b_shape_),
"MatmulInteger : B zero point is not valid");
is_b_zp_per_column = !IsScalarOr1ElementVector(b_zero_point);
b_offset_ptr = static_cast<const uint8_t*>(b_zero_point->DataRaw());
std::cout << "B Zero Point is per-column: " << is_b_zp_per_column << std::endl;
}

MatMulComputeHelper helper;
const uint8_t* b_data;
bool b_is_signed;
if (nullptr != b) {
std::cout << "Computing helper with A and B shapes." << std::endl;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b->Shape(), nullptr, b_zero_point ? &b_zero_point->Shape() : nullptr));
b_data = static_cast<const uint8_t*>(b->DataRaw());
b_is_signed = b->IsDataType<int8_t>();
} else {
std::cout << "Computing helper with A shape and packed B." << std::endl;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_, nullptr, b_zero_point ? &b_zero_point->Shape() : nullptr));
b_data = static_cast<const uint8_t*>(packed_b_.get());
b_is_signed = b_is_signed_;
}

Tensor* y = ctx->Output(OUT_Y, helper.OutputShape());
std::cout << "Output Tensor Y shape: " << y->Shape().ToString() << std::endl;

if (y->Shape().Size() == 0) {
std::cout << "Output Tensor is empty. Exiting early." << std::endl;
return Status::OK();
}

const uint8_t* a_data = static_cast<const uint8_t*>(a->DataRaw());
auto* y_data = y->MutableData<int32_t>();

MLAS_GEMM_QUANT_SHAPE_PARAMS gemm_shape;
gemm_shape.M = static_cast<size_t>(helper.M());
gemm_shape.N = static_cast<size_t>(helper.N());
gemm_shape.K = static_cast<size_t>(helper.K());
gemm_shape.AIsSigned = a->IsDataType<int8_t>();
gemm_shape.BIsSigned = b_is_signed;

std::cout << "GEMM Shape - M: " << gemm_shape.M << ", N: " << gemm_shape.N
<< ", K: " << gemm_shape.K << ", AIsSigned: " << gemm_shape.AIsSigned
<< ", BIsSigned: " << gemm_shape.BIsSigned << std::endl;

const size_t batch_size = helper.OutputOffsets().size();
std::cout << "Batch size: " << batch_size << std::endl;

std::vector<MLAS_GEMM_QUANT_DATA_PARAMS> gemm_data_vec(batch_size);

for (size_t batch = 0; batch < batch_size; batch++) {
auto& gemm_params = gemm_data_vec[batch];
gemm_params.lda = gemm_shape.K;
gemm_params.ZeroPointA = a_offset;
gemm_params.ldb = gemm_shape.N;
gemm_params.ZeroPointB = b_offset_ptr + helper.RightZeroPointOffsets()[batch];
gemm_params.PerColumnZeroPoints = is_b_zp_per_column;
gemm_params.ldc = gemm_shape.N;
gemm_params.BIsPacked = bool(packed_b_);
gemm_params.A = a_data + helper.LeftOffsets()[batch];
gemm_params.B = b_data + helper.RightOffsets()[batch];
gemm_params.C = y_data + helper.OutputOffsets()[batch];

std::cout << "Batch " << batch << " - A offset: " << helper.LeftOffsets()[batch]
<< ", B offset: " << helper.RightOffsets()[batch]
<< ", C offset: " << helper.OutputOffsets()[batch] << std::endl;
}

std::cout << "Calling MlasGemmBatch." << std::endl;

#ifdef __EMSCRIPTEN__
onnxruntime::contrib::int8MultiplyAndAddBias();
#endif

MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool());

std::cout << "Exiting FirefoxMatMulInteger8::Compute" << std::endl;
return Status::OK();
}


} // namespace contrib
} // namespace onnxruntime
33 changes: 33 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/cpu/quantization/matmul_integer_base.h"
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/util/math_cpuonly.h"

namespace onnxruntime {
namespace contrib {

class FirefoxMatMulInteger8 final : public MatMulIntegerBase {
public:
FirefoxMatMulInteger8(const OpKernelInfo& info) : MatMulIntegerBase(info) {}
Status Compute(OpKernelContext* context) const override;

enum InputTensors : int {
IN_A = 0,
IN_B = 1,
IN_A_ZERO_POINT = 2,
IN_B_ZERO_POINT = 3
};

enum OutputTensors : int { OUT_Y = 0 };

protected:
int GetBIdx() const override { return IN_B; }
};

} // namespace contrib
} // namespace onnxruntime
3 changes: 3 additions & 0 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,10 @@ static Status VerifyEachNodeIsAssignedToAnEpImpl(const Graph& graph, bool is_ver
NodePlacementMap& node_placements,
NodePlacementSet& node_placement_provider_set) {
for (const auto& node : graph.Nodes()) {

const auto& node_provider = node.GetExecutionProviderType();
printf("%s(%s), provider: %s\n", node.OpType().c_str(), node.Name().c_str(), node_provider.c_str());

if (node_provider.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"Could not find an implementation for ",
Expand Down
51 changes: 51 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1980,6 +1980,56 @@ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-
ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1);
}));


constexpr const char* FirefoxMatMulInteger_doc = R"DOC(
Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html
)DOC";



ONNX_MS_OPERATOR_SET_SCHEMA(FirefoxMatMulInteger8, 1,
OpSchema()
.SetDoc(FirefoxMatMulInteger_doc)
.Input(0, "A", "N-dimensional matrix A", "T1")
.Input(1, "B", "N-dimensional matrix B", "T2")
.Input(2, "a_zero_point",
"Zero point tensor for input 'A'. It's optional and default value is 0. It could be a scalar or a 1-D "
"tensor, "
"which means a per-tensor or per-column quantization. If it's a 1-D tensor, its number "
"of elements should be equal to the number of columns of input 'A'.",
"T1", OpSchema::Optional)
.Input(3, "b_zero_point",
"Zero point tensor for input 'B'. It's optional and default value is 0. It could be a scalar or a 1-D "
"tensor, "
"which means a per-tensor or per-column quantization. If it's a 1-D tensor, its number "
"of elements should be equal to the number of columns of input 'B'.",
"T2", OpSchema::Optional)
.Output(0, "Y", "Matrix multiply results from A * B", "T3")
.TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input A data types as 8-bit integer tensor")
.TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input B data types as 8-bit integer tensor")
.TypeConstraint("T3",
{"tensor(int32)", "tensor(uint32)"},
"Constrain output Y data types as 32-bit integer tensor."
"T3 must be tensor(uint16) when both T1 and T2 are tensor(uint8),"
"or must be tensor(int16) when either T1 or T2 is tensor(int8).")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
auto a_type = ctx.getInputType(0);
auto b_type = ctx.getInputType(1);
auto y_type = ctx.getOutputType(0);
if (nullptr == a_type || nullptr == b_type || nullptr == y_type ||
a_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType ||
b_type->value_case() != ONNX_NAMESPACE::TypeProto::kTensorType) {
fail_type_inference(
"inputs are expected to have tensor type and output type should not be null.");
}

// Right now we only support int16
y_type->mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto::INT32);

ONNX_NAMESPACE::defs::math::utils::MatMulShapeInference(ctx, 0, 1);
}));


/**
* @brief Shape inference for MatMul with right hand side matrix quantized into int4
* @param ctx
Expand Down Expand Up @@ -3780,6 +3830,7 @@ Having this op allows runtime to do operator re-ordering to reduce compute FLOPs

#endif


#ifndef _OPSCHEMA_LIB_
// Register the NCHWc schemas if supported by the platform.
if (MlasNchwcGetBlockSize() > 1) {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/graph/contrib_ops/ms_opset.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Irfft);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, IsAllFinite);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LongformerAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulInteger16);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FirefoxMatMulInteger8);
#ifndef ORT_MINIMAL_BUILD
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4);
#endif
Expand Down Expand Up @@ -189,6 +190,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, IsAllFinite)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LongformerAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulInteger16)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FirefoxMatMulInteger8)>());
#ifndef ORT_MINIMAL_BUILD
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4)>());
#endif
Expand Down
Loading

0 comments on commit c0b1091

Please sign in to comment.