From 7c31b4cada0a8ef380af5e1e0e0381a7fde257ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tarek=20Ziad=C3=A9?= Date: Thu, 12 Dec 2024 12:43:40 +0100 Subject: [PATCH] comparing both stuff --- .../quantization/firefox_matmul_integer.cc | 55 +++++++++++++++++-- 1 file changed, 51 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc index 773e5ecb1209b..ffe92feb840e7 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc @@ -108,24 +108,71 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const { std::vector int32_output(helper.M() * helper.N(), 0); #ifdef __EMSCRIPTEN__ + uint8_t zero_point_b = *(b_offset_ptr + helper.RightZeroPointOffsets()[0]); + + // Output all inputs before the call + std::cout << "Matrix A:\n"; + for (size_t i = 0; i < static_cast(helper.M()); ++i) { + for (size_t j = 0; j < static_cast(helper.K()); ++j) { + std::cout << static_cast(a_data[i * helper.K() + j]) << " "; + } + std::cout << "\n"; + } + + std::cout << "Matrix B:\n"; + for (size_t i = 0; i < static_cast(helper.K()); ++i) { + for (size_t j = 0; j < static_cast(helper.N()); ++j) { + std::cout << static_cast(b_data[i * helper.N() + j]) << " "; + } + std::cout << "\n"; + } + + std::cout << "A Zero point: " << static_cast(a_offset) << "\n"; + std::cout << "B zero_point: " << static_cast(zero_point_b) << "\n"; + std::cout << "rows A: " << helper.M() << ", width: " << helper.K() << ", Cols B: " << helper.N() << "\n"; + std::cout << "B is packed: " << (packed_b_ ? "true" : "false") << "\n"; + std::cout << "B is signed: " << (b_is_signed ? "true" : "false") << "\n"; + + // Gemmology call int8Multiply(reinterpret_cast(a_data), a_offset, reinterpret_cast(b_data), - 0, // b_zero_point + zero_point_b, static_cast(helper.M()), // rows A static_cast(helper.K()), // width static_cast(helper.N()), // col B reinterpret_cast(int32_output.data())); #endif + // Original MatmulInteger call MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool()); // Compare the outputs std::cout << "Comparing Outputs:\n"; - for (size_t i = 0; i < int32_output.size(); ++i) { - std::cout << "Index " << i << ": int8Multiply = " << int32_output[i] - << ", MlasGemmBatch = " << static_cast(y_data[i]) << "\n"; + std::cout << "Gemmology:\n"; + for (size_t i = 0; i < static_cast(helper.M()); ++i) { + for (size_t j = 0; j < static_cast(helper.N()); ++j) { + std::cout << static_cast(int32_output[i * helper.N() + j]) << " "; + } + std::cout << "\n"; } + std::cout << "MBLas:\n"; + for (size_t i = 0; i < static_cast(helper.M()); ++i) { + for (size_t j = 0; j < static_cast(helper.N()); ++j) { + std::cout << static_cast(y_data[i * helper.N() + j]) << " "; + } + std::cout << "\n"; + } + +for (size_t i = 0; i < static_cast(helper.M()); ++i) { + for (size_t j = 0; j < static_cast(helper.N()); ++j) { + size_t index = i * helper.N() + j; + if (int32_output[index] != static_cast(y_data[index])) { + ORT_ENFORCE(false, "Mismatch at Row ", i, ", Col ", j, ": int8Multiply = ", int32_output[index], + ", MlasGemmBatch = ", static_cast(y_data[index])); + } + } +} return Status::OK(); }