From b13775e56a8e2d68fa8a638744a11e0cefcaa38e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E4=B9=9D=E9=93=AE?= <77324096+jiuzhengWang@users.noreply.github.com> Date: Wed, 27 Mar 2024 09:43:47 +0800 Subject: [PATCH] Migrate CUTLASS micro kernel to CuTe (#601) * Migrate CUTLASS micro kernel to CuTe * Remove currently unsupported specializations * Workaround a cmake dependency issue --------- Co-authored-by: Jiuzheng Wang Co-authored-by: Shizhi Tang --- include/schedule/var_reorder.h | 11 +- pyproject.toml | 1 + .../micro_kernel/matmul/cutlass/gemm_sm80.h | 179 +++++++----------- src/schedule/lower_cutlass_micro_block.cc | 51 +++-- src/schedule/var_reorder.cc | 13 +- 5 files changed, 117 insertions(+), 138 deletions(-) diff --git a/include/schedule/var_reorder.h b/include/schedule/var_reorder.h index 65913208c..8cd363e48 100644 --- a/include/schedule/var_reorder.h +++ b/include/schedule/var_reorder.h @@ -14,11 +14,14 @@ class VarReorder : public SymbolTable { ID def_; std::string var_; std::vector order_; + bool forceReorderInMatMul_; bool found_ = false; public: - VarReorder(const ID &def, const std::vector &order) - : def_(def), order_(order) { + VarReorder(const ID &def, const std::vector &order, + bool forceReorderInMatMul) + : def_(def), order_(order), + forceReorderInMatMul_(forceReorderInMatMul) { std::vector numbers; numbers.reserve(order.size()); for (int i = 0, n = order.size(); i < n; i++) { @@ -58,6 +61,10 @@ class VarReorder : public SymbolTable { Stmt visit(const MatMul &op) override; }; +Stmt varReorderImpl(const Stmt &ast, const ID &def, + const std::vector &order, + bool forceReorderInMatMul = false); + Stmt varReorder(const Stmt &ast, const ID &def, const std::vector &order); } // namespace freetensor diff --git a/pyproject.toml b/pyproject.toml index 4da1cf248..9771137eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ doc = [ [build-system] requires = [ "py-build-cmake~=0.1.8", + "importlib_metadata", # Workaround https://github.com/scikit-build/cmake-python-distributions/issues/471 # We can't use pybind11-stubgen here. It will break CMake's incremental compilation "z3-solver", "setuptools", # Required by z3: https://github.com/Z3Prover/z3/issues/2374 diff --git a/runtime/micro_kernel/matmul/cutlass/gemm_sm80.h b/runtime/micro_kernel/matmul/cutlass/gemm_sm80.h index 9f780274a..e901941f8 100644 --- a/runtime/micro_kernel/matmul/cutlass/gemm_sm80.h +++ b/runtime/micro_kernel/matmul/cutlass/gemm_sm80.h @@ -1,131 +1,85 @@ /** * This file is borrowed from - * https://github.com/nox-410/tvm.tl/blob/tl/src/tl/tl_templates/gemm_sm80.h + * https://github.com/nox-410/tvm.tl/blob/tl/src/tl/tl_templates/cute_gemm.h * under Apache Lincense, and modified for use. */ -#ifndef MICRO_KERNEL_MATMUL_CUTLASS_GEMM_SM80_H -#define MICRO_KERNEL_MATMUL_CUTLASS_GEMM_SM80_H +#pragma once +#include #include #include #include -using cutlass::gemm::GemmShape; +using namespace cute; template struct DispatchInstruction; -template <> -struct DispatchInstruction { - using Shape = GemmShape<16, 8, 16>; -}; -template <> -struct DispatchInstruction { - using Shape = GemmShape<16, 8, 16>; -}; -template <> -struct DispatchInstruction { - using Shape = GemmShape<16, 8, 16>; -}; -template <> -struct DispatchInstruction { - using Shape = GemmShape<16, 8, 8>; -}; template <> struct DispatchInstruction { - using Shape = GemmShape<8, 8, 4>; -}; -template <> struct DispatchInstruction { - using Shape = GemmShape<16, 8, 32>; + using MMA = MMA_Atom; }; -template struct DispatchSharedMemoryLayout; - -template <> struct DispatchSharedMemoryLayout { - using Layout = cutlass::layout::ColumnMajor; -}; -template <> struct DispatchSharedMemoryLayout { - using Layout = cutlass::layout::RowMajor; +template +struct OperandTraits { + static constexpr int stride = K_inner ? K : N; + using Layout = typename std::conditional< + K_inner, Layout, Int>, Shape, _1>>, + Layout, Int>, Shape<_1, Int>>>::type; + using Copy = DefaultCopy; }; -template +template class GemmTensorOp { public: - using A_type = - typename std::conditional::value, - cutlass::tfloat32_t, A_type_raw>::type; - using B_type = - typename std::conditional::value, - cutlass::tfloat32_t, A_type_raw>::type; - using C_type = C_type_raw; - using InstructionShape = - typename DispatchInstruction::Shape; - using SMemLayoutA = typename DispatchSharedMemoryLayout::Layout; - using SMemLayoutB = typename DispatchSharedMemoryLayout::Layout; - - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma< - InstructionShape, 32, A_type, cutlass::layout::RowMajor, B_type, - cutlass::layout::ColumnMajor, C_type, cutlass::layout::RowMajor, - cutlass::arch::OpMultiplyAdd>, - cutlass::MatrixShape<1, 1>>; - - static_assert(Shape::kM % num_warp_m == 0); - static_assert(Shape::kN % num_warp_n == 0); - - using MmaWarp = typename cutlass::gemm::warp::MmaTensorOp< - GemmShape, - A_type, SMemLayoutA, B_type, SMemLayoutB, C_type, - cutlass::layout::RowMajor, Policy, 1, - true /* accumulate in row major */>; - - using TensorRefA = typename MmaWarp::IteratorA::TensorRef; - using TensorRefB = typename MmaWarp::IteratorB::TensorRef; - using FragmentA = typename MmaWarp::FragmentA; - using FragmentB = typename MmaWarp::FragmentB; - using FragmentC = typename MmaWarp::FragmentC; - using IteratorA = typename MmaWarp::IteratorA; - using IteratorB = typename MmaWarp::IteratorB; - - static_assert(Shape::kK % InstructionShape::kK == 0); - static int constexpr kKgroups = Shape::kK / InstructionShape::kK; - - static CUTLASS_DEVICE void body(const A_type_raw *pA, const B_type_raw *pB, - FragmentC &accum, int lda, int ldb, - double alpha, double beta, - const int warp_idx_m, const int warp_idx_n, - const int lane_id) { - MmaWarp mma_op; - FragmentA frag_A; - FragmentB frag_B; - const TensorRefA ref_A((A_type *)pA, lda); - const TensorRefB ref_B((B_type *)pB, ldb); - IteratorA iter_A(ref_A, lane_id); - IteratorB iter_B(ref_B, lane_id); - iter_A.add_tile_offset({warp_idx_m, 0}); - iter_B.add_tile_offset({0, warp_idx_n}); - - // TODO: Check all cases of alpha and beta - // TODO: Static checking of alpha and beta - if (beta == 0) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < FragmentC::kElements; i++) { - accum[i] = 0; - } - } else { - assert(beta == 1); - } - - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < kKgroups; ++k) { - iter_A.load(frag_A); - iter_B.load(frag_B); - ++iter_A; - ++iter_B; - mma_op(accum, frag_A, frag_B, accum); + using Instruction = DispatchInstruction; + + using OperandATraits = + OperandTraits::value, M, K, !trans_A>; + using OperandBTraits = + OperandTraits::value, N, K, trans_B>; + using SmemLayoutA = typename OperandATraits::Layout; + using SmemLayoutB = typename OperandBTraits::Layout; + using SmemCopyA = Copy_Atom; + using SmemCopyB = Copy_Atom; + + using TileMma = + TiledMMA, Int, _1>>>; + + static CUTE_DEVICE void body(const A_type *pA, const B_type *pB, C_type *pC, + int lda, int ldb, double alpha, double beta, + int warp_id_m, int warp_id_n, int lane_id) { + int tid = (warp_id_n * num_warp_m + warp_id_m) * 32 + lane_id; + // change the layout!!! + Tensor sA = make_tensor(make_smem_ptr((A_type *)(pA)), SmemLayoutA{}); + Tensor sB = make_tensor(make_smem_ptr((B_type *)(pB)), SmemLayoutB{}); + TileMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tid); + auto tiled_copy_A = make_tiled_copy_A(SmemCopyA{}, tiled_mma); + auto tiled_copy_B = make_tiled_copy_B(SmemCopyB{}, tiled_mma); + auto thr_copy_A = tiled_copy_A.get_thread_slice(tid); + auto thr_copy_B = tiled_copy_B.get_thread_slice(tid); + + Tensor tCrA = thr_mma.partition_fragment_A(sA); + Tensor tCrB = thr_mma.partition_fragment_B(sB); + Tensor tCsA = thr_copy_A.partition_S(sA); + Tensor tCsB = thr_copy_B.partition_S(sB); + + Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); + Tensor tCrB_copy_view = thr_copy_B.retile_D(tCrB); + + Tensor acc = + make_tensor(make_rmem_ptr(reinterpret_cast(pC)), + partition_shape_C(tiled_mma, Shape, Int>{})); + + int num_tile_k = size<2>(tCrA); + CUTE_UNROLL + for (int k = 0; k < num_tile_k; ++k) { + copy(tiled_copy_A, tCsA(_, _, k), tCrA_copy_view(_, _, k)); + copy(tiled_copy_B, tCsB(_, _, k), tCrB_copy_view(_, _, k)); + gemm(tiled_mma, tCrA(_, _, k), tCrB(_, _, k), acc); } } }; @@ -138,12 +92,9 @@ CUTLASS_DEVICE void matmul_thread(const A_type *pA, const B_type *pB, int strideb, int stridec, double alpha, double beta, int warp_id_batch, int warp_id_m, int warp_id_n, int lane_id) { - using MMA = GemmTensorOp, num_warp_m, num_warp_n, - trans_A, trans_B, A_type, B_type, C_type>; - using FragmentC = typename MMA::FragmentC; + using MMA = GemmTensorOp; MMA::body(pA + warp_id_batch * stridea, pB + warp_id_batch * strideb, - *(FragmentC *)(accum /* no thread offset */), lda, ldb, alpha, - beta, warp_id_m, warp_id_n, lane_id); + (accum /* no thread offset */), lda, ldb, alpha, beta, warp_id_m, + warp_id_n, lane_id); } - -#endif // MICRO_KERNEL_MATMUL_CUTLASS_GEMM_SM80_H diff --git a/src/schedule/lower_cutlass_micro_block.cc b/src/schedule/lower_cutlass_micro_block.cc index c54c14189..94d72b343 100644 --- a/src/schedule/lower_cutlass_micro_block.cc +++ b/src/schedule/lower_cutlass_micro_block.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -162,11 +163,11 @@ class LowerCutlassMicroBlock : public SymbolTable { auto batchInWarpPartition = makeEQ(op->indices_[nDimsCAll - 9], prop_->warpIdBatch_); auto mInWarpPartition = - makeEQ(op->indices_[nDimsCAll - 7], prop_->warpIdM_); + makeEQ(op->indices_[nDimsCAll - 4], prop_->warpIdM_); auto nInWarpPartition = - makeEQ(op->indices_[nDimsCAll - 4], prop_->warpIdN_); + makeEQ(op->indices_[nDimsCAll - 5], prop_->warpIdN_); auto mInThreadPartition = - makeEQ(op->indices_[nDimsCAll - 5], + makeEQ(op->indices_[nDimsCAll - 3], makeFloorDiv(prop_->laneId_, makeIntConst(4))); auto nInThreadPartition = makeEQ(op->indices_[nDimsCAll - 2], @@ -222,10 +223,12 @@ class LowerCutlassMicroBlock : public SymbolTable { ASSERT(nDimsCAll >= 9); // See comments in `lowerCutlassMicroBlock` below c->indices_[nDimsCAll - 9] = warpIdBatch; - c->indices_[nDimsCAll - 7] = warpIdM; - c->indices_[nDimsCAll - 5] = makeFloorDiv(laneId, makeIntConst(4)); - c->indices_[nDimsCAll - 4] = warpIdN; - c->indices_[nDimsCAll - 2] = makeMod(laneId, makeIntConst(4)); + c->indices_[nDimsCAll - 4] = warpIdM; // m warps + c->indices_[nDimsCAll - 3] = + makeFloorDiv(laneId, makeIntConst(4)); // m threads + c->indices_[nDimsCAll - 5] = warpIdN; // n warps + c->indices_[nDimsCAll - 2] = + makeMod(laneId, makeIntConst(4)); // n threads op->backend_ = MatMulBackend::CutlassMicroThread; op->cutlassMicroKernelProperty_ = prop_; @@ -278,13 +281,13 @@ Stmt lowerCutlassMicroBlock(const Stmt &_ast, const ID &matMulId, // ...: other leading dims, // -9: batch warps, // -8: batch serial, - // -7: m warps, - // -6: m 8-tiles, - // -5: m threads, - // -4: n warps, - // -3: n 8-tiles, + // -7: n 16-tiles, + // -6: m 16-tiles, + // -5: n warps + // -4: m warps, + // -3: m threads // -2: n threads, - // -1: n 2-tiles + // -1: n 2-tiles, // ] // // See @@ -333,19 +336,31 @@ Stmt lowerCutlassMicroBlock(const Stmt &_ast, const ID &matMulId, } else if (nDimsCBatch == 0) { ast = varUnsqueeze(ast, defIdC, nDimsCOthers); } + // clang-format off ast = varSplit( ast, defIdC, nDimsCOthers + 0, VarSplitMode::FixedSize, -1, nWarpBatch); + ast = varSplit( + ast, defIdC, nDimsCOthers + 2, VarSplitMode::FixedSize, 16, -1); ast = varSplit( - ast, defIdC, nDimsCOthers + 2, VarSplitMode::FixedSize, -1, nWarpM); - ast = varSplit( - ast, defIdC, nDimsCOthers + 3, VarSplitMode::FixedSize, 8, -1); + ast, defIdC, nDimsCOthers + 3, VarSplitMode::FixedSize, -1, nWarpM); ast = varSplit( - ast, defIdC, nDimsCOthers + 5, VarSplitMode::FixedSize, -1, nWarpN); + ast, defIdC, nDimsCOthers + 5, VarSplitMode::FixedSize, 16, -1); ast = varSplit( - ast, defIdC, nDimsCOthers + 6, VarSplitMode::FixedSize, 8, -1); + ast, defIdC, nDimsCOthers + 6, VarSplitMode::FixedSize, -1, nWarpN); ast = varSplit( ast, defIdC, nDimsCOthers + 7, VarSplitMode::FixedSize, 2, -1); + std::vector vec; + for(int i=0; i<=nDimsCOthers+1; i++) + vec.push_back(i); + vec.push_back(nDimsCOthers+5); + vec.push_back(nDimsCOthers+2); + vec.push_back(nDimsCOthers+6); + vec.push_back(nDimsCOthers+3); + vec.push_back(nDimsCOthers+4); + vec.push_back(nDimsCOthers+7); + vec.push_back(nDimsCOthers+8); + ast = varReorderImpl(ast, defIdC, vec, true); // clang-format on // Lower to CutlassMicroThread diff --git a/src/schedule/var_reorder.cc b/src/schedule/var_reorder.cc index fa2d18283..2e29c71ec 100644 --- a/src/schedule/var_reorder.cc +++ b/src/schedule/var_reorder.cc @@ -62,16 +62,16 @@ Expr VarReorder::visit(const Load &_op) { } Stmt VarReorder::visit(const MatMul &op) { - if (!var_.empty() && (allReads(op->equivalent_).count(var_) || + if (!var_.empty() && !forceReorderInMatMul_ && (allReads(op->equivalent_).count(var_) || allWrites(op->equivalent_).count(var_))) { throw InvalidSchedule("Please call var_reorder before as_matmul"); } return BaseClass::visit(op); } -Stmt varReorder(const Stmt &_ast, const ID &def, - const std::vector &order) { - VarReorder mutator(def, order); +Stmt varReorderImpl(const Stmt &_ast, const ID &def, + const std::vector &order, bool forceReorderInMatMul) { + VarReorder mutator(def, order, forceReorderInMatMul); auto ast = mutator(_ast); if (!mutator.found()) { throw InvalidSchedule(FT_MSG << def << " not found"); @@ -79,6 +79,11 @@ Stmt varReorder(const Stmt &_ast, const ID &def, return ast; } +Stmt varReorder(const Stmt &ast, const ID &def, + const std::vector &order) { + return varReorderImpl(ast, def, order); +} + void Schedule::varReorder(const ID &def, const std::vector &order) { beginTransaction(); auto log = appendLog(