diff --git a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp index 731edfa15..204214da1 100644 --- a/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp +++ b/examples/sycl/pvc/pvc_bfloat_dpas_gemm_cute.cpp @@ -31,6 +31,8 @@ #include "cutlass/gemm/device/gemm.h" #include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/intel_pvc_epilogue.hpp" +#include "cutlass/epilogue/fusion/intel_pvc_callbacks.hpp" #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/collective/collective_mma.hpp" @@ -49,7 +51,7 @@ template static void fill_matrix(std::vector &vector) { std::generate(std::begin(vector), std::end(vector), [&] { - return static_cast( (rand() / double(RAND_MAX)) ); + return static_cast( (rand() / double(RAND_MAX)) ); }); } @@ -360,26 +362,30 @@ int main(int argc, const char** argv) Layout>, Tile<_32,_64,_32>>; // Subgroup level-tile - using DispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; - using EpilogueOp = cutlass::epilogue::thread::LinearCombination< - ElementOutput, // <- data type of output matrix - 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized - // memory access. For a byte, it's 16 - // elements. This becomes the vector width of - // math instructions in the epilogue too - ElementAccumulator, // <- data type of accumulator - ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + using EpilogueOp = cutlass::epilogue::fusion::LinearCombination; - using CollectiveEpilogue = cutlass::epilogue::collective::DefaultEpilogue< + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, + TileShape, + ElementAccumulator, cutlass::gemm::TagToStrideC_t, + ElementOutput, cutlass::gemm::TagToStrideC_t, - EpilogueOp, - cutlass::gemm::EpilogueDefault>; + FusionCallBacks, + XE_2D_U32x8x16x1x1_LD_N, + void, void, + XE_2D_U32x8x16x1x1_ST_N, + void, void>; // Mainloop using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< - DispatchPolicy, + GEMMDispatchPolicy, TileShape, ElementInputA, cutlass::gemm::TagToStrideA_t, diff --git a/include/cute/arch/copy_xe.hpp b/include/cute/arch/copy_xe.hpp index 2646dcae1..3bfc5c853 100644 --- a/include/cute/arch/copy_xe.hpp +++ b/include/cute/arch/copy_xe.hpp @@ -287,7 +287,7 @@ struct XE_2D_U16x16x16x1x1_V struct XE_2D_U32x8x16x1x1_ST_N { template - CUTE_HOST_DEVICE static void copy(void *baseoffset, int width, int height, + CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, int height, int pitch, intel::coord_t coord, const T *src) { #if defined(SYCL_INTEL_TARGET) static_assert(sizeof(T) == 4, "Expected T to have size 4"); diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp index d61f59f72..00ddd37d6 100644 --- a/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -56,7 +56,11 @@ class CollectiveEpilogue { #include "default_epilogue.hpp" #include "default_epilogue_array.hpp" #include "epilogue_tensor_broadcast.hpp" +#if defined (SYCL_INTEL_TARGET) +#include "intel_pvc_epilogue.hpp" +#else #include "sm70_epilogue_vectorized.hpp" #include "sm90_epilogue_tma_warpspecialized.hpp" #include "sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp" +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp new file mode 100644 index 000000000..4d3330865 --- /dev/null +++ b/include/cutlass/epilogue/collective/intel_pvc_epilogue.hpp @@ -0,0 +1,342 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/detail/layout.hpp" + + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class CtaTileMNK_, + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2R_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpR2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_ +> +class CollectiveEpilogue< + IntelPVCEpilogue, + CtaTileMNK_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2R_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpR2G_, + SmemLayoutAtomD_, + CopyOpR2S_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = IntelPVCEpilogue; + using CtaTileMNK = CtaTileMNK_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using ElementAccumulator = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpG2R = CopyOpG2R_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpR2G = CopyOpR2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + + using ThreadEpilogueOp = typename fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2R; + using GmemTiledCopyD = CopyOpR2G; + using ElementOutput = typename FusionCallbacks::ElementOutput; + using ElementCompute = typename FusionCallbacks::ElementCompute; + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + static_assert(std::is_same_v, "Copy operation to shared memory is not supported"); + +private: + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + +public: + + using EmptyType = cute::tuple<>; + using SmemCStorage = EmptyType; + using SmemDStorage = EmptyType; + + struct TensorStorageImpl: cute::tuple { + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + }; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const* ptr_C; + StrideC dC; + ElementD const* ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + using XE_Copy_C = decltype(make_xe_2d_copy( + make_tensor(static_cast(nullptr), + repeat_like(StrideC{}, int32_t(0)), StrideC{}))); + using XE_Copy_D = decltype(make_xe_2d_copy( + make_tensor(static_cast(nullptr), + repeat_like(StrideD{}, int32_t(0)), StrideD{}))); + + typename FusionCallbacks::Params thread{}; + XE_Copy_C xe_load_c; + XE_Copy_D xe_store_d; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + typename Params::XE_Copy_C xe_load_c = {}; + if constexpr (is_source_supported) { + Tensor tensor_c = make_tensor(args.ptr_C, make_layout(make_shape(M,N,L), args.dC)); + xe_load_c = make_xe_2d_copy(tensor_c); + } + + typename Params::XE_Copy_D xe_store_d = {}; + if constexpr (is_destination_supported) { + Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD)); + xe_store_d = make_xe_2d_copy(tensor_d); + } + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + xe_load_c, + xe_store_d + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage const& shared_storage_) + : params(params_), fusion_callbacks(params_.thread, shared_storage_.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class Accumulator, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator() ( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + Accumulator accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem) { + + (void) tiled_mma; + (void) residue_mnk; + (void) smem; + using namespace cute; + + using MmaAtomShape = typename TiledMma::AtomShape_MNK; + using SubgroupTileShape = decltype(tile_shape(TiledMma())); + + static constexpr int FragsM = get<0>(SubgroupTileShape{}) / get<0>(MmaAtomShape()); // A frags per sub_group + static constexpr int FragsN = get<1>(SubgroupTileShape{}) / get<1>(MmaAtomShape()); // B frags per sub_group + + static constexpr int FragmentSize = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + Tensor trC = make_tensor(Shape>{}); + Tensor trD = make_tensor(Shape>{}); + Tensor tOuti = params.xe_store_d.get_pvc_tensor( + make_coord(m_coord, n_coord, 0), + make_shape(Int{}, Int{}, L), + make_stride(Int(MmaAtomShape{})>{}, Int(MmaAtomShape{})>{})); + + Tensor rw_coord = tOuti(_,_,_,l_coord); + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); + Tensor cD = local_tile(mD_crd, take<0,2>(SubgroupTileShape{}), make_coord(m_coord, n_coord)); + // Get the fusion callbacks + constexpr bool RefSrc = true; + auto residue_mn = make_coord(M, N); + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + TileShapeMNK{}, + tile_coord_mnkl, + residue_mn, + SubgroupTileShape{}, + params.xe_load_c, + thread_idx, + cD, + cD, + trC + }; + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + + cst_callbacks.begin(); + + auto acc_frag = recast>(accumulators); + auto trD_frag = recast>(trD); + + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < FragsN; epi_n++) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < FragsM; epi_m++) { + + if (is_C_load_needed) { + copy(params.xe_load_c, rw_coord(_, epi_m, epi_n), trC); + } + + cst_callbacks.previsit(epi_m, epi_n, 0, is_C_load_needed); + + auto acc_frag_mn = acc_frag(_, epi_m, epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < FragmentSize; ++epi_v) { + trD_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); + } + + copy(params.xe_store_d, trD, rw_coord(_, epi_m, epi_n)); + } + } + + cst_callbacks.end(); + + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index 409ff74dd..e49f94c02 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -156,6 +156,12 @@ struct Sm90TmaWarpSpecializedBiasElementwise { constexpr static int FragmentSize = FragmentSize_; }; +#if defined (SYCL_INTEL_TARGET) +struct IntelPVCEpilogue { + static constexpr int SubgroupSize = 16; +}; +#endif + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue diff --git a/include/cutlass/epilogue/fusion/intel_pvc_callbacks.hpp b/include/cutlass/epilogue/fusion/intel_pvc_callbacks.hpp new file mode 100644 index 000000000..c0e662b77 --- /dev/null +++ b/include/cutlass/epilogue/fusion/intel_pvc_callbacks.hpp @@ -0,0 +1,106 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Fusion callbacks specializations for the Intel PVC epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementOutput_, + class ElementCompute_, + class ElementSource_, + class ElementScalar_, + FloatRoundStyle RoundStyle_, + class CtaTileShapeMNK_, + class EpilogueTile_ +> +struct FusionCallbacks< + epilogue::IntelPVCEpilogue, + fusion::LinearCombination, + CtaTileShapeMNK_, + EpilogueTile_ +> : Sm90LinearCombination::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_> { + + using Impl = Sm90LinearCombination::type, ElementCompute_, ElementSource_, ElementScalar_, RoundStyle_>; + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index 6c729c10d..7d41952ea 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -910,7 +910,7 @@ using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8 = Sm90EVT, // activation(Z) Sm90EVT, // Aux = Z // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias - Sm90ScaledLinCombPerRowBias, + Sm90ScaledLinCombPerRowBias > > >, diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index e3160fa13..d0a7b2e90 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -526,7 +526,7 @@ struct Sm90TreeVisitor< if (lane_idx == i) { copy_if(FunctionPredTensor(predicate_fn), tC_rAux, tC_gAux); } - __syncwarp(); + syncwarp(); } } } diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index c8d941b62..51f619fbd 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -781,7 +781,7 @@ struct Sm90RowReduction { // if constexpr (not IsAtomic && FinalReduction) { // Ensure gmem writes are visible to other threads before incrementing counter - __threadfence(); + threadfence(); sync_fn(); // Collective thread 0 increments atomic tile counter and copies value to smem int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); @@ -1255,7 +1255,7 @@ struct Sm90ColReduction { // if constexpr (not IsAtomic && FinalReduction) { // Ensure gmem writes are visible to other threads before incrementing counter - __threadfence(); + threadfence(); sync_fn(); // Collective thread 0 increments atomic tile counter and copies value to smem int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); diff --git a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp index 5e8fce8b4..6e7aee895 100644 --- a/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp +++ b/include/cutlass/gemm/kernel/intel_pvc_gemm.hpp @@ -111,6 +111,12 @@ class GemmUniversal< static constexpr int VecC = CollectiveMainloop::VecC; + // Kernel level shared memory storage + struct SharedStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; + }; + // Device side arguments struct Arguments { GemmUniversalMode mode{}; @@ -188,7 +194,7 @@ class GemmUniversal< void operator()(Params const& params, char* smem_buf) { - (void)smem_buf; + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); // Preconditions CUTE_STATIC_ASSERT(is_static::value); @@ -214,6 +220,7 @@ class GemmUniversal< const int m_coord = BlockIdxX() * get<0>(subgroup_shape); const int n_coord = BlockIdxY() * get<1>(workgroup_shape) + thread_idx / SubgroupSize * get<1>(subgroup_shape); const int l_coord = BlockIdxZ(); + const auto tile_coord = make_coord(m_coord, n_coord, _, l_coord); Tensor tAi = params.mainloop.gmem_tiled_copy_a.get_pvc_tensor( make_coord(m_coord, 0, 0), @@ -253,13 +260,18 @@ class GemmUniversal< smem_buf, params.mainloop ); - auto gmem_tiled_copy_c = make_xe_2d_copy(make_tensor(params.epilogue.ptr_D, make_shape(M, N, L), params.epilogue.dD)); - - Tensor tCi = gmem_tiled_copy_c.get_pvc_tensor(make_coord(m_coord, n_coord, 0), - make_shape(Int{}, Int{}, L), - make_stride(get<0>(MmaAtomShape()), get<1>(MmaAtomShape()))); - copy(gmem_tiled_copy_c, accumulators, tCi(_,_,_,l_coord)); + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + epilogue( + problem_shape_MNKL, + subgroup_shape, + tile_coord, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + smem_buf + ); } };