Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Intel gpu backend gemm pipeline #89

Merged
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
07f36e5
apply patch of gemm pipeline
Jiaxingla Jul 3, 2024
d4cf3eb
fix format of copyright
Jiaxingla Jul 10, 2024
665f9be
replace the macro of cache flush and idx
Jiaxingla Jul 10, 2024
59c0ce4
auto format
Jiaxingla Jul 11, 2024
bdadf1e
auto format
Jiaxingla Jul 11, 2024
9e23cd6
fix comments about prefetch
Jiaxingla Jul 14, 2024
60adb24
fix comments of enum and sycl macro
Jiaxingla Jul 15, 2024
1e3f855
update from tensor library repo
Jiaxingla Jul 15, 2024
8e951d1
fix format
Jiaxingla Jul 15, 2024
c92adb3
rm redundancy code
Jiaxingla Jul 15, 2024
6bdda75
resolve conflict
Jiaxingla Jul 16, 2024
3496593
revert the change of nv hpp
Jiaxingla Jul 17, 2024
69d5c2a
Restore invalid changes
Jiaxingla Jul 17, 2024
962766b
refine gemm interface will codeplay epilogue
Jiaxingla Jul 18, 2024
7739df6
fix the issue of batch gemm
Jiaxingla Jul 18, 2024
5b1f514
rm epilogue and revert gemm example
Jiaxingla Jul 23, 2024
f5e23e8
only keep code changes of gemm
Jiaxingla Jul 25, 2024
1c57c36
comments clean
Jiaxingla Jul 26, 2024
13ae1a1
rebase other examples
Jiaxingla Jul 26, 2024
fdb7244
rm vnni_matrix func
Jiaxingla Jul 26, 2024
d09da29
code clean
Jiaxingla Jul 26, 2024
5a3d227
define N-major tensor
Jiaxingla Jul 29, 2024
b50574a
delete useless header
Jiaxingla Jul 30, 2024
2c6d1ba
more comments
Jiaxingla Jul 30, 2024
c97ccd8
modify comments
Jiaxingla Jul 30, 2024
ede5c03
Update pvc_gemm
Jiaxingla Jul 31, 2024
f9aae6f
Update mma_xe
Jiaxingla Jul 31, 2024
7878a7c
more comments
Jiaxingla Jul 31, 2024
4c42645
code clean
Jiaxingla Jul 31, 2024
abbbe4f
fix typo
Jiaxingla Jul 31, 2024
8e9a84f
revert the change of copy_atom
Jiaxingla Jul 31, 2024
ea30c83
rename enum of LSC_LDCC
Jiaxingla Aug 1, 2024
043fbea
fix typo
Jiaxingla Aug 1, 2024
abf38bd
scope enums
Jiaxingla Aug 1, 2024
5193329
modify commment of copy
Jiaxingla Aug 2, 2024
b854995
remove useless copy
Jiaxingla Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 2 additions & 34 deletions benchmarks/common/benchmark_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,42 +320,10 @@ struct PvcBenchmarkRunner : BenchmarkRunner<Gemm> {

using ProblemShapeType = typename Base::ProblemShapeType;

cutlass::DeviceAllocation<ElementB> block_B_vnni;

template <typename T>
void vnni_matrix(
T* dst, const T* src,
int batch, int numRows, int numCols, int factor)
{
for (int b = 0; b < batch; b++) {
for (int r = 0; r < numRows / factor; r++) {
for (int c = 0; c < numCols; c++) {
for (int k = 0; k < factor; k++) {
dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] =
src[((b * (numRows / factor) + r) * factor + k) * numCols + c];
}
}
}
}
}

void initialize(const ProblemShapeType& problem_size) override {
Base::initialize(problem_size);

auto problem_shape_MNKL = cute::append<4>(problem_size, 1);
auto [M, N, K, L] = problem_shape_MNKL;

block_B_vnni.reset(Base::block_B.size());

std::vector<ElementB> b(K * N * L);
std::vector<ElementB> b_vnni(b.size());

Base::block_B.copy_to_host(b.data());
vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2);

block_B_vnni.copy_from_host(b_vnni.data());
}

void run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) override {
ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l};

Expand All @@ -364,7 +332,7 @@ struct PvcBenchmarkRunner : BenchmarkRunner<Gemm> {
typename Gemm::GemmKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{Base::block_A.get(), Base::stride_A, block_B_vnni.get(), Base::stride_B},
{Base::block_A.get(), Base::stride_A, Base::block_B.get(), Base::stride_B},
{
{options.alpha, options.beta},
Base::block_C.get(), Base::stride_C, Base::block_D.get(), Base::stride_D
Expand Down
6 changes: 3 additions & 3 deletions benchmarks/pvc/bench_pvc_gemm_bf16_bf16_fp32_dpas_fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ int main(int argc, const char** argv)
using LayoutD = cutlass::layout::RowMajor;

// Workgroup-level tile
using TileShape = Shape<_32, _256, _32>;
using TileShape = Shape<_256, _256, _32>;

using TiledMma = TiledMMA<
MMA_Atom<XE_8x16x16_F32BF16BF16F32_TN>,
MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_1,_1,_1>>,
Tile<_32,_64,_32>>; // Subgroup level-tile

using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V;

using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
Expand Down
36 changes: 7 additions & 29 deletions examples/sycl/pvc/pvc_gemm.cpp
aacostadiaz marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,6 @@ static void fill_matrix(std::vector<T> &vector)
return static_cast<T>( (rand() / double(RAND_MAX)) );
});
}

template <typename T>
static void vnni_matrix(
T* dst, const T* src,
int batch, int numRows, int numCols, int factor)
{
for (int b = 0; b < batch; b++) {
for (int r = 0; r < numRows / factor; r++) {
for (int c = 0; c < numCols; c++) {
for (int k = 0; k < factor; k++) {
dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] =
src[((b * (numRows / factor) + r) * factor + k) * numCols + c];
}
}
}
}
}

using namespace cute;

///////////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -89,7 +71,7 @@ struct Options {
Options():
help(false),
error(false),
m(4096), n(4096), k(4096), l(1), iterations(100),
m(4096), n(4096), k(4096), l(1), iterations(20),
alpha(1.f), beta(0.f)
{ }

Expand All @@ -108,7 +90,7 @@ struct Options {
cmd.get_cmd_line_argument("l", l, 1);
cmd.get_cmd_line_argument("alpha", alpha, 1.f);
cmd.get_cmd_line_argument("beta", beta, 0.f);
cmd.get_cmd_line_argument("iterations", iterations, 100);
cmd.get_cmd_line_argument("iterations", iterations, 20);
}

/// Prints the usage statement.
Expand Down Expand Up @@ -170,7 +152,6 @@ struct ExampleRunner {

cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementB> block_B_vnni;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<ElementOutput> block_D;
cutlass::DeviceAllocation<ElementOutput> block_ref_D;
Expand Down Expand Up @@ -231,7 +212,6 @@ struct ExampleRunner {

block_A.reset(M * K * L);
block_B.reset(K * N * L);
block_B_vnni.reset(K * N * L);
block_C.reset(M * N * L);
block_D.reset(M * N * L);
block_ref_D.reset(M * N * L);
Expand All @@ -247,11 +227,9 @@ struct ExampleRunner {
fill_matrix(a);
fill_matrix(b);
fill_matrix(c);
vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2);

syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA));
syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB));
syclcompat::memcpy(block_B_vnni.get(), b_vnni.data(), b.size() * sizeof(ElementB));
syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC));
syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC));
}
Expand All @@ -272,7 +250,7 @@ struct ExampleRunner {
typename Gemm::GemmKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{block_A.get(), stride_A, block_B_vnni.get(), stride_B},
{block_A.get(), stride_A, block_B.get(), stride_B},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D},
hw_info
};
Expand Down Expand Up @@ -362,14 +340,14 @@ int main(int argc, const char** argv)
using LayoutD = cutlass::layout::RowMajor;

using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding some explanation for naming conventions of copy function

using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V;

// Workgroup-level tile
using TileShape = Shape<_32, _256, _32>;
using TileShape = Shape<_256, _256, _32>;

using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TN>,
using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_1,_1,_1>>,
Tile<_32,_64,_32>>; // Subgroup level-tile
Tile<_32,_64,_32>>; // Subgroup level-tile

using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelPVCUnpredicated;
using EpilogueDispatchPolicy = cutlass::epilogue::IntelPVCEpilogue;
Expand Down
30 changes: 4 additions & 26 deletions examples/sycl/pvc/pvc_gemm_with_epilogue_relu.cpp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,6 @@ static void fill_matrix(std::vector<T> &vector)
});
}

template <typename T>
static void vnni_matrix(
T* dst, const T* src,
int batch, int numRows, int numCols, int factor)
{
for (int b = 0; b < batch; b++) {
for (int r = 0; r < numRows / factor; r++) {
for (int c = 0; c < numCols; c++) {
for (int k = 0; k < factor; k++) {
dst[((b * (numRows / factor) + r) * numCols + c) * factor + k] =
src[((b * (numRows / factor) + r) * factor + k) * numCols + c];
}
}
}
}
}

using namespace cute;

///////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -171,7 +154,6 @@ struct ExampleRunner {

cutlass::DeviceAllocation<ElementA> block_A;
cutlass::DeviceAllocation<ElementB> block_B;
cutlass::DeviceAllocation<ElementB> block_B_vnni;
cutlass::DeviceAllocation<ElementC> block_C;
cutlass::DeviceAllocation<ElementOutput> block_D;
cutlass::DeviceAllocation<ElementOutput> block_ref_D;
Expand Down Expand Up @@ -238,7 +220,6 @@ struct ExampleRunner {

block_A.reset(M * K * L);
block_B.reset(K * N * L);
block_B_vnni.reset(K * N * L);
block_C.reset(M * N * L);
block_D.reset(M * N * L);
block_ref_D.reset(M * N * L);
Expand All @@ -247,18 +228,15 @@ struct ExampleRunner {
// available through SYCL.
std::vector<ElementA> a(K * M * L);
std::vector<ElementB> b(K * N * L);
std::vector<ElementB> b_vnni(b.size());
std::vector<ElementC> c(M * N * L);
std::vector<ElementC> d(M * N * L, ElementC{0});

fill_matrix(a);
fill_matrix(b);
fill_matrix(c);
vnni_matrix(b_vnni.data(), b.data(), L, K, N, 2);

syclcompat::memcpy(block_A.get(), a.data(), a.size() * sizeof(ElementA));
syclcompat::memcpy(block_B.get(), b.data(), b.size() * sizeof(ElementB));
syclcompat::memcpy(block_B_vnni.get(), b_vnni.data(), b.size() * sizeof(ElementB));
syclcompat::memcpy(block_C.get(), c.data(), c.size() * sizeof(ElementC));
syclcompat::memcpy(block_D.get(), d.data(), d.size() * sizeof(ElementC));
}
Expand All @@ -271,7 +249,7 @@ struct ExampleRunner {
typename Gemm::GemmKernel::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
problem_size,
{block_A.get(), stride_A, block_B_vnni.get(), stride_B},
{block_A.get(), stride_A, block_B.get(), stride_B},
{{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D},
hw_info
};
Expand Down Expand Up @@ -361,12 +339,12 @@ int main(int argc, const char** argv)
using LayoutD = cutlass::layout::RowMajor;

using GmemTiledCopyA = XE_2D_U16x8x16x4x2_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x1_LD_N;
using GmemTiledCopyB = XE_2D_U16x16x16x2x2_V;

// Workgroup-level tile
using TileShape = Shape<_32, _256, _32>;
using TileShape = Shape<_256, _256, _32>;

using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TN>,
using TiledMma = TiledMMA<MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT>,
Layout<Shape<_1,_1,_1>>,
Tile<_32,_64,_32>>; // Subgroup level-tile

Expand Down
Loading
Loading