Skip to content

Commit

Permalink
added a firefox matmul backend
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekziade committed Dec 2, 2024
1 parent 755e2d3 commit 51b6ae1
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 18 deletions.
15 changes: 9 additions & 6 deletions build.sh
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
#!/bin/bash
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
set -ex

# Get directory this script is in
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OS=$(uname -s)

if [ "$OS" = "Darwin" ]; then
DIR_OS="MacOS"
DIR_OS="MacOS"
else
DIR_OS="Linux"
DIR_OS="Linux"
fi

if [[ "$*" == *"--ios"* ]]; then
DIR_OS="iOS"
DIR_OS="iOS"
elif [[ "$*" == *"--android"* ]]; then
DIR_OS="Android"
DIR_OS="Android"
fi

python3 $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@"
PYTHON="${PYTHON:-python3}"

$PYTHON $DIR/tools/ci_build/build.py --build_dir $DIR/build/$DIR_OS "$@"
3 changes: 1 addition & 2 deletions onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGram
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QuickGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FirefoxMatMulInteger);

class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FirefoxMatMulInteger8);
// ******** Start: Quantization ******************* //
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ Status FirefoxMatMulInteger8<int8_t, int8_t, int16_t>::Compute(OpKernelContext*
static_cast<int>(helper.K()));
}

printf("I was called\n");
return Status::OK();
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1987,7 +1987,7 @@ Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-



ONNX_MS_OPERATOR_SET_SCHEMA(FirefoxMatMulInteger, 1,
ONNX_MS_OPERATOR_SET_SCHEMA(FirefoxMatMulInteger8, 1,
OpSchema()
.SetDoc(FirefoxMatMulInteger_doc)
.Input(0, "A", "N-dimensional matrix A", "T1")
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/graph/contrib_ops/ms_opset.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Irfft);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, IsAllFinite);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LongformerAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulInteger16);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FirefoxMatMulInteger);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FirefoxMatMulInteger8);
#ifndef ORT_MINIMAL_BUILD
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4);
#endif
Expand Down Expand Up @@ -190,7 +190,11 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, IsAllFinite)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, LongformerAttention)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulInteger16)>());
<<<<<<< HEAD

Check failure on line 193 in onnxruntime/core/graph/contrib_ops/ms_opset.h

View workflow job for this annotation

GitHub Actions / Vcpkg

version control conflict marker in file
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FirefoxMatMulInteger)>());

Check failure on line 194 in onnxruntime/core/graph/contrib_ops/ms_opset.h

View workflow job for this annotation

GitHub Actions / Vcpkg

unknown type name 'FirefoxMatMulInteger_Microsoft_ver1'; did you mean 'FirefoxMatMulInteger8_Microsoft_ver1'?
=======
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, FirefoxMatMulInteger8)>());
>>>>>>> 045a17021c (added a firefox matmul backend)
#ifndef ORT_MINIMAL_BUILD
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, MatMulFpQ4)>());
#endif
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/test/contrib_ops/firefox_matmul_integer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,32 @@
namespace onnxruntime {
namespace test {

TEST(FirefoxMatMulIntegerOpTest, FirefoxMatMulInteger_1) {
OpTester test("FirefoxMatMulInteger", 1, onnxruntime::kMSDomain);
TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_1) {
OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain);
test.AddInput<int8_t>("T1", {1, 1}, {15});
test.AddInput<int8_t>("T2", {1, 1}, {8});
test.AddOutput<int32_t>("T3", {1, 1}, {120}); // Result is 15 * 8
test.Run();
}

TEST(FirefoxMatMulIntegerOpTest, FirefoxMatMulInteger_2) {
OpTester test("FirefoxMatMulInteger", 1, onnxruntime::kMSDomain);
TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_2) {
OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain);
test.AddInput<int8_t>("T1", {1, 2}, {-7, 10});
test.AddInput<int8_t>("T2", {2, 1}, {-8, -11});
test.AddOutput<int32_t>("T3", {1, 1}, {8}); // Result is (-7 * -8) + (10 * -11)
test.Run();
}

TEST(FirefoxMatMulIntegerOpTest, FirefoxMatMulInteger_Empty_input) {
OpTester test("FirefoxMatMulInteger", 1, onnxruntime::kMSDomain);
TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_Empty_input) {
OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain);
test.AddInput<int8_t>("T1", {0, 2}, {});
test.AddInput<int8_t>("T2", {2, 1}, {-8, -11});
test.AddOutput<int32_t>("T3", {0, 1}, {}); // Empty input produces an empty output
test.Run();
}

TEST(FirefoxMatMulIntegerOpTest, FirefoxMatMulInteger_3) {
OpTester test("FirefoxMatMulInteger", 1, onnxruntime::kMSDomain);
TEST(FirefoxMatMulInteger8OpTest, FirefoxMatMulInteger8_3) {
OpTester test("FirefoxMatMulInteger8", 1, onnxruntime::kMSDomain);
test.AddInput<int8_t>("T1", {3, 2}, {-7, 10, 10, -113, 22, -36});
test.AddInput<int8_t>("T2", {2, 4}, {-8, -11, 13, 14, -9, 12, 3, -6});
test.AddOutput<int32_t>("T3", {3, 4},
Expand Down

0 comments on commit 51b6ae1

Please sign in to comment.