Skip to content

Commit

Permalink
Update to Cutlass 3.5.1
Browse files Browse the repository at this point in the history
  • Loading branch information
aacostadiaz committed Aug 6, 2024
1 parent 2aae80e commit e9bcb40
Show file tree
Hide file tree
Showing 15 changed files with 31 additions and 25 deletions.
Empty file.
20 changes: 10 additions & 10 deletions examples/cute/tutorial/sgemm_sm80_sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
// Allocate the accumulators -- same size as the projected data
Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N)

CUTE_STATIC_ASSERT_V( shape(tCrA) == shape(tCsA)); // (MMA,MMA_M,MMA_K)
CUTE_STATIC_ASSERT_V( shape(tCrB) == shape(tCsB)); // (MMA,MMA_N,MMA_K)
CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K
CUTE_STATIC_ASSERT_V(( shape(tCrA) == take<0,3>(shape(tCsA)))); // (MMA,MMA_M,MMA_K)
CUTE_STATIC_ASSERT_V(( shape(tCrB) == take<0,3>(shape(tCsB)))); // (MMA,MMA_N,MMA_K)
CUTE_STATIC_ASSERT_V(( shape(tCrC) == take<0,3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N)
CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCsA))); // MMA_M
CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCsB))); // MMA_N
CUTE_STATIC_ASSERT_V((size<2>(tCsA) == size<2>(tCsB))); // MMA_K

// Clear the accumulators
clear(tCrC);
Expand Down Expand Up @@ -390,10 +390,10 @@ gemm_tn(int m, int n, int k,
auto bP = Int<3>{}; // Pipeline

// Define the smem layouts (static)
auto sA_atom = make_layout(make_shape ( bM, bK),
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
auto sB_atom = make_layout(make_shape ( bN, bK),
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
auto sA_atom = make_layout(make_shape ( bM, bK),
make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major
[[maybe_unused]] auto sB_atom = make_layout(make_shape ( bN, bK),
make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major
auto sA = tile_to_shape(sA_atom, make_shape(bM, bK, bP));
auto sB = tile_to_shape(sA_atom, make_shape(bN, bK, bP));
auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx
Expand Down
2 changes: 2 additions & 0 deletions include/cutlass/bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ struct alignas(2) bfloat16_t {
/// Default constructor
bfloat16_t() = default;

#if !defined(CUTLASS_ENABLE_SYCL)
/// Reinterpret cast from CUDA's __nv_bfloat16 type
CUTLASS_HOST_DEVICE
explicit bfloat16_t(__nv_bfloat16 const & x) {
Expand All @@ -113,6 +114,7 @@ struct alignas(2) bfloat16_t {
std::memcpy(&storage, &raw.x, sizeof(storage));
#endif
}
#endif

/// Floating-point conversion - round toward nearest
CUTLASS_HOST_DEVICE
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ struct conjugate<complex<T>> {
}
};

#if ! defined(__CUDACC_RTC__)
#if ! defined(__CUDACC_RTC__) && !defined(CUTLASS_ENABLE_SYCL)
template <>
struct conjugate<cuFloatComplex> {
CUTLASS_HOST_DEVICE
Expand Down
10 changes: 6 additions & 4 deletions include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,15 @@ class CollectiveEpilogue<
problem_shape_mnkl,
TileShapeMNK{},
tile_coord_mnkl,
residue_mn,
SubgroupTileShape{},
tiled_mma,
SubgroupTileShape{}, // Epilogue tile
params.xe_load_c,
thread_idx,
cD,
residue_mn,
cD,
trC
residue_mn,
trC,
thread_idx,
};
auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks<RefSrc>(cst_args);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto mma_thread_slice = tiled_mma.get_thread_slice(thread_idx);
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ struct CollectiveMma<
Layout warp_group_thread_layout = make_layout(Int<MmaWarpGroups>{},
Int<NumThreadsPerWarpGroup>{});

int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);
int warp_group_idx = shfl_sync(0xFFFFFFFF, thread_idx / NumThreadsPerWarpGroup, 0);

TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx));
Expand Down
2 changes: 2 additions & 0 deletions include/cutlass/gemm/device/gemm_universal_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ class GemmUniversalAdapter<
Status launch_result{ Status::kSuccess };
// Use extended launch API only for mainloops that use it
if constexpr (GemmKernel::ArchTag::kMinComputeCapability >= 90) {
#if !defined(CUTLASS_ENABLE_SYCL)
constexpr bool is_static_1x1x1 = cute::is_static_v<typename GemmKernel::DispatchPolicy::ClusterShape> and
cute::size(typename GemmKernel::DispatchPolicy::ClusterShape{}) == 1;
dim3 cluster(cute::size<0>(typename GemmKernel::DispatchPolicy::ClusterShape{}),
Expand Down Expand Up @@ -400,6 +401,7 @@ class GemmUniversalAdapter<
}
}
}
#endif
}
else {
launch_result = Status::kSuccess;
Expand Down
4 changes: 2 additions & 2 deletions include/cutlass/pipeline/sm90_pipeline.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ pipeline_init_wait(int cluster_size) {
cute::cluster_wait();
}
else {
__syncthreads();
syncthreads();
}
}

Expand All @@ -1160,7 +1160,7 @@ pipeline_init_arrive_relaxed(int cluster_size) {
cute::cluster_arrive_relaxed();
}
else {
__syncthreads();
syncthreads();
}
}

Expand Down

0 comments on commit e9bcb40

Please sign in to comment.