From 86a312c962ec1ff7befe340668c532938993eb54 Mon Sep 17 00:00:00 2001 From: Yihong Zhang Date: Mon, 22 Jul 2024 13:54:36 -0700 Subject: [PATCH 01/10] Add three failing tests --- test/error/tiled_matmul_wrong_layout.cpp | 148 +++++++++++++++++++++++ test/error/tiled_matmul_wrong_modulo.cpp | 148 +++++++++++++++++++++++ test/error/tiled_matmul_wrong_tiling.cpp | 105 ++++++++++++++++ 3 files changed, 401 insertions(+) create mode 100644 test/error/tiled_matmul_wrong_layout.cpp create mode 100644 test/error/tiled_matmul_wrong_modulo.cpp create mode 100644 test/error/tiled_matmul_wrong_tiling.cpp diff --git a/test/error/tiled_matmul_wrong_layout.cpp b/test/error/tiled_matmul_wrong_layout.cpp new file mode 100644 index 000000000000..8fc754f33570 --- /dev/null +++ b/test/error/tiled_matmul_wrong_layout.cpp @@ -0,0 +1,148 @@ +#include "Halide.h" +#include "halide_test_dirs.h" +#include + +using namespace Halide; + +template +void fill_buffer_a(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +void fill_buffer_b(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 4; iy++) { + for (int ix = 0; ix < col; ix++) { + for (int ik = 0; ik < 8; ++ik) { + buf(ik, ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } + } +} + +bool equal_eps(float lhs, float rhs, float eps) { + return std::abs(lhs - rhs) < eps; +} + +struct make_uint_t { + template + Type operator()(Args &&...args) const { + return UInt(static_cast(args)...); + } +}; + +struct make_int_t { + template + Type operator()(Args &&...args) const { + return Int(static_cast(args)...); + } +}; + +template +void print_mat(const Buffer &buf, int rows, int cols) { + using cast_T = std::conditional_t, int, T>; + for (int j = 0; j != rows; ++j) { + for (int i = 0; i != cols; ++i) { + std::cout << static_cast(buf(i, j)) << " "; + } + std::cout << std::endl; + } +} + +template +void print_mat_rhs(const Buffer &buf, int rows, int cols) { + using cast_T = std::conditional_t, int, T>; + for (int j = 0; j != (rows / (4 / sizeof(T))); ++j) { + for (int k = 0; k != (4 / sizeof(T)); ++k) { + for (int i = 0; i != cols; ++i) { + std::cout << static_cast(buf(k, i, j)) << " "; + } + + std::cout << std::endl; + } + } +} + +template +bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { + Target target("x86-64-linux-avx512_sapphirerapids"); + Buffer A_buf(acc, row); + Buffer B_buf(8, col, acc / 4); + + Var x("x"), y("y"); + RDom r(0, acc); + + Func mm("matmul"); + mm(x, y) = cast(0); + mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(r % 4, x, r / 4)); + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r, rro, rri, tile_r) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); + + Func result = mm.in(); + + // Should err with AMX mapping failure since B buffer has a + // different layout than expected by AMX + result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); + + if (get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + std::cerr << "Validating compiled program\n"; + + fill_buffer_a(A_buf, row, acc); + fill_buffer_b(B_buf, col, acc); + Buffer out(col, row); + result.realize(out); + + bool should_continue = true; + for (int j = 0; j < row && should_continue; ++j) { + for (int i = 0; i < col && should_continue; ++i) { + int32_t val = 0; + for (int k = 0; k < acc; ++k) { + val += static_cast(A_buf(k, j)) * static_cast(B_buf(k % 4, i, k / 4)); + } + if (val != out(i, j)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n" + << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; + return false; + } + } + } + } + + return true; +} + +int main(int argc, char **argv) { + // matmul(2, 2, 16, 2, 2, 8); + // matmul(4, 4, 8, 4, 4, 8); + // matmul(32, 32, 32, 8, 8, 8); + matmul(32, 32, 32, 8, 8, 4); +} \ No newline at end of file diff --git a/test/error/tiled_matmul_wrong_modulo.cpp b/test/error/tiled_matmul_wrong_modulo.cpp new file mode 100644 index 000000000000..d17e050f2be1 --- /dev/null +++ b/test/error/tiled_matmul_wrong_modulo.cpp @@ -0,0 +1,148 @@ +#include "Halide.h" +#include "halide_test_dirs.h" +#include + +using namespace Halide; + +template +void fill_buffer_a(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +void fill_buffer_b(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 4; iy++) { + for (int ix = 0; ix < col; ix++) { + for (int ik = 0; ik < 4; ++ik) { + buf(ik, ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } + } +} + +bool equal_eps(float lhs, float rhs, float eps) { + return std::abs(lhs - rhs) < eps; +} + +struct make_uint_t { + template + Type operator()(Args &&...args) const { + return UInt(static_cast(args)...); + } +}; + +struct make_int_t { + template + Type operator()(Args &&...args) const { + return Int(static_cast(args)...); + } +}; + +template +void print_mat(const Buffer &buf, int rows, int cols) { + using cast_T = std::conditional_t, int, T>; + for (int j = 0; j != rows; ++j) { + for (int i = 0; i != cols; ++i) { + std::cout << static_cast(buf(i, j)) << " "; + } + std::cout << std::endl; + } +} + +template +void print_mat_rhs(const Buffer &buf, int rows, int cols) { + using cast_T = std::conditional_t, int, T>; + for (int j = 0; j != (rows / (4 / sizeof(T))); ++j) { + for (int k = 0; k != (4 / sizeof(T)); ++k) { + for (int i = 0; i != cols; ++i) { + std::cout << static_cast(buf(k, i, j)) << " "; + } + + std::cout << std::endl; + } + } +} + +template +bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { + Target target("x86-64-linux-avx512_sapphirerapids"); + Buffer A_buf(acc, row); + Buffer B_buf(4, col, acc / 4); + + Var x("x"), y("y"); + RDom r(0, acc); + + Func mm("matmul"); + mm(x, y) = cast(0); + // Mod is 3 instead of 4 + mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(r % 3, x, r / 4)); + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r, rro, rri, tile_r) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); + + Func result = mm.in(); + + // Should err with AMX mapping failure since B buffer is not swizzled correctly + result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); + + if (get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + std::cerr << "Validating compiled program\n"; + + fill_buffer_a(A_buf, row, acc); + fill_buffer_b(B_buf, col, acc); + Buffer out(col, row); + result.realize(out); + + bool should_continue = true; + for (int j = 0; j < row && should_continue; ++j) { + for (int i = 0; i < col && should_continue; ++i) { + int32_t val = 0; + for (int k = 0; k < acc; ++k) { + val += static_cast(A_buf(k, j)) * static_cast(B_buf(k % 3, i, k / 4)); + } + if (val != out(i, j)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n" + << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; + return false; + } + } + } + } + + return true; +} + +int main(int argc, char **argv) { + // matmul(2, 2, 16, 2, 2, 8); + // matmul(4, 4, 8, 4, 4, 8); + matmul(32, 32, 32, 8, 8, 8); + // matmul(32, 32, 32, 8, 8, 4); +} \ No newline at end of file diff --git a/test/error/tiled_matmul_wrong_tiling.cpp b/test/error/tiled_matmul_wrong_tiling.cpp new file mode 100644 index 000000000000..6efec3f8a9c0 --- /dev/null +++ b/test/error/tiled_matmul_wrong_tiling.cpp @@ -0,0 +1,105 @@ +#include "Halide.h" +#include "halide_test_dirs.h" +#include + +using namespace Halide; + +template +void fill_buffer_a(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +void fill_buffer_b(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc / 8; iy++) { + for (int ix = 0; ix < col; ix++) { + for (int ik = 0; ik < 8; ++ik) { + buf(ik, ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } + } +} + +template +bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { + Target target("x86-64-linux-avx512_sapphirerapids"); + Buffer A_buf(acc, row); + Buffer B_buf(8, col, acc / 8); + + Var x("x"), y("y"); + RDom r(0, acc); + + Func mm("matmul"); + mm(x, y) = cast(0); + mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(r % 8, x, r / 8)); + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r, rro, rri, tile_r) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); + + Func result = mm.in(); + + // Should err with AMX mapping failure since the tiling is set to 8, + // which is not what AMX expects + result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); + + if (get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + std::cerr << "Validating compiled program\n"; + + fill_buffer_a(A_buf, row, acc); + fill_buffer_b(B_buf, col, acc); + Buffer out(col, row); + result.realize(out); + + bool should_continue = true; + for (int j = 0; j < row && should_continue; ++j) { + for (int i = 0; i < col && should_continue; ++i) { + int32_t val = 0; + for (int k = 0; k < acc; ++k) { + val += static_cast(A_buf(k, j)) * static_cast(B_buf(k % 8, i, k / 8)); + } + if (val != out(i, j)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n" + << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; + return false; + } + } + } + } + + return true; +} + +int main(int argc, char **argv) { + // matmul(2, 2, 16, 2, 2, 8); + // matmul(4, 4, 8, 4, 4, 8); + matmul(32, 32, 32, 8, 8, 8); + // matmul(32, 32, 32, 8, 8, 4); +} \ No newline at end of file From bcf99f0283b38da24f4edeb3f77bc9c1ea41690a Mon Sep 17 00:00:00 2001 From: Yihong Zhang Date: Mon, 22 Jul 2024 16:08:22 -0700 Subject: [PATCH 02/10] add validate flag --- test/error/tiled_matmul_wrong_layout.cpp | 72 ++++++------------------ test/error/tiled_matmul_wrong_modulo.cpp | 70 ++++++----------------- test/error/tiled_matmul_wrong_tiling.cpp | 31 +++++----- 3 files changed, 51 insertions(+), 122 deletions(-) diff --git a/test/error/tiled_matmul_wrong_layout.cpp b/test/error/tiled_matmul_wrong_layout.cpp index 8fc754f33570..a2713f7cba50 100644 --- a/test/error/tiled_matmul_wrong_layout.cpp +++ b/test/error/tiled_matmul_wrong_layout.cpp @@ -24,51 +24,8 @@ void fill_buffer_b(Buffer &buf, int col, int acc) { } } -bool equal_eps(float lhs, float rhs, float eps) { - return std::abs(lhs - rhs) < eps; -} - -struct make_uint_t { - template - Type operator()(Args &&...args) const { - return UInt(static_cast(args)...); - } -}; - -struct make_int_t { - template - Type operator()(Args &&...args) const { - return Int(static_cast(args)...); - } -}; - -template -void print_mat(const Buffer &buf, int rows, int cols) { - using cast_T = std::conditional_t, int, T>; - for (int j = 0; j != rows; ++j) { - for (int i = 0; i != cols; ++i) { - std::cout << static_cast(buf(i, j)) << " "; - } - std::cout << std::endl; - } -} - -template -void print_mat_rhs(const Buffer &buf, int rows, int cols) { - using cast_T = std::conditional_t, int, T>; - for (int j = 0; j != (rows / (4 / sizeof(T))); ++j) { - for (int k = 0; k != (4 / sizeof(T)); ++k) { - for (int i = 0; i != cols; ++i) { - std::cout << static_cast(buf(k, i, j)) << " "; - } - - std::cout << std::endl; - } - } -} - template -bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { +bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool validate) { Target target("x86-64-linux-avx512_sapphirerapids"); Buffer A_buf(acc, row); Buffer B_buf(8, col, acc / 4); @@ -108,11 +65,11 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { Func result = mm.in(); - // Should err with AMX mapping failure since B buffer has a - // different layout than expected by AMX - result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); - - if (get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + if (!validate) { + // Should err with AMX mapping failure since B buffer has a + // different layout than expected by AMX + result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); + } else { std::cerr << "Validating compiled program\n"; fill_buffer_a(A_buf, row, acc); @@ -129,8 +86,8 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { } if (val != out(i, j)) { std::cerr << "Invalid result at " << i << ", " << j << "\n" - << out(i, j) << " != " << val << "\n" - << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; + << out(i, j) << " != " << val << "\n" + << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; return false; } } @@ -141,8 +98,13 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { } int main(int argc, char **argv) { - // matmul(2, 2, 16, 2, 2, 8); - // matmul(4, 4, 8, 4, 4, 8); - // matmul(32, 32, 32, 8, 8, 8); - matmul(32, 32, 32, 8, 8, 4); + bool validate = false; + if (argc == 2 && argv[1] == std::string("--validate")) { + validate = true; + } + if (validate && !get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + std::cerr << "Skipping test since target does not support AMX\n"; + return 0; + } + matmul(32, 32, 32, 8, 8, 4, validate); } \ No newline at end of file diff --git a/test/error/tiled_matmul_wrong_modulo.cpp b/test/error/tiled_matmul_wrong_modulo.cpp index d17e050f2be1..cbf7521ba4c2 100644 --- a/test/error/tiled_matmul_wrong_modulo.cpp +++ b/test/error/tiled_matmul_wrong_modulo.cpp @@ -24,51 +24,8 @@ void fill_buffer_b(Buffer &buf, int col, int acc) { } } -bool equal_eps(float lhs, float rhs, float eps) { - return std::abs(lhs - rhs) < eps; -} - -struct make_uint_t { - template - Type operator()(Args &&...args) const { - return UInt(static_cast(args)...); - } -}; - -struct make_int_t { - template - Type operator()(Args &&...args) const { - return Int(static_cast(args)...); - } -}; - -template -void print_mat(const Buffer &buf, int rows, int cols) { - using cast_T = std::conditional_t, int, T>; - for (int j = 0; j != rows; ++j) { - for (int i = 0; i != cols; ++i) { - std::cout << static_cast(buf(i, j)) << " "; - } - std::cout << std::endl; - } -} - -template -void print_mat_rhs(const Buffer &buf, int rows, int cols) { - using cast_T = std::conditional_t, int, T>; - for (int j = 0; j != (rows / (4 / sizeof(T))); ++j) { - for (int k = 0; k != (4 / sizeof(T)); ++k) { - for (int i = 0; i != cols; ++i) { - std::cout << static_cast(buf(k, i, j)) << " "; - } - - std::cout << std::endl; - } - } -} - template -bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { +bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool validate) { Target target("x86-64-linux-avx512_sapphirerapids"); Buffer A_buf(acc, row); Buffer B_buf(4, col, acc / 4); @@ -109,10 +66,10 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { Func result = mm.in(); - // Should err with AMX mapping failure since B buffer is not swizzled correctly - result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); - - if (get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + if (!validate) { + // Should err with AMX mapping failure since B buffer is not swizzled correctly + result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); + } else { std::cerr << "Validating compiled program\n"; fill_buffer_a(A_buf, row, acc); @@ -129,8 +86,8 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { } if (val != out(i, j)) { std::cerr << "Invalid result at " << i << ", " << j << "\n" - << out(i, j) << " != " << val << "\n" - << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; + << out(i, j) << " != " << val << "\n" + << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; return false; } } @@ -141,8 +98,13 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { } int main(int argc, char **argv) { - // matmul(2, 2, 16, 2, 2, 8); - // matmul(4, 4, 8, 4, 4, 8); - matmul(32, 32, 32, 8, 8, 8); - // matmul(32, 32, 32, 8, 8, 4); + bool validate = false; + if (argc == 2 && argv[1] == std::string("--validate")) { + validate = true; + } + if (validate && !get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + std::cerr << "Skipping test since target does not support AMX\n"; + return 0; + } + matmul(32, 32, 32, 8, 8, 8, validate); } \ No newline at end of file diff --git a/test/error/tiled_matmul_wrong_tiling.cpp b/test/error/tiled_matmul_wrong_tiling.cpp index 6efec3f8a9c0..30af906f7473 100644 --- a/test/error/tiled_matmul_wrong_tiling.cpp +++ b/test/error/tiled_matmul_wrong_tiling.cpp @@ -25,7 +25,7 @@ void fill_buffer_b(Buffer &buf, int col, int acc) { } template -bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { +bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool validate) { Target target("x86-64-linux-avx512_sapphirerapids"); Buffer A_buf(acc, row); Buffer B_buf(8, col, acc / 8); @@ -65,13 +65,13 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { Func result = mm.in(); - // Should err with AMX mapping failure since the tiling is set to 8, - // which is not what AMX expects - result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); - - if (get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + if (!validate) { + // Should err with AMX mapping failure since the tiling is set to 8, + // which is not what AMX expects + result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); + } else { std::cerr << "Validating compiled program\n"; - + fill_buffer_a(A_buf, row, acc); fill_buffer_b(B_buf, col, acc); Buffer out(col, row); @@ -86,8 +86,8 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { } if (val != out(i, j)) { std::cerr << "Invalid result at " << i << ", " << j << "\n" - << out(i, j) << " != " << val << "\n" - << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; + << out(i, j) << " != " << val << "\n" + << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; return false; } } @@ -98,8 +98,13 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { } int main(int argc, char **argv) { - // matmul(2, 2, 16, 2, 2, 8); - // matmul(4, 4, 8, 4, 4, 8); - matmul(32, 32, 32, 8, 8, 8); - // matmul(32, 32, 32, 8, 8, 4); + bool validate = false; + if (argc == 2 && argv[1] == std::string("--validate")) { + validate = true; + } + if (validate && !get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + std::cerr << "Skipping test since target does not support AMX\n"; + return 0; + } + matmul(32, 32, 32, 8, 8, 8, validate); } \ No newline at end of file From f7d42cc36843924ef767c0b164d1b325e557d44d Mon Sep 17 00:00:00 2001 From: Yihong Zhang Date: Mon, 22 Jul 2024 16:14:30 -0700 Subject: [PATCH 03/10] remove should-continue --- test/error/tiled_matmul_wrong_layout.cpp | 5 ++--- test/error/tiled_matmul_wrong_modulo.cpp | 5 ++--- test/error/tiled_matmul_wrong_tiling.cpp | 5 ++--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/test/error/tiled_matmul_wrong_layout.cpp b/test/error/tiled_matmul_wrong_layout.cpp index a2713f7cba50..e13f802d0bb0 100644 --- a/test/error/tiled_matmul_wrong_layout.cpp +++ b/test/error/tiled_matmul_wrong_layout.cpp @@ -77,9 +77,8 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool Buffer out(col, row); result.realize(out); - bool should_continue = true; - for (int j = 0; j < row && should_continue; ++j) { - for (int i = 0; i < col && should_continue; ++i) { + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { int32_t val = 0; for (int k = 0; k < acc; ++k) { val += static_cast(A_buf(k, j)) * static_cast(B_buf(k % 4, i, k / 4)); diff --git a/test/error/tiled_matmul_wrong_modulo.cpp b/test/error/tiled_matmul_wrong_modulo.cpp index cbf7521ba4c2..7045d1f5185b 100644 --- a/test/error/tiled_matmul_wrong_modulo.cpp +++ b/test/error/tiled_matmul_wrong_modulo.cpp @@ -77,9 +77,8 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool Buffer out(col, row); result.realize(out); - bool should_continue = true; - for (int j = 0; j < row && should_continue; ++j) { - for (int i = 0; i < col && should_continue; ++i) { + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { int32_t val = 0; for (int k = 0; k < acc; ++k) { val += static_cast(A_buf(k, j)) * static_cast(B_buf(k % 3, i, k / 4)); diff --git a/test/error/tiled_matmul_wrong_tiling.cpp b/test/error/tiled_matmul_wrong_tiling.cpp index 30af906f7473..0329cd19b06a 100644 --- a/test/error/tiled_matmul_wrong_tiling.cpp +++ b/test/error/tiled_matmul_wrong_tiling.cpp @@ -77,9 +77,8 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool Buffer out(col, row); result.realize(out); - bool should_continue = true; - for (int j = 0; j < row && should_continue; ++j) { - for (int i = 0; i < col && should_continue; ++i) { + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { int32_t val = 0; for (int k = 0; k < acc; ++k) { val += static_cast(A_buf(k, j)) * static_cast(B_buf(k % 8, i, k / 8)); From 942e4d5aacf72780bc57de178d143b3a67646590 Mon Sep 17 00:00:00 2001 From: Yihong Zhang Date: Mon, 22 Jul 2024 16:20:50 -0700 Subject: [PATCH 04/10] minor --- test/error/tiled_matmul_wrong_layout.cpp | 1 + test/error/tiled_matmul_wrong_tiling.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/test/error/tiled_matmul_wrong_layout.cpp b/test/error/tiled_matmul_wrong_layout.cpp index e13f802d0bb0..f32da120f94c 100644 --- a/test/error/tiled_matmul_wrong_layout.cpp +++ b/test/error/tiled_matmul_wrong_layout.cpp @@ -28,6 +28,7 @@ template bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool validate) { Target target("x86-64-linux-avx512_sapphirerapids"); Buffer A_buf(acc, row); + // Each tile in B is padded with another 4 bytes. Buffer B_buf(8, col, acc / 4); Var x("x"), y("y"); diff --git a/test/error/tiled_matmul_wrong_tiling.cpp b/test/error/tiled_matmul_wrong_tiling.cpp index 0329cd19b06a..58f84fbdd68a 100644 --- a/test/error/tiled_matmul_wrong_tiling.cpp +++ b/test/error/tiled_matmul_wrong_tiling.cpp @@ -35,6 +35,7 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool Func mm("matmul"); mm(x, y) = cast(0); + // Tiling is set to 8 mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(r % 8, x, r / 8)); Var rxi("rxi"), ryi("ryi"); From 5ab925763cafa98e7d87f2d7580fa08be9e711c4 Mon Sep 17 00:00:00 2001 From: Yihong Zhang Date: Mon, 12 Aug 2024 17:39:01 -0700 Subject: [PATCH 05/10] add _wrong_pattern and update cmakelist --- test/error/CMakeLists.txt | 4 + test/error/tiled_matmul_wrong_pattern.cpp | 108 ++++++++++++++++++++++ 2 files changed, 112 insertions(+) create mode 100644 test/error/tiled_matmul_wrong_pattern.cpp diff --git a/test/error/CMakeLists.txt b/test/error/CMakeLists.txt index 52a2a01cd65e..afef43c870dd 100644 --- a/test/error/CMakeLists.txt +++ b/test/error/CMakeLists.txt @@ -104,6 +104,10 @@ tests(GROUPS error split_same_var_names.cpp store_at_without_compute_at.cpp thread_id_outside_block_id.cpp + tiled_matmul_wrong_layout.cpp + tiled_matmul_wrong_modulo.cpp + tiled_matmul_wrong_pattern.cpp + tiled_matmul_wrong_tiling.cpp too_many_args.cpp tuple_arg_select_undef.cpp tuple_output_bounds_check.cpp diff --git a/test/error/tiled_matmul_wrong_pattern.cpp b/test/error/tiled_matmul_wrong_pattern.cpp new file mode 100644 index 000000000000..c9ff2b95b5a5 --- /dev/null +++ b/test/error/tiled_matmul_wrong_pattern.cpp @@ -0,0 +1,108 @@ +#include "Halide.h" +#include "halide_test_dirs.h" +#include + +using namespace Halide; + +template +void fill_buffer_a(Buffer &buf, int row, int acc) { + for (int iy = 0; iy < row; iy++) { + for (int ix = 0; ix < acc; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +void fill_buffer_b(Buffer &buf, int col, int acc) { + for (int iy = 0; iy < acc; iy++) { + for (int ix = 0; ix < col; ix++) { + buf(ix, iy) = rand() % 256 + std::numeric_limits::min(); + } + } +} + +template +bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool validate) { + Target target("x86-64-linux-avx512_sapphirerapids"); + Buffer A_buf(acc, row); + Buffer B_buf(col, acc); + + Var x("x"), y("y"); + RDom r(0, acc); + + Func mm("matmul"); + mm(x, y) = cast(0); + mm(x, y) += cast(A_buf(r, y)) * cast(B_buf(x, r)); + + Var rxi("rxi"), ryi("ryi"); + RVar rri("rri"), rro("rro"); + + mm.compute_at(mm.in(), x) + .store_in(MemoryType::AMXTile) + .update() + .tile(x, y, rxi, ryi, tile_x, tile_y, TailStrategy::GuardWithIf) + .split(r, rro, rri, tile_r) + .reorder(rri, rxi, ryi, rro, x, y) + .atomic() + .vectorize(rri) + .vectorize(rxi) + .vectorize(ryi); + + Var ixi("ixi"), iyi("iyi"); + mm.compute_at(mm.in(), x) + .tile(x, y, ixi, iyi, tile_x, tile_y) + .vectorize(ixi) + .vectorize(iyi); + + Var mmxi("mmxi"), mmyi("mmyi"); + mm.in() + .tile(x, y, mmxi, mmyi, tile_x, tile_y) + .vectorize(mmxi) + .vectorize(mmyi); + + Func result = mm.in(); + + if (!validate) { + // Should err with AMX mapping failure since B buffer has a + // different layout than expected by AMX + result.compile_to_lowered_stmt("/dev/null", {A_buf, B_buf}, Halide::Text, target); + } else { + std::cerr << "Validating compiled program\n"; + + fill_buffer_a(A_buf, row, acc); + fill_buffer_b(B_buf, col, acc); + Buffer out(col, row); + result.realize(out); + + bool should_continue = true; + for (int j = 0; j < row && should_continue; ++j) { + for (int i = 0; i < col && should_continue; ++i) { + int32_t val = 0; + for (int k = 0; k < acc; ++k) { + val += static_cast(A_buf(k, j)) * static_cast(B_buf(i, k)); + } + if (val != out(i, j)) { + std::cerr << "Invalid result at " << i << ", " << j << "\n" + << out(i, j) << " != " << val << "\n" + << "Matrix dims: " << row << "x" << col << "x" << acc << "\nTile dims: " << tile_x << "x" << tile_y << "x" << tile_r << "\n"; + return false; + } + } + } + } + + return true; +} + +int main(int argc, char **argv) { + bool validate = false; + if (argc == 2 && argv[1] == std::string("--validate")) { + validate = true; + } + if (validate && !get_jit_target_from_environment().has_feature(Target::AVX512_SapphireRapids)) { + std::cerr << "Skipping test since target does not support AMX\n"; + return 0; + } + matmul(32, 32, 32, 8, 8, 4, validate); +} \ No newline at end of file From 2004522e4bf39739409a63f26315e02439cb969e Mon Sep 17 00:00:00 2001 From: Yihong Zhang Date: Mon, 12 Aug 2024 17:40:28 -0700 Subject: [PATCH 06/10] nits --- test/error/tiled_matmul_wrong_pattern.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/error/tiled_matmul_wrong_pattern.cpp b/test/error/tiled_matmul_wrong_pattern.cpp index c9ff2b95b5a5..d40c17c23419 100644 --- a/test/error/tiled_matmul_wrong_pattern.cpp +++ b/test/error/tiled_matmul_wrong_pattern.cpp @@ -75,9 +75,8 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r, bool Buffer out(col, row); result.realize(out); - bool should_continue = true; - for (int j = 0; j < row && should_continue; ++j) { - for (int i = 0; i < col && should_continue; ++i) { + for (int j = 0; j < row; ++j) { + for (int i = 0; i < col; ++i) { int32_t val = 0; for (int k = 0; k < acc; ++k) { val += static_cast(A_buf(k, j)) * static_cast(B_buf(i, k)); From cce2f5d5d7c51cacab17de511cc6d44613ab23f7 Mon Sep 17 00:00:00 2001 From: Yihong Zhang Date: Tue, 13 Aug 2024 01:22:03 -0700 Subject: [PATCH 07/10] Fix tile extraction --- src/ExtractTileOperations.cpp | 95 +++++++++++++------------------ test/correctness/tiled_matmul.cpp | 17 +++++- test/performance/tiled_matmul.cpp | 14 +++-- 3 files changed, 64 insertions(+), 62 deletions(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 8fdcea73f34b..78c6c09711c1 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -81,35 +81,7 @@ const auto wild_i32x = Variable::make(Int(32, 0), "*"); Tile<1> get_1d_tile_index(const Expr &e) { if (const auto *r1 = e.as()) { - - const auto stride_var = Variable::make(Int(32), "stride"); - const auto v1 = Variable::make(Int(32), "v1"); - const auto v2 = Variable::make(Int(32), "v2"); - const auto v3 = Variable::make(Int(32), "v3"); - - Expr patterns[] = { - ((v1 * stride_var) + v2) * v3, - v3 * ((v1 * stride_var) + v2), - (v2 + (v1 * stride_var)) * v3, - v3 * (v2 + (v1 * stride_var)), - }; - - std::map matches; - for (const auto &pattern : patterns) { - if (expr_match(pattern, r1->base, matches)) { - auto stride = std::move(matches["stride"]); - // stride must be a constant in order to not be confused with v1 - if (stride.as()) { - return {true, r1->base, {std::move(stride)}, {r1->lanes}}; - } - - // if stride wasn't a constant then v1 could possibly be the stride if constant - auto v1_expr = std::move(matches["v1"]); - if (v1_expr.as()) { - return {true, r1->base, {std::move(v1_expr)}, {r1->lanes}}; - } - } - } + return {true, r1->base, {r1->stride}, {r1->lanes}}; } return {}; @@ -218,7 +190,7 @@ Tile<3> get_3d_tile_index(const Expr &e) { * The pattern which is getting matched looks roughly like * `broadcast(ramp(0, 1, r), x*y) / broadcast(4, x*y*r) + optional(broadcast(base, x*y*r)) * broadcast(8, x*y*r) + * broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + - * broadcast(ramp(broadcast(_, r), broadcast(4, r), x) , y)` + * broadcast(ramp(broadcast(_, r), broadcast(4, r), y) , x)` */ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) { const auto *sub = e.as(); @@ -239,38 +211,38 @@ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) { // The right hand side of the add expression is used for retrieving the dimensions of the matrix. // obtain the x, y, r dimensions // this expr looks like below, the shape of `add_lhs->a` can be seen further down below - // broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + broadcast(ramp(broadcast(base, r), broadcast(4, r), x) , y) + // broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) + broadcast(ramp(broadcast(base, r), broadcast(4, r), y) , x) const Add *dim_expr = add_lhs->b.as(); if (!dim_expr) { return {}; } - // broadcast(ramp(broadcast(_, r), broadcast(4, r), x), y) + // broadcast(ramp(broadcast(_, r), broadcast(4, r), y), x) const Broadcast *base_stride_bc = dim_expr->b.as(); if (!base_stride_bc) { return {}; } - int tile_y = base_stride_bc->lanes; + int tile_x = base_stride_bc->lanes; // broadcast(ramp(0, 1, r), x*y) % broadcast(4, x*y*r) - const Mod *mod = dim_expr->a.as(); - - if (!mod) { + std::vector results{}; + const Expr mod_pattern = Mod::make(wild_i32x, Broadcast::make(4 / element_width, 0)); + if (!expr_match(mod_pattern, dim_expr->a, results)) { return {}; } // broadcast(ramp(0, 1, r), x*y) - const Broadcast *bc_ramp = mod->a.as(); + const Broadcast *bc_ramp = results[0].as(); if (!bc_ramp) { return {}; } int tile_xy = bc_ramp->lanes; - int tile_x = tile_xy / tile_y; + int tile_y = tile_xy / tile_x; // ramp(0, 1, r) const Ramp *r_ramp = bc_ramp->value.as(); @@ -282,21 +254,13 @@ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) { int tile_r = r_ramp->lanes; // get the base and stride - // ramp(broadcast(_, r), broadcast(4, r), x) - const Ramp *base_stride_ramp = base_stride_bc->value.as(); - - if (!base_stride_ramp) { + // ramp(broadcast(_, r), broadcast(4, r), y) + const Expr base_stride_ramp_pattern = Ramp::make(Broadcast::make(wild_i32, tile_r), Broadcast::make(4 / element_width, tile_r), tile_y); + if (!expr_match(base_stride_ramp_pattern, base_stride_bc->value, results)) { return {}; } - // broadcast(_, r) - const Broadcast *base_bc = base_stride_ramp->base.as(); - - if (!base_bc) { - return {}; - } - - Expr base = base_bc->value; + Expr base = results[0]; Expr stride; bool found_stride = false; @@ -308,7 +272,6 @@ Tile<3> get_3d_rhs_tile_index(const Expr &e, int element_width) { // this stride pattern can occur if `tile_r` is the same size as `acc` auto stride_pattern = Broadcast::make(Ramp::make(0, 1, tile_r), tile_x * tile_y) / Broadcast::make((4 / element_width), tile_x * tile_y * tile_r) * Broadcast::make(wild_i32, tile_x * tile_y * tile_r); - std::vector results{}; if (expr_match(stride_pattern, add_lhs->a, results)) { found_stride = true; stride = std::move(results[0]); @@ -353,19 +316,41 @@ BaseStride get_rhs_tile_index(const Expr &index, int element_width, int tile_x, return {true, rhs_tile3.base, rhs_tile3.stride[0] * element_width}; } else { + // 1D: degenerate as dot product. There are two cases: + // * tile_r is 4, so effectively there is only one row in the loaded tile + // * rhs.stride.1 == 4 && tile_y = 1, where the loaded RHS has shape (K/4)x4 + // and is contiguous in the memory if (rhs_tile1.extent[0] != tile_y * tile_r) { return {}; } + if (!(rhs_tile1.stride[0].as() && rhs_tile1.stride[0].as()->value == 1)) { + return {}; + } + + if (tile_r == 4 / element_width) { + return {true, rhs_tile1.base, 0}; + } - // times 4 because of the rhs layout, each vector used by AMX is 4 bytes in size. - // For the 4 gets divided by the element width which means each vector has 4 elements in u8/i8 and - // 2 elements for bf16. - return {true, rhs_tile1.base, rhs_tile1.stride[0] * (4 / element_width)}; + if (tile_y == 1) { + // 4 elements in u8/i8 and 2 elements for bf16. + return {true, rhs_tile1.base, 4 / element_width}; + } + + return {}; } } else { + // The only case where there is a ramp of ramp is when tile_y = 1 and so RHS has size (K/4)x4 + // (and rhs.stride.1 != 4, o.w. it degenerates to 1D) if (tile_y != rhs_tile2.extent[0] || tile_r != rhs_tile2.extent[1]) { return {}; } + if (!(rhs_tile2.stride[1].as() && rhs_tile2.stride[1].as()->value == 1)) { + return {}; + } + + if (tile_y != 1) { + return {}; + } return {true, rhs_tile2.base, rhs_tile2.stride[0]}; } diff --git a/test/correctness/tiled_matmul.cpp b/test/correctness/tiled_matmul.cpp index f17b3786366a..02fc1c1e60dd 100644 --- a/test/correctness/tiled_matmul.cpp +++ b/test/correctness/tiled_matmul.cpp @@ -1,4 +1,6 @@ #include "Halide.h" + +#include #include using namespace Halide; @@ -89,6 +91,7 @@ void print_mat_rhs(const Buffer &buf, int rows, int cols) { template bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { + Target target("x86-64-linux-avx512_sapphirerapids"); Buffer A_buf(acc, row); Buffer B_buf(4, col, acc / 4); @@ -134,6 +137,7 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { Buffer out(col, row); result.realize(out); + result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A_buf, B_buf}, target); // uncomment to check the matrices // std::cout << "Matrix A\n"; @@ -164,6 +168,7 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { } bool matmul_bf16(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { + Target target("x86-64-linux-avx512_sapphirerapids"); Var x("x"), y("y"); Buffer A(acc, row); Buffer B(2, col, acc / 2); @@ -209,7 +214,7 @@ bool matmul_bf16(int row, int col, int acc, int tile_x, int tile_y, int tile_r) Buffer out(col, row); // Uncomment to check the asm - // result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, target); + result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, target); // result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); result.realize(out); @@ -248,7 +253,15 @@ auto matmul_su = &matmul; auto matmul_uu = &matmul; bool run_tests(bool (*fn)(int, int, int, int, int, int), int element_width) { - return fn(2, 2, 16, 2, 2, 8 / element_width) && fn(4, 4, 8, 4, 4, 8 / element_width) && fn(32, 32, 32, 8, 8, 8 / element_width) && fn(32, 32, 32, 8, 8, 4 / element_width); + return true + && fn(2, 2, 16, 2, 2, 8 / element_width) + && fn(4, 4, 8, 4, 4, 8 / element_width) + && fn(8, 8, 4, 8, 8, 4 / element_width) + && fn(32, 32, 32, 8, 8, 8 / element_width) + && fn(32, 32, 32, 8, 8, 4 / element_width) + && fn(32, 32, 32, 6, 8, 4 / element_width) + && fn(32, 32, 32, 6, 8, 8 / element_width) + ; } int main(int argc, char **argv) { diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index 03bd243ef554..36e3ccd78ec7 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -85,6 +85,8 @@ bool matmul(Halide::Target target) { // This means that the rows must always be divisible by 4 (or 2 for bf16). ImageParam B(rhs(8), 3, "rhs"); + B.dim(1).set_stride(4); + RDom r(0, acc); Func mm("matmul"); @@ -141,12 +143,12 @@ bool matmul(Halide::Target target) { // Uncomment to check the asm // result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A, B}, target); - // result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); + result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); - auto time = Tools::benchmark(20, 20, [&]() { - result.realize(out); - }); - std::cout << "Exec time: " << time << "\n"; + // auto time = Tools::benchmark(20, 20, [&]() { + // result.realize(out); + // }); + // std::cout << "Exec time: " << time << "\n"; std::cout << "Success!\n"; return true; } @@ -172,6 +174,8 @@ bool matmul_bf16(Halide::Target target) { ImageParam A(BFloat(16), 2, "lhs"); ImageParam B(BFloat(16), 3, "rhs"); + B.dim(1).set_stride(2); + RDom r(0, acc, "acc"); Func mm("matmul"); From a3b267cf43fbab7c94c129373140c3f533a9e9b9 Mon Sep 17 00:00:00 2001 From: Yihong Zhang Date: Tue, 13 Aug 2024 01:39:19 -0700 Subject: [PATCH 08/10] tweaks --- src/ExtractTileOperations.cpp | 2 +- test/correctness/tiled_matmul.cpp | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ExtractTileOperations.cpp b/src/ExtractTileOperations.cpp index 78c6c09711c1..823b2e9d0fbe 100644 --- a/src/ExtractTileOperations.cpp +++ b/src/ExtractTileOperations.cpp @@ -340,7 +340,7 @@ BaseStride get_rhs_tile_index(const Expr &index, int element_width, int tile_x, } } else { // The only case where there is a ramp of ramp is when tile_y = 1 and so RHS has size (K/4)x4 - // (and rhs.stride.1 != 4, o.w. it degenerates to 1D) + // (and rhs.stride.1 != 4, for o.w. it degenerates to 1D) if (tile_y != rhs_tile2.extent[0] || tile_r != rhs_tile2.extent[1]) { return {}; } diff --git a/test/correctness/tiled_matmul.cpp b/test/correctness/tiled_matmul.cpp index 02fc1c1e60dd..b2c5f34becb1 100644 --- a/test/correctness/tiled_matmul.cpp +++ b/test/correctness/tiled_matmul.cpp @@ -254,6 +254,9 @@ auto matmul_uu = &matmul; bool run_tests(bool (*fn)(int, int, int, int, int, int), int element_width) { return true + // TODO: tile_x and tile_y is not supported because they degenerate to a pattern that the matcher for LHS fails to recognize + // && fn(2, 2, 16, 1, 2, 4 / element_width) + // && fn(2, 2, 16, 2, 2, 4 / element_width) && fn(2, 2, 16, 2, 2, 8 / element_width) && fn(4, 4, 8, 4, 4, 8 / element_width) && fn(8, 8, 4, 8, 8, 4 / element_width) From 8ae0f433fca5784f989e00436ec3a4057150aaac Mon Sep 17 00:00:00 2001 From: Yihong Zhang Date: Tue, 13 Aug 2024 01:44:53 -0700 Subject: [PATCH 09/10] tweak --- test/error/tiled_matmul_wrong_layout.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/error/tiled_matmul_wrong_layout.cpp b/test/error/tiled_matmul_wrong_layout.cpp index f32da120f94c..6f788c50a269 100644 --- a/test/error/tiled_matmul_wrong_layout.cpp +++ b/test/error/tiled_matmul_wrong_layout.cpp @@ -106,5 +106,9 @@ int main(int argc, char **argv) { std::cerr << "Skipping test since target does not support AMX\n"; return 0; } + // Note theoretically we should be able to compile this if tile_y is 8, in which case + // each row of a tile becomes contiguous in memory again. + // However, we cannot do this because the matcher for LHS cannot handle the case + // when tile_x or tile_y is 1. matmul(32, 32, 32, 8, 8, 4, validate); } \ No newline at end of file From a4393df3f04309ef452f7643265fbe3fdb7edb42 Mon Sep 17 00:00:00 2001 From: Yihong Zhang Date: Tue, 13 Aug 2024 01:48:31 -0700 Subject: [PATCH 10/10] tweaks --- test/correctness/tiled_matmul.cpp | 6 ++---- test/error/tiled_matmul_wrong_layout.cpp | 2 +- test/performance/tiled_matmul.cpp | 10 +++++----- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/test/correctness/tiled_matmul.cpp b/test/correctness/tiled_matmul.cpp index b2c5f34becb1..c7b31883d09f 100644 --- a/test/correctness/tiled_matmul.cpp +++ b/test/correctness/tiled_matmul.cpp @@ -91,7 +91,6 @@ void print_mat_rhs(const Buffer &buf, int rows, int cols) { template bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { - Target target("x86-64-linux-avx512_sapphirerapids"); Buffer A_buf(acc, row); Buffer B_buf(4, col, acc / 4); @@ -137,7 +136,7 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { Buffer out(col, row); result.realize(out); - result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A_buf, B_buf}, target); + // result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A_buf, B_buf}, target); // uncomment to check the matrices // std::cout << "Matrix A\n"; @@ -168,7 +167,6 @@ bool matmul(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { } bool matmul_bf16(int row, int col, int acc, int tile_x, int tile_y, int tile_r) { - Target target("x86-64-linux-avx512_sapphirerapids"); Var x("x"), y("y"); Buffer A(acc, row); Buffer B(2, col, acc / 2); @@ -214,7 +212,7 @@ bool matmul_bf16(int row, int col, int acc, int tile_x, int tile_y, int tile_r) Buffer out(col, row); // Uncomment to check the asm - result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, target); + // result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul_bf16.ll", {A, B}, target); // result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); result.realize(out); diff --git a/test/error/tiled_matmul_wrong_layout.cpp b/test/error/tiled_matmul_wrong_layout.cpp index 6f788c50a269..b13b1818e58c 100644 --- a/test/error/tiled_matmul_wrong_layout.cpp +++ b/test/error/tiled_matmul_wrong_layout.cpp @@ -106,7 +106,7 @@ int main(int argc, char **argv) { std::cerr << "Skipping test since target does not support AMX\n"; return 0; } - // Note theoretically we should be able to compile this if tile_y is 8, in which case + // Note theoretically we should be able to compile this if tile_x is set to 1, in which case // each row of a tile becomes contiguous in memory again. // However, we cannot do this because the matcher for LHS cannot handle the case // when tile_x or tile_y is 1. diff --git a/test/performance/tiled_matmul.cpp b/test/performance/tiled_matmul.cpp index 36e3ccd78ec7..b8094b3cdc07 100644 --- a/test/performance/tiled_matmul.cpp +++ b/test/performance/tiled_matmul.cpp @@ -143,12 +143,12 @@ bool matmul(Halide::Target target) { // Uncomment to check the asm // result.compile_to_llvm_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.ll", {A, B}, target); - result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); + // result.compile_to_assembly(Internal::get_test_tmp_dir() + "tiled_matmul.s", {A, B}, target); - // auto time = Tools::benchmark(20, 20, [&]() { - // result.realize(out); - // }); - // std::cout << "Exec time: " << time << "\n"; + auto time = Tools::benchmark(20, 20, [&]() { + result.realize(out); + }); + std::cout << "Exec time: " << time << "\n"; std::cout << "Success!\n"; return true; }