diff --git a/benchmarks/common/benchmark_runner.hpp b/benchmarks/common/benchmark_runner.hpp index 337cd28ef4..6b6cb11fd8 100644 --- a/benchmarks/common/benchmark_runner.hpp +++ b/benchmarks/common/benchmark_runner.hpp @@ -46,70 +46,94 @@ #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_compare.h" +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/reference/device/sycl_tensor_fill.h" +#else +#include "cutlass/util/reference/device/tensor_fill.h" +#endif #include "cutlass/util/print_error.hpp" #include -template -static void fill_matrix(std::vector &M) -{ - std::generate(std::begin(M), std::end(M), [&] - { return static_cast( 2 * (rand() / double(RAND_MAX)) - 1); }); -} - using namespace cute; /////////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + // Command line options parsing struct Options { - bool help; - bool error; - - int m, n, k, l, iterations; - float alpha, beta; - - Options(): - help(false), - error(false), - m(4096), n(4096), k(4096), l(1), - alpha(1.f), beta(0.f) - { } - - // Parses the command line - void parse(int argc, char const **args) { - cutlass::CommandLine cmd(argc, args); - - if (cmd.check_cmd_line_flag("help")) { - help = true; - return; - } - - cmd.get_cmd_line_argument("m", m, 4096); - cmd.get_cmd_line_argument("n", n, 4096); - cmd.get_cmd_line_argument("k", k, 4096); - cmd.get_cmd_line_argument("l", l, 1); - cmd.get_cmd_line_argument("alpha", alpha, 1.f); - cmd.get_cmd_line_argument("beta", beta, 0.f); - } + bool help; + bool error; - /// Prints the usage statement. - std::ostream & print_usage(std::ostream &out) const { - - out << "PVC GEMM Benchmark\n\n" - << "Options:\n\n" - << " --help If specified, displays this usage statement\n\n" - << " --m= Sets the M extent of the GEMM\n" - << " --n= Sets the N extent of the GEMM\n" - << " --k= Sets the K extent of the GEMM\n" - << " --l= Sets the L extent (batch count) of the GEMM\n" - << " --alpha= Epilogue scalar alpha\n" - << " --beta= Epilogue scalar beta\n\n" - << " --iterations= Iterations\n\n"; - - return out; + int m, n, k, l, iterations; + float alpha, beta; + + Options(): + help(false), + error(false), + m(5120), n(4096), k(4096), l(1), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; } + + cmd.get_cmd_line_argument("m", m, 5120); + cmd.get_cmd_line_argument("n", n, 4096); + cmd.get_cmd_line_argument("k", k, 4096); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "PVC GEMM Benchmark\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n"; + + return out; + } }; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -117,213 +141,184 @@ struct Options { template struct BenchmarkRunner { - using StrideA = typename Gemm::GemmKernel::StrideA; - using StrideB = typename Gemm::GemmKernel::StrideB; - using StrideC = typename Gemm::GemmKernel::StrideC; - using StrideD = typename Gemm::GemmKernel::StrideD; - - using LayoutA = typename Gemm::LayoutA; - using LayoutB = typename Gemm::LayoutB; - using LayoutC = typename Gemm::LayoutC; - using LayoutD = typename Gemm::LayoutD; - - using ElementA = typename Gemm::ElementA; - using ElementB = typename Gemm::ElementB; - using ElementAcc = typename Gemm::ElementAccumulator; - - using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; - using ElementC = typename Gemm::ElementC; - using ElementOutput = typename CollectiveEpilogue::ElementOutput; - using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; - - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - - // - // Data members - // - - /// Initialization - StrideA stride_A; - StrideB stride_B; - StrideC stride_C; - StrideD stride_D; - - cutlass::DeviceAllocation block_A; - cutlass::DeviceAllocation block_B; - cutlass::DeviceAllocation block_C; - cutlass::DeviceAllocation block_D; - cutlass::DeviceAllocation block_ref_D; - - ElementOutput epsilon; - ElementOutput nonzero_floor; - - BenchmarkRunner(std::string test_name) : epsilon(static_cast(0.1f)), - nonzero_floor(static_cast(0.1f)), test_name(test_name) { - int argc = 0; - benchmark::SetDefaultTimeUnit(benchmark::kMillisecond); - benchmark::Initialize(&argc, nullptr); - }; - - BenchmarkRunner(ElementOutput epsilon, ElementOutput nonzeroFloor, std::string test_name) : - epsilon(epsilon), nonzero_floor(nonzeroFloor), test_name(test_name) { - int argc = 0; - benchmark::SetDefaultTimeUnit(benchmark::kMillisecond); - benchmark::Initialize(&argc, nullptr); - } - - // - // Methods - // - - bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { - auto [M, N, K, L] = problem_size; - - cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); - cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); - cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); - cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); - - cutlass::reference::device::GemmComplex( - {M, N, K}, - alpha, - ref_A, - cutlass::ComplexTransform::kNone, - ref_B, - cutlass::ComplexTransform::kNone, - beta, - ref_C, - ref_D, - ElementAccumulator(0), - L, // batch_count - M * K, // batch_stride_A - K * N, // batch_stride_B - M * N, // batch_stride_C - M * N // batch_stride_D - ); + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + BenchmarkRunner(std::string test_name) : test_name(test_name) { + int argc = 0; + benchmark::SetDefaultTimeUnit(benchmark::kMillisecond); + benchmark::Initialize(&argc, nullptr); + }; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, ElementCompute alpha, ElementCompute beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + alpha, + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + beta, + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); #if defined(CUTLASS_ENABLE_SYCL) - syclcompat::wait(); + syclcompat::wait(); #else - cudaDeviceSynchronize(); + cudaDeviceSynchronize(); #endif - // Check if output from CUTLASS kernel and reference kernel are relatively equal or not - // need to set a larger error margin for comparison to succeed - auto epsilon = static_cast(0.1f); - auto nonzero_floor = static_cast(0.1f); - - bool passed = cutlass::reference::device::BlockCompareRelativelyEqual( - block_ref_D.get(), block_D.get(), block_D.size(), - epsilon, nonzero_floor); + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); - return passed; - } + return passed; + } - /// Initialize operands to be used in the GEMM and reference GEMM - virtual void initialize(const ProblemShapeType& problem_size) { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); - stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); - - block_A.reset(M * K * L); - block_B.reset(K * N * L); - block_C.reset(M * N * L); - block_D.reset(M * N * L); - block_ref_D.reset(M * N * L); - - // TODO: Enable initialization on device directly once RNG is - // available through SYCL. - std::vector a(K * M * L); - std::vector b(K * N * L); - std::vector c(M * N * L); - std::vector d(M * N * L, ElementC{-1}); - std::vector ref_d(M * N * L, ElementC{-2}); - - fill_matrix(a); - fill_matrix(b); - fill_matrix(c); - - block_A.copy_from_host(a.data(), a.size()); - block_B.copy_from_host(b.data(), b.size()); - block_C.copy_from_host(c.data(), c.size()); - block_D.copy_from_host(d.data(), d.size()); - block_ref_D.copy_from_host(ref_d.data(), d.size()); - } + /// Initialize operands to be used in the GEMM and reference GEMM + virtual void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } - virtual void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { - benchmark::ClearRegisteredBenchmarks(); - ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + virtual void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + benchmark::ClearRegisteredBenchmarks(); + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; - initialize(problem_size); + initialize(problem_size); - typename Gemm::GemmKernel::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {block_A.get(), stride_A, block_B.get(), stride_B}, - {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, - hw_info - }; + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; - Gemm gemm_op; + Gemm gemm_op; - size_t workspace_size = Gemm::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); - gemm_op.can_implement(arguments); + gemm_op.can_implement(arguments); - gemm_op.initialize(arguments, workspace.get()); + gemm_op.initialize(arguments, workspace.get()); - // Run the GEMM - gemm_op.run(); + // Run the GEMM + gemm_op.run(); #if defined(CUTLASS_ENABLE_SYCL) - syclcompat::wait(); + syclcompat::wait(); #else - cudaDeviceSynchronize(); + cudaDeviceSynchronize(); #endif - // Verify that the result is correct - bool passed = verify(problem_size, options.alpha, options.beta); - if(not passed) { - throw std::runtime_error("Disposition Failed."); - } - - std::stringstream full_test_name; - full_test_name << test_name << "/"; - std::string test_name_suffix = std::to_string(options.m) + "x" + - std::to_string(options.n) + "x" + - std::to_string(options.k) + "x" + - std::to_string(options.l); - full_test_name << test_name_suffix; - benchmark::RegisterBenchmark(full_test_name.str().c_str(), run_benchmark, options, gemm_op) - ->UseManualTime(); - benchmark::RunSpecifiedBenchmarks(); + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + if(not passed) { + throw std::runtime_error("Disposition Failed."); + } + + std::stringstream full_test_name; + full_test_name << test_name << "/"; + std::string test_name_suffix = std::to_string(options.m) + "x" + + std::to_string(options.n) + "x" + + std::to_string(options.k) + "x" + + std::to_string(options.l); + full_test_name << test_name_suffix; + benchmark::RegisterBenchmark(full_test_name.str().c_str(), run_benchmark, options, gemm_op) + ->UseManualTime(); + benchmark::RunSpecifiedBenchmarks(); } ~BenchmarkRunner() { benchmark::Shutdown(); } - - private: - static void run_benchmark(benchmark::State& state, const Options& options, Gemm gemm_op) { - state.counters["runtime_ms"] = 0; - for(auto _ : state) { - GPU_Clock timer; - timer.start(); - gemm_op.run(); - auto ms_elapsed = timer.milliseconds(); - state.counters["runtime_ms"] += ms_elapsed; - state.SetIterationTime(ms_elapsed / 1000); - } - state.counters["runtime_ms"] /= state.iterations(); - state.counters["TFlops"] = ((2.0 * options.m * options.n * options.k * options.l) * 1e-12) / - (state.counters["runtime_ms"] / 1000); + +private: + static void run_benchmark(benchmark::State& state, const Options& options, Gemm gemm_op) { + state.counters["runtime_ms"] = 0; + for(auto _ : state) { + GPU_Clock timer; + timer.start(); + gemm_op.run(); + auto ms_elapsed = timer.milliseconds(); + state.counters["runtime_ms"] += ms_elapsed; + state.SetIterationTime(ms_elapsed / 1000); } + state.counters["runtime_ms"] /= state.iterations(); + state.counters["TFlops"] = ((2.0 * options.m * options.n * options.k * options.l) * 1e-12) / + (state.counters["runtime_ms"] / 1000); + } - std::string test_name; + std::string test_name; }; diff --git a/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm_cute.cu b/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm_cute.cu index a5a924af9e..4c163b946f 100644 --- a/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm_cute.cu +++ b/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm_cute.cu @@ -70,7 +70,9 @@ #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_compare.h" -#if !defined(CUTLASS_ENABLE_SYCL) +#if defined(CUTLASS_ENABLE_SYCL) +#include "cutlass/util/reference/device/sycl_tensor_fill.h" +#else #include "cutlass/util/reference/device/tensor_fill.h" #endif #include "helper.h" @@ -185,46 +187,8 @@ bool initialize_block( scope_min = -8; } -#if defined(CUTLASS_ENABLE_SYCL) - using FloatType = typename std::conditional< - (sizeof(Element) > 4), - double, - float>::type; - - using IntType = typename std::conditional< - (sizeof(Element) > 4), - int64_t, - int>::type; - - srand(seed); - Element range = static_cast(scope_max - scope_min); - Element max = static_cast(scope_max); - int int_scale = 0; - - Element float_scale_up = FloatType(IntType(2) << int_scale); // scale up to clamp low order bits - Element float_scale_down = FloatType(1) / FloatType(IntType(2) << int_scale); - - // Random values are cast to integer after scaling by a power of two to facilitate error - // testing - auto const size = block.size(); - auto h_vector = std::vector(size); - for (int j = 0; j < size; ++j) { - FloatType rnd = rand() / double(RAND_MAX); - rnd = max - range * rnd; - - if (int_scale >= 0) { - rnd = FloatType(IntType(std::llround(rnd * float_scale_up))); - h_vector[j] = Element(IntType(rnd * float_scale_down)); - } - else { - h_vector[j] = Element(rnd); - } - } - syclcompat::memcpy(block.get(), h_vector.data(), size); -#else cutlass::reference::device::BlockFillRandomUniform( - block.get(), block.size(), seed, scope_max, scope_min, 0); -#endif + block.get(), block.size(), seed, scope_max, scope_min, 0); return true; } diff --git a/examples/sycl/pvc/common.h b/examples/sycl/pvc/common.h new file mode 100644 index 0000000000..cd11b1c7c9 --- /dev/null +++ b/examples/sycl/pvc/common.h @@ -0,0 +1,60 @@ +/*************************************************************************************************** +* Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/sycl_tensor_fill.h" + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + return true; +} diff --git a/examples/sycl/pvc/pvc_gemm.cpp b/examples/sycl/pvc/pvc_gemm.cpp index c0da127ca7..ea5d499219 100644 --- a/examples/sycl/pvc/pvc_gemm.cpp +++ b/examples/sycl/pvc/pvc_gemm.cpp @@ -45,14 +45,8 @@ #include "cutlass/util/packed_stride.hpp" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_compare.h" +#include "common.h" -template -static void fill_matrix(std::vector &vector) -{ - std::generate(std::begin(vector), std::end(vector), [&] { - return static_cast( (rand() / double(RAND_MAX)) ); - }); -} using namespace cute; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -69,7 +63,7 @@ struct Options { Options(): help(false), error(false), - m(4096), n(4096), k(4096), l(1), iterations(20), + m(5120), n(4096), k(4096), l(1), iterations(20), alpha(1.f), beta(0.f) { } @@ -82,13 +76,13 @@ struct Options { return; } - cmd.get_cmd_line_argument("m", m, 4096); + cmd.get_cmd_line_argument("m", m, 5120); cmd.get_cmd_line_argument("n", n, 4096); cmd.get_cmd_line_argument("k", k, 4096); cmd.get_cmd_line_argument("l", l, 1); cmd.get_cmd_line_argument("alpha", alpha, 1.f); cmd.get_cmd_line_argument("beta", beta, 0.f); - cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("iterations", iterations, 100); } /// Prints the usage statement. @@ -147,6 +141,7 @@ struct ExampleRunner { StrideB stride_B; StrideC stride_C; StrideD stride_D; + uint64_t seed = 0; cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; @@ -186,14 +181,9 @@ struct ExampleRunner { syclcompat::wait(); - // Check if output from CUTLASS kernel and reference kernel are relatively equal or not - // need to set a larger error margin for comparison to succeed - auto epsilon = static_cast(0.1f); - auto nonzero_floor = static_cast(0.1f); - - bool passed = cutlass::reference::device::BlockCompareRelativelyEqual( - block_ref_D.get(), block_D.get(), block_D.size(), - epsilon, nonzero_floor); + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); return passed; } @@ -214,22 +204,9 @@ struct ExampleRunner { block_D.reset(M * N * L); block_ref_D.reset(M * N * L); - // TODO: Enable initialization on device directly once RNG is - // available through SYCL. - std::vector a(K * M * L); - std::vector b(K * N * L); - std::vector b_vnni(b.size()); - std::vector c(M * N * L); - std::vector d(M * N * L, ElementC{0}); - - fill_matrix(a); - fill_matrix(b); - fill_matrix(c); - - syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA)); - syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB)); - syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC)); - syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC)); + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); } void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { diff --git a/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp index 521d64f7d5..5f4a7e8254 100644 --- a/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp +++ b/examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp @@ -49,13 +49,7 @@ #include "cutlass/tensor_view.h" #include "cutlass/coord.h" -template -static void fill_matrix(std::vector &vector) -{ - std::generate(std::begin(vector), std::end(vector), [&] { - return static_cast( (rand() / double(RAND_MAX)) ); - }); -} +#include "common.h" using namespace cute; @@ -73,7 +67,7 @@ struct Options { Options(): help(false), error(false), - m(4096), n(4096), k(4096), l(1), iterations(100), + m(5120), n(4096), k(4096), l(1), iterations(100), alpha(1.f), beta(0.f) { } @@ -86,7 +80,7 @@ struct Options { return; } - cmd.get_cmd_line_argument("m", m, 4096); + cmd.get_cmd_line_argument("m", m, 5120); cmd.get_cmd_line_argument("n", n, 4096); cmd.get_cmd_line_argument("k", k, 4096); cmd.get_cmd_line_argument("l", l, 1); @@ -151,6 +145,7 @@ struct ExampleRunner { StrideB stride_B; StrideC stride_C; StrideD stride_D; + uint64_t seed = 0; cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; @@ -196,14 +191,9 @@ struct ExampleRunner { syclcompat::wait(); - // Check if output from CUTLASS kernel and reference kernel are relatively equal or not - // need to set a larger error margin for comparison to succeed - auto epsilon = static_cast(0.1f); - auto nonzero_floor = static_cast(0.1f); - - bool passed = cutlass::reference::device::BlockCompareRelativelyEqual( - block_ref_D.get(), block_D.get(), block_D.size(), - epsilon, nonzero_floor); + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + block_ref_D.get(), block_D.get(), block_D.size()); return passed; } @@ -224,21 +214,9 @@ struct ExampleRunner { block_D.reset(M * N * L); block_ref_D.reset(M * N * L); - // TODO: Enable initialization on device directly once RNG is - // available through SYCL. - std::vector a(K * M * L); - std::vector b(K * N * L); - std::vector c(M * N * L); - std::vector d(M * N * L, ElementC{0}); - - fill_matrix(a); - fill_matrix(b); - fill_matrix(c); - - syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA)); - syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB)); - syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC)); - syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC)); + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); } void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { diff --git a/tools/util/include/cutlass/util/reference/device/sycl_tensor_fill.h b/tools/util/include/cutlass/util/reference/device/sycl_tensor_fill.h new file mode 100644 index 0000000000..7a039fd63c --- /dev/null +++ b/tools/util/include/cutlass/util/reference/device/sycl_tensor_fill.h @@ -0,0 +1,180 @@ +/*************************************************************************************************** +* Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +// Standard Library includes +#include +#include +#include +#include +#include +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" + + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template ///< Element type +struct RandomUniformFunc { + + using FloatType = typename std::conditional< + (sizeof(Element) > 4), + double, + float>::type; + + using IntType = typename std::conditional< + (sizeof(Element) > 4), + int64_t, + int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType max; + FloatType min; + int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; + + /// Default ctor + CUTLASS_HOST + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Element max_ = 1, + Element min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), + max(static_cast(max_)), + min(static_cast(min_)), + int_scale(int_scale_) { + + float_scale_up = FloatType(IntType(2) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(2) << int_scale); + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + std::default_random_engine generator; + std::normal_distribution distribution; + + // + // Methods + // + + explicit RandomUniformFunc(Params const ¶ms): + params(params), + generator(params.seed), + distribution(static_cast(params.min), static_cast(params.max)) { + } + + /// Compute random value and update RNG state + CUTLASS_HOST + Element operator()() { + FloatType rnd = distribution(generator); + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (params.int_scale >= 0) { + rnd = FloatType(IntType(std::llround(rnd * params.float_scale_up))); + result = Element(IntType(rnd * params.float_scale_down)); + } + else { + result = Element(rnd); + } + + return result; + } +}; + +} // namespace detail + +/// Fills a tensor with random values with a uniform random distribution. +template +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + typename RealType::Type max, ///< upper bound of distribution + typename RealType::Type min, ///< lower bound for distribution + int bits = -1 ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + ) { + + using RandomFunc = detail::RandomUniformFunc; + + typename RandomFunc::Params params(seed, max, min, bits); + + auto rand = RandomFunc(params); + auto h_vector = std::vector(capacity); + for (int j = 0; j < capacity; ++j) { + h_vector[j] = rand(); + } + syclcompat::memcpy(ptr, h_vector.data(), capacity); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass