Skip to content

Commit

Permalink
Code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rtmadduri committed Jan 16, 2025
1 parent fb73588 commit 4eea3bd
Showing 1 changed file with 9 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
ComputeTypeA,
ComputeTypeB>;

// Block2CTileMap configuration parameter.
static constexpr index_t B2E_M01 = 8;
using CGridDesc_M_N =
remove_cvref_t<decltype(GridwiseGemm::MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))>;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;

using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
// Block2CTileMap configuration parameter.
using Block2ETileMap = typename GridwiseGemm::Block2CTileMap;;

using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMap>;
using KernelArgument = typename GridwiseGemm::Argument;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;

Expand Down Expand Up @@ -289,7 +285,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
GridwiseGemm::MakeCGridDescriptor_M_N(M, m_padded, N, n_padded, stride_c);

const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
Block2ETileMap{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);

const index_t block_start = grid_size_;
Expand Down Expand Up @@ -339,7 +335,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
karg.M, karg.N, m_padded, n_padded, karg.StrideC);

const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
Block2ETileMap{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);

const index_t block_start = grid_size_;
Expand Down Expand Up @@ -444,9 +440,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_kernel_args_.size(),
PassThrough{},
PassThrough{},
PassThrough{});
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation);
};

constexpr index_t minimum_occupancy = BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
Expand Down

0 comments on commit 4eea3bd

Please sign in to comment.