Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add workgroup level TileShape #84

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,12 @@ int main(int argc, const char** argv)
using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;

using TileShape = Shape<_1, _1, _1>;
// Workgroup-level tile
using TileShape = Shape<_32, _256, _32>;

using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TN>,
Layout<Shape<_1,_1,_1>>,
Tile<_32,_64,_32>>;
Tile<_32,_64,_32>>; // Subgroup level-tile

using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated;

Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/collective/intel_pvc_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ struct CollectiveMma<
using DpasShape = typename TiledMma::Shape_MNK;
using TileDpasShape = decltype(tile_shape(TiledMma()));

static constexpr uint32_t MaxThreadsPerBlock = get<0>(DpasShape()) * get<1>(DpasShape());
static constexpr uint32_t MaxThreadsPerBlock = cute::size(TileShape{}) / cute::size(TileDpasShape{}) * SubgroupSize;

static constexpr int FragsM = get<0>(TileDpasShape{}) / get<0>(DpasShape()); // A frags per sub_group
static constexpr int FragsN = get<1>(TileDpasShape{}) / get<1>(DpasShape()); // B frags per sub_group
Expand Down
15 changes: 6 additions & 9 deletions include/cutlass/gemm/kernel/intel_pvc_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,10 @@ class GemmUniversal<

static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size
static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock;
static constexpr uint32_t MinBlocksPerMultiprocessor = CollectiveMainloop::MinBlocksPerMultiprocessor;

static constexpr int num_sg = MaxThreadsPerBlock / SubgroupSize; // number of sub_groups per work group

using DpasShape = typename CollectiveMainloop::DpasShape;
mehdi-goli marked this conversation as resolved.
Show resolved Hide resolved
using TileDpasShape = typename CollectiveMainloop::TileDpasShape;


static constexpr int FragsM = CollectiveMainloop::FragsM;
static constexpr int FragsN = CollectiveMainloop::FragsN;

Expand Down Expand Up @@ -182,9 +178,9 @@ class GemmUniversal<
const int sg_n = (N - 1) / get<1>(TileDpasShape{}) + 1; // sub_groups required to process B fragments

return dim3(
sg_m,
cute::ceil_div(sg_n, num_sg),
batch_count
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))),
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))),
batch_count
);
}

Expand Down Expand Up @@ -218,9 +214,10 @@ class GemmUniversal<

// Get the appropriate blocks for this sub_group -- potential for sub_group locality
int thread_idx = int(ThreadIdxX());
auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K)
constexpr auto workgroup_shape = TileShape{}; // (SUB_M,SUB_N,SUB_K)
constexpr auto subgroup_shape = TileDpasShape{}; // (SUB_M,SUB_N,SUB_K)
const int m_coord = BlockIdxX() * get<0>(subgroup_shape);
const int n_coord = (BlockIdxY() * num_sg + thread_idx / SubgroupSize) * get<1>(subgroup_shape);
const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape);
const int l_coord = BlockIdxZ();

Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor(make_coord(m_coord, 0, 0),
Expand Down
Loading