Skip to content

Commit

Permalink
add splitk benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
FMarno committed Jan 7, 2025
1 parent eae5c5b commit 1751904
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 16 deletions.
6 changes: 6 additions & 0 deletions benchmarks/ampere/gemm_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ struct GemmConfiguration<
>;

using Gemm = GemmUniversalAdapter<GemmKernel>;

constexpr static typename GemmKernel::Arguments defaultArguments() { return {}; };
};

/////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -389,6 +391,8 @@ struct GemmConfiguration<
>;

using Gemm = GemmUniversalAdapter<GemmKernel>;

constexpr static typename GemmKernel::Arguments defaultArguments() { return {}; };
};

/////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -531,6 +535,8 @@ struct GemmConfiguration<
>;

using Gemm = GemmUniversalAdapter<GemmKernel>;

constexpr static typename GemmKernel::Arguments defaultArguments() { return {}; };
};

} // namespace device
Expand Down
13 changes: 6 additions & 7 deletions benchmarks/benchmark_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,12 @@ struct BenchmarkRunnerGemm {

initialize(problem_size);

typename Gemm::GemmKernel::Arguments arguments{
gemm::GemmUniversalMode::kGemm,
problem_size,
{block_A[0].get(), stride_A, block_B[0].get(), stride_B},
{{options.alpha, options.beta}, block_C[0].get(), stride_C, block_D.get(), stride_D},
hw_info
};
typename Gemm::GemmKernel::Arguments arguments = GemmConfiguration::defaultArguments();
arguments.mode = gemm::GemmUniversalMode::kGemm;
arguments.problem_shape = problem_size;
arguments.mainloop = {block_A[0].get(), stride_A, block_B[0].get(), stride_B};
arguments.epilogue = {{options.alpha, options.beta}, block_C[0].get(), stride_C, block_D.get(), stride_D};
arguments.hw_info = hw_info;

Gemm gemm_op;

Expand Down
30 changes: 24 additions & 6 deletions benchmarks/pvc/benchmarks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ using PvcGemmBF16BF16FP32_RRR_1 = cutlass::gemm::device::GemmConfiguration<
float, cutlass::layout::RowMajor,
float, Shape<_256, _256, _32>,
TiledMMA<MMAAtom, Layout<Shape<_8,_4,_1>>>,
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V, void>;
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V,
cutlass::gemm::device::Scheduler::Parallel>;

using PvcGemmBF16BF16FP32_RRR_2 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
Expand All @@ -51,7 +52,8 @@ using PvcGemmBF16BF16FP32_RRR_2 = cutlass::gemm::device::GemmConfiguration<
float, cutlass::layout::RowMajor,
float, Shape<_128, _512, _32>,
TiledMMA<MMAAtom, Layout<Shape<_4,_8,_1>>>,
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V, void>;
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V,
cutlass::gemm::device::Scheduler::Parallel>;

using PvcGemmBF16BF16FP32_RRR_3 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
Expand All @@ -60,7 +62,8 @@ using PvcGemmBF16BF16FP32_RRR_3 = cutlass::gemm::device::GemmConfiguration<
float, cutlass::layout::RowMajor,
float, Shape<_256, _128, _32>,
TiledMMA<MMAAtom, Layout<Shape<_8,_4,_1>>>,
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V, void>;
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V,
cutlass::gemm::device::Scheduler::Parallel>;

using PvcGemmBF16BF16FP32_RRR_4 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
Expand All @@ -69,7 +72,8 @@ using PvcGemmBF16BF16FP32_RRR_4 = cutlass::gemm::device::GemmConfiguration<
float, cutlass::layout::RowMajor,
float, Shape<_128, _256, _16>,
TiledMMA<MMAAtom, Layout<Shape<_4,_8,_1>>>,
XE_2D_U16x32x16_LD_N, XE_2D_U16x16x32_LD_V, void>;
XE_2D_U16x32x16_LD_N, XE_2D_U16x16x32_LD_V,
cutlass::gemm::device::Scheduler::Parallel>;

using PvcGemmBF16BF16FP32_RRR_5 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
Expand All @@ -78,7 +82,8 @@ using PvcGemmBF16BF16FP32_RRR_5 = cutlass::gemm::device::GemmConfiguration<
float, cutlass::layout::RowMajor,
float, Shape<_8, _128, _32>,
TiledMMA<MMAAtom, Layout<Shape<_1,_4,_1>>>,
XE_2D_U16x8x32_LD_N, XE_2D_U16x32x32_LD_V, void>;
XE_2D_U16x8x32_LD_N, XE_2D_U16x32x32_LD_V,
cutlass::gemm::device::Scheduler::Parallel>;

CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1);
CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2);
Expand All @@ -94,15 +99,28 @@ using PvcGemmBF16BF16FP32_StreamK_RRR_1 = cutlass::gemm::device::GemmConfigurati
float, Shape<_256, _256, _32>,
TiledMMA<MMAAtom, Layout<Shape<_8,_4,_1>>>,
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V,
cutlass::gemm::StreamKScheduler>;
cutlass::gemm::device::Scheduler::StreamK>;

CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_StreamK_RRR_1);

using PvcGemmBF16BF16FP32_SplitK_RRR_1 = cutlass::gemm::device::GemmConfiguration<
cutlass::arch::IntelPVC,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
cutlass::bfloat16_t, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, Shape<_256, _256, _32>,
TiledMMA<MMAAtom, Layout<Shape<_8,_4,_1>>>,
XE_2D_U16x32x32_LD_N, XE_2D_U16x32x32_LD_V,
cutlass::gemm::device::Scheduler::SplitK>;

CUTLASS_CREATE_GEMM_BENCHMARK(PvcGemmBF16BF16FP32_SplitK_RRR_1);

static void register_benchmarks() {
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_1);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_2);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_3);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_4);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_RRR_5);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_StreamK_RRR_1);
CUTLASS_BENCHMARK(PvcGemmBF16BF16FP32_SplitK_RRR_1);
}
26 changes: 23 additions & 3 deletions benchmarks/pvc/gemm_configuration.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ namespace cutlass {
namespace gemm {
namespace device {

enum class Scheduler { Parallel, SplitK, StreamK };

template<
class ArchTag,
class ElementA, class LayoutA,
Expand All @@ -60,7 +62,7 @@ template<
class ElementAccumulator,
class TileShape, class TiledMma,
class GmemTiledCopyA, class GmemTiledCopyB,
class TileScheduler>
Scheduler TileScheduler>
struct GemmConfiguration {
static_assert(sizeof(ElementA) == 0, "No valid GemmConfiguration configuration exists.");
};
Expand All @@ -70,7 +72,7 @@ struct GemmConfiguration {
// bfloat16

template<typename LayoutA, typename LayoutB, typename LayoutC,
class TileShape, class TiledMma, class GmemTiledCopyA, class GmemTiledCopyB, class TileScheduler>
class TileShape, class TiledMma, class GmemTiledCopyA, class GmemTiledCopyB, Scheduler TileScheduler>
struct GemmConfiguration<
arch::IntelPVC,
bfloat16_t, LayoutA,
Expand Down Expand Up @@ -113,10 +115,28 @@ struct GemmConfiguration<
Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue,
TileScheduler
std::conditional_t<TileScheduler == Scheduler::Parallel, void, cutlass::gemm::StreamKScheduler>
>;

using Gemm = GemmUniversalAdapter<GemmKernel>;

constexpr static typename GemmKernel::Arguments defaultArguments() {
using StreamKMode =
cutlass::gemm::kernel::detail::PersistentTileSchedulerXeStreamKParams::DecompositionMode;
if constexpr (TileScheduler == Scheduler::Parallel) {
return {};
}
else if constexpr (TileScheduler == Scheduler::StreamK) {
typename GemmKernel::Arguments arguments{};
arguments.scheduler = {1, StreamKMode::StreamK};
return arguments;
}
else if constexpr (TileScheduler == Scheduler::SplitK) {
typename GemmKernel::Arguments arguments{};
arguments.scheduler = {1, StreamKMode::SplitK};
return arguments;
}
}
};

} // namespace device
Expand Down
17 changes: 17 additions & 0 deletions benchmarks/pvc/input.in
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,20 @@ PvcGemmBF16BF16FP32_StreamK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16
PvcGemmBF16BF16FP32_StreamK_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096
PvcGemmBF16BF16FP32_StreamK_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128
PvcGemmBF16BF16FP32_StreamK_RRR_1 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128

PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=8192
PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=8192 --n=32768
PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=512 --k=32768 --n=8192
PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=1024
PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=1024 --n=8192
PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=8192 --n=4096
PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=16384 --k=4096 --n=8192
PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=4096 --k=16384 --n=8192
PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=4096
PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=1024 --k=16384 --n=8192
PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=1 --m=8192 --k=16384 --n=1024
# PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=128 --n=16384
# PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=4096 --m=8 --k=16384 --n=128
PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=128 --n=4096
PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=4 --m=32768 --k=4096 --n=128
PvcGemmBF16BF16FP32_SplitK_RRR_1 --bm_name=bf16_bf16_fp32 --l=32 --m=4096 --k=4096 --n=128

0 comments on commit 1751904

Please sign in to comment.