Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing 2d Prefetch #187

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions examples/sycl/pvc/flash_attention_v2/pvc_flash_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,6 @@ int main(int argc, const char** argv)
using ElementInputKV = bfloat16_t; // <- data type of elements in input matrix B
using ElementOutput = float; // <- data type of elements in output matrix D

using GmemTiledCopyQ = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyK = XE_2D_U16x32x32_LD_V;
using GmemTiledCopyV = XE_2D_U16x32x32_LD_V;

using LayoutQ = cutlass::layout::RowMajor;
using LayoutK = cutlass::layout::RowMajor;
using LayoutV = cutlass::layout::RowMajor;
Expand Down Expand Up @@ -476,7 +472,10 @@ int main(int argc, const char** argv)
XE_2D_U32x8x16_ST_N>;

if(options.is_causal) {
// Mainloop
using GmemTiledCopyQ = XE_2D_U16x32x16_LD_N;
using GmemTiledCopyK = XE_2D_U16x16x32_LD_V;
using GmemTiledCopyV = XE_2D_U16x16x32_LD_V;
// Mainloop
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMmaAttention<
GEMMDispatchPolicy,
TileShape,
Expand All @@ -503,7 +502,10 @@ int main(int argc, const char** argv)

runner.run(options, hw_info);
} else {
// Mainloop
using GmemTiledCopyQ = XE_2D_U16x32x32_LD_N;
using GmemTiledCopyK = XE_2D_U16x32x32_LD_V;
using GmemTiledCopyV = XE_2D_U16x32x32_LD_V;
// Mainloop
using CollectiveMainloop = cutlass::gemm::collective::CollectiveMmaAttention<
GEMMDispatchPolicy,
TileShape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class GemmUniversalAttention
//m, k
Tensor prefetch_iter_2d_a = params.mainloop.gmem_prefetch_q.get_pvc_tensor(
make_coord(seq_coord + (((sub_group_id % ATOM_N) / get<1>(PrefetchQThrShape{}))* get<0>(PrefetchQTileSize{})), // iteration 0/M/Hight/vertical
((sub_group_id % ATOM_N) % get<1>(PrefetchQThrShape{})) * get<1>(PrefetchQTileSize{}), // Iteration 1/K/Width/Horisontal
0, // Iteration 1/K/Width/Horisontal
blk_l_coord),
make_shape(_1{}, _1{}, _1{}));
Tensor prefetch_iter_a = append_pvc_tensor<1>(prefetch_iter_2d_a, k_tile_count, BLK_K);
Expand All @@ -306,23 +306,21 @@ class GemmUniversalAttention
// the iteration over K dimention of B matrix (head_size) should be :
auto iter_over_head_count = head_size / BLK_N;
// k, n
Tensor prefetch_iter_2d_b = params.mainloop.gmem_prefetch_k.get_pvc_tensor(
Tensor prefetch_iter_b = params.mainloop.gmem_prefetch_k.get_pvc_tensor(
make_coord(sub_group_id * get<0>(PrefetchKTileSize{}), // iteration 0/K/Hight/vertical
(sub_group_id % ATOM_N) * get<1>(PrefetchKTileSize{}), // iteration 1/N/W/Horisontal
0, // iteration 1/N/W/Horisontal
blk_l_coord), // batch
// ?, ?, k, N swap k and n here to match cutlass
make_shape(_1{}, _1{}, nblock_limit/*This is N*/));
// iter_over_head_count/* This is K*/), //(frag, iter_m, iter_n, iter_k)
append<4>(make_shape(_1{}, _1{}, nblock_limit/*This is N*/), iter_over_head_count/* This is K*/), //(frag, iter_m, iter_n, iter_k)
// K, ?, N (The N should move along the N as get<0>(PrefetchKThrShape) load 32 each and we want 128 of N )
// The K should move along the dimmension of Block load as we lay 8x32 using the 8x1 shape for subgroups
// leading to load 64x32 of (K,N) per each prefetch (BLOCK_N SHows K DIM)
// append<3>(make_shape(_, SG_N), BLK_N), seq<0, 1, 0>{}); // so 64 * iteration 0 (SG_N that is K which is vertical) and 32 * iteration 1 (N which is horisontal)
append<3>(make_shape(_, SG_N), BLK_N), seq<0, 1, 0>{}); // so 64 * iteration 0 (SG_N that is K which is vertical) and 32 * iteration 1 (N which is horisontal)

// V is a transposed matrix, So here the Sequense length is consumed, it is transposed so the consumed dimension looks like B matrix
// Hence, the Head size is the fast moving dimention and horisontal and sequence length is vertical.
// The prefetch only move along the sequence lenth. Here we call sequence length K since it get consumed and head size N since it stay

Tensor prefetch_iter_b = append_pvc_tensor<0>(prefetch_iter_2d_b, iter_over_head_count, BLK_N);

Tensor prefetch_iter_2d_v = params.mainloop.gmem_prefetch_v.get_pvc_tensor(
make_coord((sub_group_id / ATOM_N) * get<0>(PrefetchVTileSize{}), // iteration 0/K/Hight/vertical/ sequence lengh
head_size_coord, // iteration 1/N/W/Horisontal / Head size
Expand Down
Loading