Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CK_TILE] Sync fmha fwd splitkv minor optimizations #1785

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 127 additions & 68 deletions example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py

Large diffs are not rendered by default.

164 changes: 84 additions & 80 deletions example/ck_tile/01_fmha/fmha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <array>
#include <cstring>
#include <functional>
#include <map>
#include <numeric>
#include <ostream>
#include <string>
Expand Down Expand Up @@ -176,61 +177,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
}
}

int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 ||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
efficiency.push_back(0.f);
}
else
{
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency.push_back(eff);
}
}
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
continue;
}
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
{
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
}
}
return 1;
}

int override_num_splits_if_necessary(
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
int override_num_splits_if_necessary(int batch,
int nhead,
int max_seqlen_q,
int hdim_q,
int hdim_v,
float p_drop,
bool is_prefill,
int num_splits)
{
int device;
auto status = hipGetDevice(&device);
Expand All @@ -246,17 +200,41 @@ int override_num_splits_if_necessary(
return num_splits;
}

// tile size should match the generate.py
const int kM0 = 64;
const int kN1 = hdim_v;
const int kM0 = [&] {
// get kM0 for prefill phase
if(is_prefill)
{
return 128;
}

// get kM0 for decode phase
/// TODO: take dtype=fp8/bf8 into consideration
const std::map<int, int> hdim_to_m0 = {
{32, 32},
{64, 64},
// {96, 64},
{128, 64},
{256, 64},
};

for(auto [hdim, m0] : hdim_to_m0)
{
if(hdim_q <= hdim && hdim_v <= hdim)
{
return m0;
}
}

return 64; // meet unsupported hdim_q/hdim_v
}();
// const int kN1 = hdim_v;

const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
// const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); // always 1

if(num_splits < 1 && p_drop == 0.0f)
{
return num_splits_heuristic(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
return num_splits_heuristic(batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 8);
}

return num_splits;
Expand Down Expand Up @@ -556,8 +534,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
// legalize num_splits according to other options
if(num_splits < 1)
{
num_splits = override_num_splits_if_necessary(
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits);
num_splits = override_num_splits_if_necessary(batch,
nhead,
max_seqlen_q,
hdim_q,
hdim_v,
p_drop,
/*is_prefill=*/mode == mode_enum::group &&
0 < page_block_size,
num_splits);
}
if(128 < num_splits)
{
Expand Down Expand Up @@ -632,17 +617,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin<KDataType>(
std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed);

// lse_acc_host & o_acc_host are only used when 1 < num_spilts
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits || use_kvcache
1 < num_splits
? std::array<ck_tile::index_t, 4>{shape_batch, nhead, num_splits, shape_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits || use_kvcache ? std::array<ck_tile::index_t, 5>{shape_batch,
nhead,
num_splits,
shape_seqlen_q,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
1 < num_splits ? std::array<ck_tile::index_t, 5>{shape_batch,
nhead,
num_splits,
shape_seqlen_q,
hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});

// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
Expand Down Expand Up @@ -1043,9 +1029,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();

// lse_acc_buf & o_acc_buf are only used when 1 < num_spilts
args.block_table_ptr =
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
args.batch_stride_block_table = batch_stride_block_table;
Expand All @@ -1057,13 +1041,33 @@ bool run(const ck_tile::ArgParser& arg_parser)

args.num_splits = num_splits;

args.stride_o_acc = stride_o_acc;
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
args.nhead_stride_o_acc = nhead_stride_o_acc;
args.batch_stride_lse_acc = batch_stride_lse_acc;
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
if(1 < num_splits)
{
args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer();
args.o_acc_ptr = o_acc_buf.GetDeviceBuffer();

args.stride_o_acc = stride_o_acc;
args.nhead_stride_lse_acc = nhead_stride_lse_acc;
args.nhead_stride_o_acc = nhead_stride_o_acc;
args.batch_stride_lse_acc = batch_stride_lse_acc;
args.batch_stride_o_acc = batch_stride_o_acc;
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
}
else
{
// following attribues are ignored by fmha_fwd_splitkv()
args.lse_acc_ptr = nullptr;
args.o_acc_ptr = nullptr;

args.stride_o_acc = 0;
args.nhead_stride_lse_acc = 0;
args.nhead_stride_o_acc = 0;
args.batch_stride_lse_acc = 0;
args.batch_stride_o_acc = 0;
args.split_stride_lse_acc = 0;
args.split_stride_o_acc = 0;
}
}
}
};
Expand Down
Loading