Skip to content

Commit

Permalink
savepoint
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekziade committed Jan 7, 2025
1 parent 26dd1ee commit c90b933
Show file tree
Hide file tree
Showing 7 changed files with 227 additions and 81 deletions.
117 changes: 56 additions & 61 deletions onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "core/util/qmath.h"
#include <cassert>

#include <chrono>
#include <algorithm>

namespace onnxruntime {
Expand All @@ -20,30 +21,20 @@ namespace {


using Index = uint32_t;
extern "C" void
__attribute__((import_module("wasm_gemm"), import_name("int8_multiply")))
int8Multiply(const uint8_t* input_A,
float zero_point_A,
const int8_t* input_B,
//const uint8_t* zero_point_B,
Index rows_A,
Index width,
Index cols_B,
float* output);

extern "C" void
__attribute__((import_module("wasm_gemm"), import_name("f32_multiply")))
f32Multiply(
const uint8_t* a_data,
float zero_point_A,
const int8_t* input_B,
const uint8_t* zero_point_B,
Index rows_A,
Index width,
Index cols_B,
const float* b_scale_data,
float is_b_scale_per_column,
float* output
__attribute__((import_module("wasm_gemm"), import_name("onnx_matmul_integer_to_float")))
GeckoMatmulIntegerToFloat(
const uint8_t* a_data,
float zero_point_A,
const int8_t* input_B,
const uint8_t* zero_point_B,
uint32_t rows_A,
uint32_t width,
uint32_t cols_B,
const float* b_scale_data,
float is_b_scale_per_column,
float* output
);


Expand Down Expand Up @@ -234,19 +225,9 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
params.ldc = gemm_shape.N;
}

// rowsA = M
// width = K
// colsB = N
size_t rowsA = static_cast<size_t>(helper.M());
size_t width = static_cast<size_t>(helper.K());
size_t colsB = static_cast<size_t>(helper.N());
const int8_t* b_data = static_cast<const int8_t*>(b_tensor->DataRaw());

#if 0
size_t total_elements = rowsA * colsB;
size_t display_limit = std::min(total_elements, static_cast<size_t>(100));
#if 0
std::vector<float> y_data_2(rowsA * colsB, 0.0f);
std::cout << "Calling MatMulFull with the following parameters:\n";
if (rowsA > 1) {
std::cout << "rowsA: " << rowsA << ", width: " << width << ", colsB: " << colsB << "\n";
std::cout << "a_zp: " << static_cast<int>(a_zp) << "\n";
std::cout << "is_b_scale_per_column: " << is_b_scale_per_column << "\n";
Expand Down Expand Up @@ -276,13 +257,22 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
}
std::cout << "multiplier_per_tensor: " << multiplier_per_tensor << std::endl;
std::cout << "b_scale_data[0]: " << b_scale_data[0] << std::endl;
}
#endif
//auto start = std::chrono::steady_clock::now();
//std::cout << "Calling f32Multiply\n";
// should split in parts and call ctx.ParallelFor just on the rows part

//MatMulFull(a_data, b_data, y_data, rowsA, width, colsB, a_zp, b_zp_ptr, b_scale_data, is_b_scale_per_column);

std::cout << "Calling f32Multiply\n";
// rowsA = M
// width = K
// colsB = N
size_t rowsA = static_cast<size_t>(helper.M());
size_t width = static_cast<size_t>(helper.K());
size_t colsB = static_cast<size_t>(helper.N());

const int8_t* b_data = static_cast<const int8_t*>(b_tensor->DataRaw());

f32Multiply(a_data,
GeckoMatmulIntegerToFloat(a_data,
a_zp,
b_data,
b_zp_ptr,
Expand All @@ -291,15 +281,28 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
colsB,
b_scale_data,
is_b_scale_per_column,
y_data);

std::cout << "Done calling f32Multiply\n";


#if 0
y_data
);

// MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool());
/*
auto end = std::chrono::steady_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
std::cout << "Done calling f32Multiply. Duration: " << duration << " nano\n";
std::cout << "Calling MlasGemmBatch\n";
auto start2 = std::chrono::steady_clock::now();
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), num_gemms, ctx->GetOperatorThreadPool());
auto end2 = std::chrono::steady_clock::now();
auto duration2 = std::chrono::duration_cast<std::chrono::nanoseconds>(end2 - start2).count();
std::cout << "Done calling MlasGemmBatch. Duration: " << duration2 << " nano\n";
*/
/*
// Compare y_data and y_data_2
size_t total_elements = rowsA * colsB;
size_t display_limit = std::min(total_elements, static_cast<size_t>(100));
bool mismatch_found = false;
for (size_t i = 0; i < total_elements; ++i) {
if (std::fabs(y_data[i] - y_data_2[i]) > 1e-6) { // Tolerance for floating-point comparison
Expand All @@ -322,23 +325,10 @@ Status MatMulIntegerToFloatBase::ComputeCommon(OpKernelContext* ctx,
std::cerr << "Mismatch found between y_data and y_data_2!" << std::endl;
assert(false && "Validation failed: y_data and y_data_2 are not equal.");
}
#endif
*/
return Status::OK();
}

/*
int8Multiply(
reinterpret_cast<const uint8_t*>(a_data),
a_zp,
b_data,
//reinterpret_cast<const uint8_t*>(b_zero_point->DataRaw()),
rowsA,
width,
colsB,
reinterpret_cast<float*>(y_data)
);
*/

class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {
public:
DynamicQuantizeMatMul(const OpKernelInfo& info) : MatMulIntegerToFloatBase(info) {}
Expand Down Expand Up @@ -405,6 +395,10 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
ParQuantizeLinearStd(a_data, a_data_quant, narrow<size_t>(num_of_elements), a_scale, a_zero_point, ctx->GetOperatorThreadPool());

bool is_b_scale_supported = IsBQuantParamSupported(b_scale_tensor->Shape(), b ? b->Shape() : b_shape_);

//std::cout << "dynamic quantize matmul calling ComputeCommon" << std::endl;


ORT_RETURN_IF_ERROR(ComputeCommon(
ctx,
a_data_quant,
Expand All @@ -418,7 +412,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const {
ctx->Input<Tensor>(IN_BIAS)));

if (!is_b_scale_supported) {
std::cout << "dynamic quantize matmul: b scale is not supported\n";
//std::cout << "dynamic quantize matmul: b scale is not supported\n";
ScaleOutput(*b_scale_tensor, *ctx->Output<Tensor>(0));
}

Expand Down Expand Up @@ -460,6 +454,7 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const {
a_zero_point = *(static_cast<const uint8_t*>(a_zero_point_tensor->DataRaw()));
}

//std::cout << "matmul integer float calling ComputeCommon" << std::endl;
const Tensor* b_zp_tensor = ctx->Input<Tensor>(IN_B_ZERO_POINT);
ORT_RETURN_IF_ERROR(ComputeCommon(
ctx,
Expand All @@ -474,11 +469,11 @@ Status MatMulIntegerToFloat::Compute(OpKernelContext* ctx) const {
ctx->Input<Tensor>(IN_BIAS)));

if (!is_a_scale_scalar) {
std::cout << "dynamic quantize matmul: a scale is not scalar\n";
//std::cout << "dynamic quantize matmul: a scale is not scalar\n";
ScaleOutput(*a_scale_tensor, *ctx->Output<Tensor>(0));
}
if (!is_b_scale_supported) {
std::cout << "dynamic quantize matmul: b scale is not supported\n";
//std::cout << "dynamic quantize matmul: b scale is not supported\n";
ScaleOutput(*b_scale_tensor, *ctx->Output<Tensor>(0));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,16 +149,18 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
}
#endif
//auto start_matmul = Clock::now();
/*
int8Multiply(
reinterpret_cast<const uint8_t*>(a->DataRaw()),
a_offset,
reinterpret_cast<const int8_t*>(b->DataRaw()),
//reinterpret_cast<const uint8_t*>(b_zero_point->DataRaw()),
reinterpret_cast<const uint8_t*>(b_zero_point->DataRaw()),
M,
K,
N,
reinterpret_cast<float*>(y_data)
);
*/
//auto end_matmul = Clock::now();
//auto matmul_time = std::chrono::duration_cast<Microseconds>(end_matmul - start_matmul).count();

Expand Down Expand Up @@ -202,6 +204,7 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
std::cout << "MlasGemmBatch: " << mblas_time << "\n";

#endif
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool());
return Status::OK();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ class FirefoxMatMulInteger8 final : public MatMulIntegerBase {

#include <cstdint>

using Index = uint32_t;
#if 0
extern "C" void
__attribute__((import_module("wasm_gemm"), import_name("int8_multiply")))
int8Multiply(const uint8_t* input_A,
float zero_point_A,
const int8_t* input_B,
//const uint8_t* zero_point_B,
Index rows_A,
Index width,
Index cols_B,
const uint8_t* zero_point_B,
float rows_A,
float width,
float cols_B,
float* output);

#endif

#endif

Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/core/mlas/lib/qgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Module Name:

#include "mlasi.h"
#include "qgemm.h"
#include <iostream>

//
// Define the parameters to execute segments of a QGEMM operation on worker
Expand Down Expand Up @@ -144,6 +145,8 @@ MlasGemmBatch(

const double Complexity = double(M) * double(N) * double(K) * double(BatchN);

//std::cout << "Complexity: " << Complexity << std::endl;

ptrdiff_t TargetThreadCount;

if (Complexity < double(MLAS_QGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) {
Expand Down Expand Up @@ -194,10 +197,16 @@ MlasGemmBatch(
WorkBlock.ThreadCountN = 1;
}
TargetThreadCount = ThreadsPerGemm * BatchN;
//std::cout << "ThreadsPerGemm: " << ThreadsPerGemm << std::endl;
//std::cout << "TargetThreadCount: " << TargetThreadCount << std::endl;
//std::cout << "MaximumThreadCount: " << MaximumThreadCount << std::endl;



MlasTrySimpleParallel(ThreadPool, TargetThreadCount, [&](ptrdiff_t tid) {
const auto gemm_i = tid / ThreadsPerGemm;
const auto blk_i = tid % ThreadsPerGemm;
//std::cout << "gemm_i: " << gemm_i << " blk_i: " << blk_i << std::endl;
MlasGemmQuantThreaded(&WorkBlock, &Shape, &DataParams[gemm_i], blk_i);
});
}
Expand Down Expand Up @@ -277,6 +286,13 @@ MlasSymmQgemmBatch(
const size_t ThreadCountM = MlasDivRoundup(M, StrideM);
const size_t ThreadCountN = MlasDivRoundup(N, StrideN);
ThreadsPerGemm = ThreadCountM * ThreadCountN;

/*
std::cout << "ThreadsPerGemm" << ThreadsPerGemm << std::endl;
std::cout << "TargetThreadCount " <<TargetThreadCount << std::endl;
std::cout << "ThreadCountM" << ThreadCountM << std::endl;
std::cout << "ThreadCountN" << ThreadCountN << std::endl;
*/

MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * BatchN, [&](ptrdiff_t tid) {
auto uarch = MLAS_CPUIDINFO::GetCPUIDInfo().IsCurrentCoreArmv8NarrowLd();
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/mlas/lib/qgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ Module Name:
#include <string>
#include <cstdlib>

extern "C" void
__attribute__((import_module("wasm_gemm"), import_name("mlas_gemm_u8x8")))

Check warning on line 40 in onnxruntime/core/mlas/lib/qgemm.h

View workflow job for this annotation

GitHub Actions / Vcpkg

unknown attribute 'import_module' ignored [-Wunknown-attributes]

Check warning on line 40 in onnxruntime/core/mlas/lib/qgemm.h

View workflow job for this annotation

GitHub Actions / Vcpkg

unknown attribute 'import_name' ignored [-Wunknown-attributes]

Check warning on line 40 in onnxruntime/core/mlas/lib/qgemm.h

View workflow job for this annotation

GitHub Actions / Vcpkg

unknown attribute 'import_module' ignored [-Wunknown-attributes]

Check warning on line 40 in onnxruntime/core/mlas/lib/qgemm.h

View workflow job for this annotation

GitHub Actions / Vcpkg

unknown attribute 'import_name' ignored [-Wunknown-attributes]

Check warning on line 40 in onnxruntime/core/mlas/lib/qgemm.h

View workflow job for this annotation

GitHub Actions / Vcpkg

unknown attribute 'import_module' ignored [-Wunknown-attributes]

Check warning on line 40 in onnxruntime/core/mlas/lib/qgemm.h

View workflow job for this annotation

GitHub Actions / Vcpkg

unknown attribute 'import_name' ignored [-Wunknown-attributes]

Check warning on line 40 in onnxruntime/core/mlas/lib/qgemm.h

View workflow job for this annotation

GitHub Actions / Vcpkg

unknown attribute 'import_module' ignored [-Wunknown-attributes]

Check warning on line 40 in onnxruntime/core/mlas/lib/qgemm.h

View workflow job for this annotation

GitHub Actions / Vcpkg

unknown attribute 'import_name' ignored [-Wunknown-attributes]
xMlasGemmU8X8MultiplyAccumulateRowWasmSimd(
const float* A,
const float* B,
const float* C
);


//
// Define the default striding parameters used for the quantized integer
// matrix/matrix multiply operation.
Expand Down
Loading

0 comments on commit c90b933

Please sign in to comment.