From f40b19a2baeb7cb11f723fcdc1b2277876dcb528 Mon Sep 17 00:00:00 2001 From: Finlay Marno Date: Wed, 30 Oct 2024 13:49:43 +0000 Subject: [PATCH 1/9] Create a LinCombDeEltAct example and impl with XeAuxLoad --- examples/sycl/pvc/CMakeLists.txt | 5 + ...pvc_gemm_with_epilogue_lincombdeeltact.cpp | 452 ++++++++++++++++++ .../collective/builders/xe_builder.inl | 96 +++- .../epilogue/collective/xe_epilogue.hpp | 12 +- .../cutlass/epilogue/fusion/xe_callbacks.hpp | 105 ++++ .../cutlass/epilogue/fusion/xe_visitor.hpp | 200 ++++++++ ...gemm_bf16_bf16_fp32_tensor_op_fp32_evt.cpp | 60 +++ .../util/reference/device/tensor_compare.h | 66 +++ 8 files changed, 981 insertions(+), 15 deletions(-) create mode 100644 examples/sycl/pvc/pvc_gemm_with_epilogue_lincombdeeltact.cpp create mode 100644 include/cutlass/epilogue/fusion/xe_visitor.hpp diff --git a/examples/sycl/pvc/CMakeLists.txt b/examples/sycl/pvc/CMakeLists.txt index 5736847e88..abc138172d 100644 --- a/examples/sycl/pvc/CMakeLists.txt +++ b/examples/sycl/pvc/CMakeLists.txt @@ -42,6 +42,11 @@ cutlass_example_add_executable( pvc_gemm_with_epilogue_gelu.cpp ) +cutlass_example_add_executable( + pvc_gemm_with_epilogue_lincombdeeltact + pvc_gemm_with_epilogue_lincombdeeltact.cpp +) + cutlass_example_add_executable( pvc_collective_builder pvc_collective_builder.cpp diff --git a/examples/sycl/pvc/pvc_gemm_with_epilogue_lincombdeeltact.cpp b/examples/sycl/pvc/pvc_gemm_with_epilogue_lincombdeeltact.cpp new file mode 100644 index 0000000000..585e24568b --- /dev/null +++ b/examples/sycl/pvc/pvc_gemm_with_epilogue_lincombdeeltact.cpp @@ -0,0 +1,452 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/xe_epilogue.hpp" +#include "cutlass/epilogue/fusion/xe_callbacks.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_relu.h" +#include "cutlass/tensor_view.h" +#include "cutlass/coord.h" + +#include "common.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), iterations(100), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations, 100); + + // validate the arguments + bool m_valid = m > 0 and m % 8 == 0; + bool n_valid = n > 0 and n % 16 == 0; + bool k_valid = k > 0 and k % 16 == 0; + bool l_valid = l > 0; + if (!(m_valid and n_valid and k_valid and l_valid)) { + std::cout << "invalid arguments. Must be a multiple of (8, 16, 16)\n"; + std::exit(1); + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "PVC GEMM Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct sum_vals { + CUTLASS_HOST_DEVICE + T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct diff_vals { + CUTLASS_HOST_DEVICE + T operator()(T a, T b) const { + return a - b; + } +}; + +template +void print_device_tensor(DeviceAllocation device_alloc, std::size_t M, std::size_t N, std::size_t L, const char * message){ + syclcompat::wait(); + std::vector host_data(M*N*L); + device_alloc.copy_to_host(host_data.data()); + auto tensor_aux_host = make_tensor(host_data.data(), cute::make_shape(M,N,L)); + print(message); + print('\n'); + print_tensor(tensor_aux_host); + print('\n'); +} + + + +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_Aux; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + cutlass::TensorRef ref_Aux(block_Aux.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + syclcompat::wait(); + cutlass::reference::device::BlockElementwiseOp( + block_ref_D.get(), block_ref_D.get(), block_Aux.get(), block_D.size()); + syclcompat::wait(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); + + if (0 and !passed){ + print_device_tensor(block_D, M, N, L, "Actual"); + print_device_tensor(block_ref_D, M, N, L, "Reference"); + cutlass::reference::device::BlockElementwiseOp(block_ref_D.get(), block_ref_D.get(), block_D.get(), block_D.size()); + syclcompat::wait(); + print_device_tensor(block_ref_D, M, N, L, "difference"); + } + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + block_Aux.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + initialize_block(block_Aux, seed + 2020); + } + + void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + using EpilogueArguments = typename Gemm::GemmKernel::EpilogueArguments; + // TODO need to inject the aux argument here + EpilogueArguments epilogue_arguments{ + {options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}; + epilogue_arguments.thread.aux_ptr = block_Aux.get(); + epilogue_arguments.thread.dAux = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, // mode + problem_size, // problem_shape + {block_A.get(), stride_A, block_B.get(), stride_B}, // mainloop + epilogue_arguments, // epilogue + hw_info // hw_info + // scheduler + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + gemm_op.can_implement(arguments); + + gemm_op.initialize(arguments, workspace.get()); + + // Run the GEMM + gemm_op.run(); + + syclcompat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (passed && options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + gemm_op.run(); + } + syclcompat::wait(); + + float cute_time = timer.seconds() / options.iterations; + double tflops = (2.0 * options.m * options.n * options.k * options.l) * 1e-12; + std::cout << "Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + printf("Cutlass GEMM Performance: [%4.3f]TFlop/s (%6.4f)ms\n", tflops / cute_time, cute_time*1000); + } + + return; + } + +}; + +int main(int argc, const char** argv) +{ + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = float; // <- data type of epilogue operations + using ElementAux = float; // <- data type of epilogue operations + using ElementInputA = bfloat16_t; // <- data type of elements in input matrix A + using ElementInputB = bfloat16_t; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + + using GmemTiledCopyA = XE_2D_U16x32x32_LD_N; + using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; + + // Workgroup-level tile + using TileShape = Shape<_256, _256, _32>; + + using TiledMma = TiledMMA, + Layout>, + Tile<_64,_64,_32>>; // Subgroup level-tile + + constexpr int PipelineStages = 3; + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVC; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue; + + using CopyOpG2R = XE_2D_U32x8x16_LD_N; + using EpilogueOp = cutlass::epilogue::fusion::LinCombDeEltAct< + LayoutC, + sum_vals, + ElementOutput, + ElementComputeEpilogue>; + static_assert(std::is_same_v); + + using EpilogueTile = decltype(take<0,2>(TileShape{})); + + // issue is that this isn't specializing + using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< + EpilogueDispatchPolicy, + EpilogueOp, + TileShape, + EpilogueTile, + CopyOpG2R + >; + using CollectiveEpilogue = cutlass::epilogue::collective::CollectiveEpilogue< + EpilogueDispatchPolicy, // IntelPVCEpilogue + TileShape, // CtaTileMNK + ElementAccumulator, // ElementC + cutlass::gemm::TagToStrideC_t, // StrideC + ElementOutput, // ElementD + cutlass::gemm::TagToStrideC_t, // StrideD + FusionCallBacks, // FusionCallBacks + CopyOpG2R, // CopyOpG2R + void,// SmemLayoutAtomC + void, // CopyOpS2R + XE_2D_U32x8x16_ST_N, // CopyOpR2G + void, // SmemLayoutAtomD + void>; // CopyOpR2S + +// Mainloop + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + GEMMDispatchPolicy, + TileShape, + ElementInputA, + cutlass::gemm::TagToStrideA_t, + ElementInputB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, void, void, cute::identity, // A + GmemTiledCopyB, void, void, cute::identity // B + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + runner.run(options, hw_info); + + return 0; +} diff --git a/include/cutlass/epilogue/collective/builders/xe_builder.inl b/include/cutlass/epilogue/collective/builders/xe_builder.inl index c51c9fba7b..5992369ec9 100644 --- a/include/cutlass/epilogue/collective/builders/xe_builder.inl +++ b/include/cutlass/epilogue/collective/builders/xe_builder.inl @@ -39,6 +39,86 @@ namespace cutlass::epilogue::collective { +namespace detail { + template + struct FusionOpInfo { + static_assert(cutlass::detail::dependent_false, "Could not find a builder specialization."); + }; + + template < + class ElementD, + class ElementCompute, + class ElementC + > + struct FusionOpInfo> { + constexpr static bool HasBuilder = true; + + template < + class DispatchPolicy, + class TileShape_MNK, + class EpilogueTile, + class> + using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< + DispatchPolicy, + cutlass::epilogue::fusion::LinearCombination, + TileShape_MNK, + EpilogueTile + >; + }; + + template < + template class ActivationFn, + class ElementD, + class ElementCompute, + class ElementC + > + struct FusionOpInfo> { + constexpr static bool HasBuilder = true; + template < + class DispatchPolicy, + class TileShape_MNK, + class EpilogueTile, + class> + + using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< + DispatchPolicy, + cutlass::epilogue::fusion::LinCombEltAct, + TileShape_MNK, + EpilogueTile + >; + }; + + template < + class GmemLayoutTagC, + template class ActivationFn, + class ElementD, + class ElementCompute, + class ElementC + > + struct FusionOpInfo> { + constexpr static bool HasBuilder = true; + + template < + class DispatchPolicy, + class TileShape_MNK, + class EpilogueTile, + class CopyOpG2R> + using FusionCallbacks = cutlass::epilogue::fusion::FusionCallbacks< + DispatchPolicy, + cutlass::epilogue::fusion::LinCombDeEltAct, + TileShape_MNK, + EpilogueTile, + CopyOpG2R + >; + }; +} + // Intel epilogue builder template < @@ -74,15 +154,7 @@ template < cute::is_same_v && cute::is_same_v && cute::is_same_v && - // Only linear combination is supported at the moment - (cute::is_same_v> || - cute::is_same_v> || - cute::is_same_v>) + detail::FusionOpInfo::HasBuilder > >{ #ifdef SYCL_NVIDIA_TARGET @@ -106,9 +178,7 @@ template < using SmemLayoutAtomD_ = void; using CopyOpR2S_ = void; - using FusionCallBacks = cutlass::epilogue::fusion::FusionCallbacks< - DispatchPolicy, FusionOpOrCallbacks, TileShape_MNK, - decltype(tile_shape(TiledMma()))>; + using FusionCallbacks = typename detail::FusionOpInfo::template FusionCallbacks; using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< DispatchPolicy, @@ -117,7 +187,7 @@ template < cutlass::gemm::TagToStrideC_t, ElementD, cutlass::gemm::TagToStrideC_t, - FusionCallBacks, + FusionCallbacks, CopyOpG2R, SmemLayoutAtomC_, CopyOpS2R_, diff --git a/include/cutlass/epilogue/collective/xe_epilogue.hpp b/include/cutlass/epilogue/collective/xe_epilogue.hpp index a5d00a0c03..13d61024bc 100644 --- a/include/cutlass/epilogue/collective/xe_epilogue.hpp +++ b/include/cutlass/epilogue/collective/xe_epilogue.hpp @@ -265,6 +265,10 @@ class CollectiveEpilogue< (void) smem; using namespace cute; + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + using MmaAtomShape = typename TiledMma::AtomShape_MNK; static constexpr auto BLK_M = get<0>(CtaTileMNK{}); static constexpr auto BLK_N = get<1>(CtaTileMNK{}); @@ -313,7 +317,7 @@ class CollectiveEpilogue< tiled_mma, SubgroupTileShape{}, // Epilogue tile params.xe_load_c, - cD, + rw_coord, residue_mn, cD, residue_mn, @@ -327,6 +331,10 @@ class CollectiveEpilogue< auto acc_frag = recast>(accumulators); auto trD_frag = recast>(trD); + constexpr int values_loaded = FragsM*FragsN*FragmentSize*SubgroupSize*ATOM_M*ATOM_N*ATOM_K; + constexpr int MN = get<0>(CtaTileMNK{}) * get<1>(CtaTileMNK{}); + static_assert(values_loaded == MN, "the total elements loaded by all threads should be the same as MxN" ); + CUTLASS_PRAGMA_UNROLL for (int epi_n = 0; epi_n < FragsN; epi_n++) { CUTLASS_PRAGMA_UNROLL @@ -341,7 +349,7 @@ class CollectiveEpilogue< auto acc_frag_mn = acc_frag(_, epi_m, epi_n); CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < FragmentSize; ++epi_v) { + for (int epi_v = 0; epi_v < size(trD_frag); ++epi_v) { trD_frag(epi_v) = cst_callbacks.visit(acc_frag_mn(epi_v), epi_v, epi_m, epi_n); } copy(params.xe_store_d, trD, rw_coord(_, epi_m, epi_n)); diff --git a/include/cutlass/epilogue/fusion/xe_callbacks.hpp b/include/cutlass/epilogue/fusion/xe_callbacks.hpp index f81d8fae99..6020cd95c3 100644 --- a/include/cutlass/epilogue/fusion/xe_callbacks.hpp +++ b/include/cutlass/epilogue/fusion/xe_callbacks.hpp @@ -43,6 +43,7 @@ #include "cutlass/epilogue/fusion/callbacks.hpp" #include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/xe_visitor.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" #include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" @@ -167,6 +168,110 @@ struct FusionCallbacks< using Impl::Impl; }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class CtaTileShapeMNK, + class StrideAux, + class CopyOpG2R, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using XeLinCombDeEltAct = + Sm90EVT, // activation(beta * C + (alpha * acc), aux) + Sm90LinearCombination, // beta * C + (alpha * acc) + XeAuxLoad // aux + >; + +// Z = Aux +// dY = alpha * acc + beta * C +// D = activation(dY, Z) +// +template < + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput_, + class ElementCompute_, + class ElementAux, + class ElementSource, + class ElementScalar, + int AlignmentAux, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class CopyOpG2R +> +struct FusionCallbacks< + epilogue::IntelPVCEpilogue, + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput_, ElementCompute_, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + CopyOpG2R +> : XeLinCombDeEltAct< + CtaTileShapeMNK, cutlass::gemm::TagToStrideC_t, CopyOpG2R, ActivationFn, + ElementOutput_, ElementCompute_, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + > { + + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + + using Impl = + XeLinCombDeEltAct< + CtaTileShapeMNK, cutlass::gemm::TagToStrideC_t, CopyOpG2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >; + using Operation = + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux const* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // binary op : activation(beta * C + (alpha * acc), aux) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, ElementAux(0), dAux}, // leaf args : aux + activation // binary args : activation + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + + } // namespace cutlass::epilogue::fusion ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/xe_visitor.hpp b/include/cutlass/epilogue/fusion/xe_visitor.hpp new file mode 100644 index 0000000000..b8f0526190 --- /dev/null +++ b/include/cutlass/epilogue/fusion/xe_visitor.hpp @@ -0,0 +1,200 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree operations for the PVC epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +using namespace cutlass; +using namespace cutlass::epilogue::fusion; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Elementwise Load Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Element, + class StrideMNL, + class CopyOpG2R, + bool EnableNullptr = true +> +struct XeAuxLoad { + using SharedStorage = Element; + + struct Arguments { + Element const* ptr_aux = nullptr; + Element null_default = Element(0); + StrideMNL dAux = {}; + }; + + using Trait_Aux = Copy_Traits; + using SubgroupSize = Int; + using XE_Copy_Aux = decltype(make_tiled_copy(Copy_Atom{} + .with(static_cast(nullptr), int32_t(0), int32_t(0), int32_t(0)), + Layout>{}, + make_layout(make_shape(get<0>(typename Trait_Aux::Shape_MN{}), + get<1>(typename Trait_Aux::Shape_MN{}) / SubgroupSize{})))); + struct Params { + XE_Copy_Aux xe_load_aux; + Element null_default = Element(0); + bool use_default = false; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + // TODO(codeplay): This assumes a packed aux matrix + auto dAux = append<3>(args.dAux, 0); + auto N_AUX = get<0>(dAux); // dAux is a stride and N_AUX is a size + auto M_AUX = size(M); + XE_Copy_Aux xe_load_aux = make_tiled_copy(Copy_Atom{}.with( + args.ptr_aux, N_AUX, M_AUX, N_AUX), + Layout>{}, + make_layout(make_shape(get<0>(typename Trait_Aux::Shape_MN{}), + get<1>(typename Trait_Aux::Shape_MN{}) / SubgroupSize{}))); + + bool use_default = false; + if constexpr (EnableNullptr) { + use_default = args.ptr_aux == nullptr; + } + + return Params{xe_load_aux, args.null_default, use_default}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + XeAuxLoad() { } + + CUTLASS_HOST_DEVICE + XeAuxLoad(Params const& params, SharedStorage const&) : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (params_ptr->use_default && params_ptr->null_default == Element(0)); + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const&) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CTensor rw_coord; // (EPI_V, EPI_M, EPI_N) + XE_Copy_Aux xe_copy_aux; + RTensor tC_rAux; // (CPY,CPY_M,CPY_N) + Params const* params_ptr; + + CUTLASS_DEVICE + ConsumerStoreCallbacks(CTensor rw_coord, XE_Copy_Aux xe_copy_aux, RTensor&& tC_rAux, Params const* params_ptr) + : rw_coord(cute::forward(rw_coord)), xe_copy_aux(xe_copy_aux), tC_rAux(cute::forward(tC_rAux)), params_ptr(params_ptr) { } + + + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + if constexpr (EnableNullptr) { + if (params_ptr->use_default) { + fill(tC_rAux, params_ptr->null_default); + return; + } + } + + copy(xe_copy_aux, rw_coord(_, epi_m, epi_n), tC_rAux); + } + + // here is where we return values from the aux tile being processed + template + CUTLASS_DEVICE Array + visit(Array const&, int epi_v, int, int) { + Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); // (EPI_V) + return tC_rAux_frg(epi_v); + + } + }; + + template < + bool ReferenceSrc, + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto xe_copy_aux = params_ptr->xe_load_aux; + Tensor rw_coord = args.cD; + Tensor trAux = make_tensor_like(args.tCrC); + + return ConsumerStoreCallbacks( + rw_coord, xe_copy_aux, cute::move(trAux), params_ptr + ); + } +}; + diff --git a/test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_fp32_evt.cpp b/test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_fp32_evt.cpp index 6eaeb7afda..6b9fa4df6a 100644 --- a/test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_fp32_evt.cpp +++ b/test/unit/gemm/device/xe_gemm_bf16_bf16_fp32_tensor_op_fp32_evt.cpp @@ -106,3 +106,63 @@ TEST(XE_Device_Gemm_bf16t_bf16t_f32t_tensor_op_gmma_f32_epilogue, 256x256x32_Lin EXPECT_TRUE(passed); } +TEST(Xe_Gemm_bf16t_bf16t_f32_tensor_op_gmma_f32_epilogue_drelu, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + using LayoutAux = cutlass::layout::RowMajor; + + using TileShape_MNK = Shape<_256, _256, _32>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ElementInputA = bfloat16_t; + using ElementInputB = bfloat16_t; + using ElementOutput = float; + + constexpr int AlignmentA = sizeof(ElementInputA); + constexpr int AlignmentB = sizeof(ElementInputB); + constexpr int AlignmentC = sizeof(ElementAccumulator); + constexpr int AlignmentD = sizeof(ElementOutput); + + using FusionCallbacks = cutlass::epilogue::fusion::LinCombDeEltAct< + LayoutC, + cutlass::epilogue::thread::dReLU, + ElementOutput, + ElementComputeEpilogue>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::IntelPVC, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementComputeEpilogue, ElementAccumulator, + ElementAccumulator, LayoutC, AlignmentC, + ElementOutput, LayoutD, AlignmentD, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::IntelPVC, cutlass::arch::OpClassTensorOp, + ElementInputA, LayoutA, AlignmentA, + ElementInputB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestXe(1.0, 1.0); + EXPECT_TRUE(passed); +} + diff --git a/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/tools/util/include/cutlass/util/reference/device/tensor_compare.h index 3c312f5ff8..365413e47b 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -107,6 +107,25 @@ __global__ void } } +template