From e1bd2b8fb1beab9c6057ef9b9307b7a9bc9389da Mon Sep 17 00:00:00 2001 From: Alejandro Acosta Date: Wed, 12 Jun 2024 11:57:35 +0100 Subject: [PATCH] Add subgroup tile information to tiledMma. (#82) --- .../sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp | 7 +++-- include/cute/arch/mma_xe.hpp | 9 ++---- include/cute/atom/mma_traits_xe.hpp | 2 +- .../cutlass/gemm/collective/intel_pvc_mma.hpp | 28 +++++++++---------- .../cutlass/gemm/kernel/intel_pvc_gemm.hpp | 17 +++++------ 5 files changed, 29 insertions(+), 34 deletions(-) diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index e248d3341..542bafb34 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -353,10 +353,11 @@ int main(int argc, const char** argv) using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N; using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N; - using TileShape = Shape<_32, _64, _32>; + using TileShape = Shape<_1, _1, _1>; - using TiledMma = TiledMMA, - Layout>>; + using TiledMma = TiledMMA, + Layout>, + Tile<_32,_64,_32>>; using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; diff --git a/include/cute/arch/mma_xe.hpp b/include/cute/arch/mma_xe.hpp index 878c587fd..3d1bfb8f6 100644 --- a/include/cute/arch/mma_xe.hpp +++ b/include/cute/arch/mma_xe.hpp @@ -45,11 +45,7 @@ SYCL_DEVICE_OCL(float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, cute::i #undef SYCL_DEVICE_OCL namespace cute { -//MxNxK_A,B,C,D -//# of vector component of a x subgroup-size x function name -//float8 intel_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc); -//TODO: Is A really not transposed? Maybe better a macro than separate define for 1,2,4,8 -struct XE_8x16x16_BF16BF16F32F32_NN +struct XE_8x16x16_F32BF16BF16F32_TN { using DRegisters = intel::float8[1]; using ARegisters = intel::short8[1]; @@ -69,8 +65,7 @@ struct XE_8x16x16_BF16BF16F32F32_NN #endif } }; -//float intel_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) -struct XE_1x16x16_BF16BF16F32F32_NN +struct XE_1x16x16_F32BF16BF16F32_TN { using DRegisters = float[1]; using ARegisters = short[1]; diff --git a/include/cute/atom/mma_traits_xe.hpp b/include/cute/atom/mma_traits_xe.hpp index a5ef6dbec..1cbefc872 100644 --- a/include/cute/atom/mma_traits_xe.hpp +++ b/include/cute/atom/mma_traits_xe.hpp @@ -38,7 +38,7 @@ namespace cute { template <> -struct MMA_Traits +struct MMA_Traits { using ValTypeD = float; using ValTypeA = bfloat16_t; diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index 83d46afa6..d587fbcd9 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -99,24 +99,22 @@ struct CollectiveMma< using ArchTag = typename DispatchPolicy::ArchTag; static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; - - static constexpr int DpasM = get<0>(shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per sub_group for Matrix A - static constexpr int DpasN = get<1>(shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per sub_group for Matrix B - static constexpr int DpasK = get<1>(shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per sub_group for Matrix A - static constexpr uint32_t MaxThreadsPerBlock = DpasM * DpasN; - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + using DpasShape = typename TiledMma::Shape_MNK; + using TileDpasShape = decltype(tile_shape(TiledMma())); - static constexpr int FragsM = get<0>(TileShape{}) / DpasM; // A frags per sub_group - static constexpr int FragsN = get<1>(TileShape{}) / DpasN; // B frags per sub_group - static constexpr int FragsK = get<2>(TileShape{}) / DpasK; + static constexpr uint32_t MaxThreadsPerBlock = get<0>(DpasShape()) * get<1>(DpasShape()); + + static constexpr int FragsM = get<0>(TileDpasShape{}) / get<0>(DpasShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(TileDpasShape{}) / get<1>(DpasShape()); // B frags per sub_group + static constexpr int FragsK = get<2>(TileDpasShape{}) / get<2>(DpasShape()); // Calculate the vector width based on the amount of registers // required per work item by dividing the total fragment size by // the sub_group size. - static constexpr int VecC = (DpasN * DpasM) / SubgroupSize; - static constexpr int VecA = (DpasM * DpasK) / SubgroupSize; - static constexpr int VecB = (DpasN * DpasK) / SubgroupSize; + static constexpr int VecC = (get<1>(DpasShape()) * get<0>(DpasShape())) / SubgroupSize; + static constexpr int VecA = (get<0>(DpasShape()) * get<2>(DpasShape())) / SubgroupSize; + static constexpr int VecB = (get<1>(DpasShape()) * get<2>(DpasShape())) / SubgroupSize; // Host side kernel arguments struct Arguments { @@ -188,8 +186,8 @@ struct CollectiveMma< static_assert(is_rmem::value, "C tensor must be rmem resident."); // Tensor to hold input data - Tensor tAr = make_tensor(Shape(TileShape{}) * FragsK>, Int<1>>{}); - Tensor tBr = make_tensor(Shape(TileShape{}) / FragsN>, Int>{}); + Tensor tAr = make_tensor(Shape(TileDpasShape{}) * FragsK>, Int<1>>{}); + Tensor tBr = make_tensor(Shape(TileDpasShape{}) / FragsN>, Int>{}); Tensor tAr_view = make_tensor(static_cast(tAr).data(), Shape, Int, Int>{}); @@ -202,7 +200,7 @@ struct CollectiveMma< // // Mainloop // - for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += DpasK * FragsK) + for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += get<2>(DpasShape()) * FragsK) { // Copy gmem to rmem for the first k_tile copy(mainloop.gmem_tiled_copy_a, gA(_,_,k), tAr); diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 536f537db..1a9185437 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -105,8 +105,9 @@ class GemmUniversal< static constexpr int num_sg = MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group - static constexpr int DpasM = CollectiveMainloop::DpasM; - static constexpr int DpasN = CollectiveMainloop::DpasN; + using DpasShape = typename CollectiveMainloop::DpasShape; + using TileDpasShape = typename CollectiveMainloop::TileDpasShape; + static constexpr int FragsM = CollectiveMainloop::FragsM; static constexpr int FragsN = CollectiveMainloop::FragsN; @@ -177,8 +178,8 @@ class GemmUniversal< auto M = get<0>(params.problem_shape); auto N = get<1>(params.problem_shape); - const int sg_m = (M - 1) / get<0>(TileShape{}) + 1; // sub_groups required to process A fragments - const int sg_n = (N - 1) / get<1>(TileShape{}) + 1; // sub_groups required to process B fragments + const int sg_m = (M - 1) / get<0>(TileDpasShape{}) + 1; // sub_groups required to process A fragments + const int sg_n = (N - 1) / get<1>(TileDpasShape{}) + 1; // sub_groups required to process B fragments return dim3( sg_m, @@ -217,18 +218,18 @@ class GemmUniversal< // Get the appropriate blocks for this sub_group -- potential for sub_group locality int thread_idx = int(ThreadIdxX()); - auto subgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K) + auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K) const int m_coord = BlockIdxX() * get<0>(subgroup_shape); const int n_coord = (BlockIdxY() * num_sg + thread_idx / SubgroupSize) * get<1>(subgroup_shape); const int l_coord = BlockIdxZ(); Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, 0), make_shape(_1{}, K, L), - make_stride(Int{}, _1{})); + make_stride(Int{} * get<0>(DpasShape()), _1{})); Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor(make_coord(0, n_coord, 0), make_shape(K, Int{}, L), - make_stride(_1{}, Int{})); + make_stride(_1{}, get<1>(DpasShape()))); // Compute tile residues for predication auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord @@ -262,7 +263,7 @@ class GemmUniversal< Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, 0), make_shape(Int{}, Int{}, L), - make_stride(Int{}, Int{})); + make_stride(get<0>(DpasShape()), get<1>(DpasShape()))); copy(gmem_tiled_copy_c, accumulators, tCi(_,_,_,l_coord)); }