Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reverted the change of cD to rw_coord in consumer store args #180

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/cutlass/epilogue/collective/xe_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ class CollectiveEpilogue<
tile_coord_mnkl,
tiled_mma,
SubgroupTileShape{}, // Epilogue tile
params.xe_load_c,
rw_coord,
params.xe_store_d,
cD,
residue_mn,
cD,
residue_mn,
Expand Down
11 changes: 6 additions & 5 deletions include/cutlass/epilogue/fusion/xe_callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ struct FusionCallbacks<
/////////////////////////////////////////////////////////////////////////////////////////////////

template<
class CtaTileShapeMNK,
class StrideAux,
class CopyOpG2R,
template <class> class ActivationFn,
Expand All @@ -184,7 +185,7 @@ template<
using XeLinCombDeEltAct =
Sm90EVT<Sm90Compute<ActivationFn, ElementOutput, ElementCompute, RoundStyle>, // activation(beta * C + (alpha * acc), aux)
Sm90LinearCombination<ElementCompute, ElementCompute, ElementSource, ElementScalar, RoundStyle>, // beta * C + (alpha * acc)
XeAuxLoad<ElementAux, StrideAux, CopyOpG2R> // aux
XeAuxLoad<CtaTileShapeMNK, ElementAux, StrideAux, CopyOpG2R> // aux
>;

// Z = Aux
Expand Down Expand Up @@ -215,17 +216,17 @@ struct FusionCallbacks<
EpilogueTile,
CopyOpG2R
> : XeLinCombDeEltAct<
cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn, ElementOutput_,
ElementCompute_, ElementAux, ElementSource, ElementScalar, RoundStyle
CtaTileShapeMNK, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn,
ElementOutput_, ElementCompute_, ElementAux, ElementSource, ElementScalar, RoundStyle
> {

using ElementOutput = ElementOutput_;
using ElementCompute = ElementCompute_;

using Impl =
XeLinCombDeEltAct<
cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn, ElementOutput,
ElementCompute, ElementAux, ElementSource, ElementScalar, RoundStyle
CtaTileShapeMNK, cutlass::gemm::TagToStrideC_t<GmemLayoutTagAux>, CopyOpG2R, ActivationFn,
ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, RoundStyle
>;
using Operation =
fusion::LinCombDeEltAct<
Expand Down
27 changes: 26 additions & 1 deletion include/cutlass/epilogue/fusion/xe_visitor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ using namespace cutlass::epilogue::fusion;
/////////////////////////////////////////////////////////////////////////////////////////////////

template <
class CtaTileShapeMNK,
class Element,
class StrideMNL,
class CopyOpG2R,
Expand Down Expand Up @@ -190,9 +191,33 @@ struct XeAuxLoad {
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto xe_copy_aux = params_ptr->xe_load_aux;
Tensor rw_coord = args.cD;
Tensor trAux = make_tensor_like<Element>(args.tCrC);

using TiledMma = decltype(args.tiled_mma);
using MmaAtomShape = typename TiledMma::AtomShape_MNK;

static constexpr auto BLK_M = get<0>(CtaTileShapeMNK{});
static constexpr auto BLK_N = get<1>(CtaTileShapeMNK{});

static constexpr auto ATOM_M = get<1>(typename TiledMma::ThrLayoutVMNK{}.shape());
static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape());

static constexpr auto SG_M = BLK_M / ATOM_M;
static constexpr auto SG_N = BLK_N / ATOM_N;

static constexpr int FragsM = SG_M / get<0>(MmaAtomShape()); // A frags per sub_group
static constexpr int FragsN = SG_N / get<1>(MmaAtomShape()); // B frags per sub_group

auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl;
auto m_offset = m_coord * BLK_M + (get_sub_group_id() / ATOM_N) * SG_M;
auto n_offset = n_coord * BLK_N + (get_sub_group_id() % ATOM_N) * SG_N;
Tensor tOuti = args.tiled_copy.get_pvc_tensor(
make_coord(m_offset, n_offset, 0),
make_shape(_, Int<FragsM>{}, Int<FragsN>{}, L),
make_stride(Int<get<0>(MmaAtomShape{})>{}, Int<get<1>(MmaAtomShape{})>{}, _1{}));
Tensor rw_coord = tOuti(_,_,_,l_coord);

return ConsumerStoreCallbacks(
rw_coord, xe_copy_aux, cute::move(trAux), params_ptr
);
Expand Down
Loading