Skip to content

Commit

Permalink
comparing both stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
tarekziade committed Dec 12, 2024
1 parent 2ecbcdf commit 7c31b4c
Showing 1 changed file with 51 additions and 4 deletions.
55 changes: 51 additions & 4 deletions onnxruntime/contrib_ops/cpu/quantization/firefox_matmul_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,24 +108,71 @@ Status FirefoxMatMulInteger8::Compute(OpKernelContext* ctx) const {
std::vector<int32_t> int32_output(helper.M() * helper.N(), 0);

#ifdef __EMSCRIPTEN__
uint8_t zero_point_b = *(b_offset_ptr + helper.RightZeroPointOffsets()[0]);

// Output all inputs before the call
std::cout << "Matrix A:\n";
for (size_t i = 0; i < static_cast<size_t>(helper.M()); ++i) {
for (size_t j = 0; j < static_cast<size_t>(helper.K()); ++j) {
std::cout << static_cast<int>(a_data[i * helper.K() + j]) << " ";
}
std::cout << "\n";
}

std::cout << "Matrix B:\n";
for (size_t i = 0; i < static_cast<size_t>(helper.K()); ++i) {
for (size_t j = 0; j < static_cast<size_t>(helper.N()); ++j) {
std::cout << static_cast<int>(b_data[i * helper.N() + j]) << " ";
}
std::cout << "\n";
}

std::cout << "A Zero point: " << static_cast<int>(a_offset) << "\n";
std::cout << "B zero_point: " << static_cast<int>(zero_point_b) << "\n";
std::cout << "rows A: " << helper.M() << ", width: " << helper.K() << ", Cols B: " << helper.N() << "\n";
std::cout << "B is packed: " << (packed_b_ ? "true" : "false") << "\n";
std::cout << "B is signed: " << (b_is_signed ? "true" : "false") << "\n";

// Gemmology call
int8Multiply(reinterpret_cast<const int8_t*>(a_data),
a_offset,
reinterpret_cast<const int8_t*>(b_data),
0, // b_zero_point
zero_point_b,
static_cast<size_t>(helper.M()), // rows A
static_cast<size_t>(helper.K()), // width
static_cast<size_t>(helper.N()), // col B
reinterpret_cast<float*>(int32_output.data()));
#endif

// Original MatmulInteger call
MlasGemmBatch(gemm_shape, gemm_data_vec.data(), batch_size, ctx->GetOperatorThreadPool());

// Compare the outputs
std::cout << "Comparing Outputs:\n";
for (size_t i = 0; i < int32_output.size(); ++i) {
std::cout << "Index " << i << ": int8Multiply = " << int32_output[i]
<< ", MlasGemmBatch = " << static_cast<float>(y_data[i]) << "\n";
std::cout << "Gemmology:\n";
for (size_t i = 0; i < static_cast<size_t>(helper.M()); ++i) {
for (size_t j = 0; j < static_cast<size_t>(helper.N()); ++j) {
std::cout << static_cast<int>(int32_output[i * helper.N() + j]) << " ";
}
std::cout << "\n";
}
std::cout << "MBLas:\n";
for (size_t i = 0; i < static_cast<size_t>(helper.M()); ++i) {
for (size_t j = 0; j < static_cast<size_t>(helper.N()); ++j) {
std::cout << static_cast<int>(y_data[i * helper.N() + j]) << " ";
}
std::cout << "\n";
}

for (size_t i = 0; i < static_cast<size_t>(helper.M()); ++i) {
for (size_t j = 0; j < static_cast<size_t>(helper.N()); ++j) {
size_t index = i * helper.N() + j;
if (int32_output[index] != static_cast<float>(y_data[index])) {
ORT_ENFORCE(false, "Mismatch at Row ", i, ", Col ", j, ": int8Multiply = ", int32_output[index],
", MlasGemmBatch = ", static_cast<float>(y_data[index]));
}
}
}

return Status::OK();
}
Expand Down

0 comments on commit 7c31b4c

Please sign in to comment.