diff --git a/include/cute/atom/copy_traits_xe.hpp b/include/cute/atom/copy_traits_xe.hpp index b60ef8811c..63eae73dd4 100644 --- a/include/cute/atom/copy_traits_xe.hpp +++ b/include/cute/atom/copy_traits_xe.hpp @@ -37,12 +37,42 @@ namespace cute { + +template +CUTE_HOST_DEVICE constexpr +auto get_shape_WHD(cute::Stride, IntT, IntT> , cute::Shape shape_MKL) { + return shape_MKL; +} + +template +CUTE_HOST_DEVICE constexpr +auto get_shape_WHD(cute::Stride, IntT> , cute::Shape shape_MKL) { + return Shape(get<1>(shape_MKL), get<0>(shape_MKL), get<2>(shape_MKL)); +} + +template +CUTE_HOST_DEVICE constexpr +auto get_coordinates(cute::Stride, IntT, IntT> , + Tensor>, SLayout> const &src) { + auto [x, y, z] = src.data().coord_; + return make_coord(x, y, z); +} + +template +CUTE_HOST_DEVICE constexpr +auto get_coordinates(cute::Stride, IntT> , + Tensor>, SLayout> const &src) { + auto [x, y, z] = src.data().coord_; + return make_coord(y, x, z); +} + template struct XE_2D_LD_Unpack { GTensor tensor; using Copy_Traits = Copy_Traits; + template CUTE_HOST_DEVICE friend constexpr void @@ -50,11 +80,12 @@ struct XE_2D_LD_Unpack Tensor>, SLayout> const &src, Tensor &dst) { - static_assert(is_rmem::value); - int H = size<0>(traits.tensor); - int W = size<1>(traits.tensor) * sizeof(typename Copy_Traits::CopyInternalType); - auto [y, x, z] = src.data().coord_; - CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*dst.data()); + static_assert(is_rmem::value); + auto shape_whd = get_shape_WHD(traits.tensor.stride(), traits.tensor.shape()); + int W = size<0>(shape_whd) * sizeof(typename Copy_Traits::CopyInternalType); + int H = size<1>(shape_whd); + auto [x, y, z] = get_coordinates(traits.tensor.stride(), src); + CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*dst.data()); } template @@ -105,15 +136,13 @@ struct Copy_Traits : XE_2D_LD_Unpack { // Logical thread id to thread idx - using ThrID = Layout<_16>; + using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; + using SrcLayout = Layout>>; // one coordinate // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout, Shape<_16, _2>>>, - Stride<_16, Stride, Stride<_1, _256>>>>; + using DstLayout = Layout>>; // Reference map from (thr,val) to bit - using RefLayout = DstLayout; + using RefLayout = SrcLayout; using CopyInternalType = ushort; }; @@ -188,14 +217,13 @@ struct Copy_Traits : XE_2D_LD_Unpack { // Logical thread id to thread idx - using ThrID = Layout<_16>; + using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout, Stride<_0, _1>>; + using SrcLayout = Layout>>; // expected 4 coordinates // Map from (dst-thr,dst-val) to bit - using DstLayout = - Layout>, Stride<_32, Stride<_512, _1>>>; + using DstLayout = Layout>>; // Reference map from (thr,val) to bit - using RefLayout = DstLayout; + using RefLayout = SrcLayout; // 32 bits register file using CopyInternalType = uint; }; diff --git a/include/cute/util/debug.hpp b/include/cute/util/debug.hpp index be4eb1551c..5183363e90 100644 --- a/include/cute/util/debug.hpp +++ b/include/cute/util/debug.hpp @@ -129,7 +129,9 @@ bool block(int bid) { #if defined(CUTLASS_ENABLE_SYCL) - return (syclcompat::get_nd_item<3>().get_group_linear_id()==bid); + using namespace syclcompat; + return (work_group_id::x() + work_group_id::y() * work_group_range::x() + + work_group_id::z() * work_group_range::y() * work_group_range::x() == bid); #elif defined(__CUDA_ARCH__) return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == bid; #else @@ -142,7 +144,9 @@ bool thread(int tid, int bid) { #if defined(CUTLASS_ENABLE_SYCL) - return (syclcompat::get_nd_item<3>().get_global_linear_id()==bid); + using namespace syclcompat; + return (local_id::x() + local_id::y() * local_range::x() + + local_id::z() * local_range::x() * local_range::y() == tid) && block(bid); #elif defined(__CUDA_ARCH__) return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == tid) && block(bid); #else diff --git a/include/cutlass/gemm/collective/intel_pvc_mma.hpp b/include/cutlass/gemm/collective/intel_pvc_mma.hpp index c552ee8616..f69ae7bdf0 100644 --- a/include/cutlass/gemm/collective/intel_pvc_mma.hpp +++ b/include/cutlass/gemm/collective/intel_pvc_mma.hpp @@ -149,7 +149,7 @@ struct CollectiveMma< auto [M,N,K,L] = problem_shape_MNKL; Tensor tensorA = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA)); - Tensor tensorB = make_tensor(args.ptr_B, make_layout(make_shape(K,N,L), args.dB)); + Tensor tensorB = make_tensor(args.ptr_B, make_layout(make_shape(N,K,L), args.dB)); typename Params::XE_Copy_A copyA = make_xe_2d_copy(tensorA); typename Params::XE_Copy_B copyB = make_xe_2d_copy(tensorB); @@ -187,14 +187,14 @@ struct CollectiveMma< static_assert(is_rmem::value, "C tensor must be rmem resident."); // Tensor to hold input data - Tensor tAr = make_tensor(Shape(SubgroupTileShape{}) * FragsK>, Int<1>>{}); - Tensor tBr = make_tensor( - Shape(SubgroupTileShape{}) / FragsN>, Int>{}); + Tensor tAr = make_tensor(Shape(SubgroupTileShape{}) * FragsK>, _1>{}); + Tensor tBr = make_tensor(Shape(SubgroupTileShape{}) / 2>, Int>{}); Tensor tAr_view = make_tensor(static_cast(tAr).data(), Shape, Int, Int>{}); Tensor tBr_view = make_tensor(static_cast(tBr).data(), - Shape, Int, Int>{}); + Shape, Int, Int>{}, + Stride<_1, Int(SubgroupTileShape{}) / 2>, Int>{}); // Instantiate the M MA object TiledMma tiled_mma; @@ -206,11 +206,9 @@ struct CollectiveMma< { // Copy gmem to rmem for the first k_tile copy(mainloop.gmem_tiled_copy_a, gA(_,_,k), tAr); - copy(mainloop.gmem_tiled_copy_b, gB(_,k/2,_), tBr); + copy(mainloop.gmem_tiled_copy_b, gB(_,_,k/2), tBr); - for (int kl = 0; kl < FragsK; kl++) { - cute::gemm(tiled_mma, accum, tAr_view(_, _, kl), tBr_view(_, kl, _), src_accum); - } + cute::gemm(tiled_mma, accum, tAr_view, tBr_view, src_accum); } } }; diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 5c9b6d019e..5e8fce8b4e 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -221,9 +221,9 @@ class GemmUniversal< make_stride(Int{} * get<0>(MmaAtomShape()),_1{})); Tensor tBi = params.mainloop.gmem_tiled_copy_b.get_pvc_tensor( - make_coord(0, n_coord, 0), - make_shape(K, Int{}, L), - make_stride(_1{}, get<1>(MmaAtomShape()))); + make_coord(n_coord, 0, 0), + make_shape(Int{}, K / 2, L), + make_stride(get<1>(MmaAtomShape()), _1{})); // Compute tile residues for predication auto m_max_coord = M - get<0>(subgroup_shape) * m_coord; // M - SUB_M * m_coord