diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index fb8a4389f3..01461b81af 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -12,9 +12,9 @@ from codegen.cmake_config import * from codegen.cpp_symbol_map import * +import codegen.ops.fmha_fwd from codegen.ops.fmha_fwd import ( FmhaFwdTileSize, - FmhaFwdApiTrait, FMHA_FWD_KERNEL_HEADER, FMHA_FWD_API_PER_DTYPE, FMHA_FWD_API_PER_HDIM_CASE, @@ -48,7 +48,7 @@ using fmha_mask_{F_idx} = {F_mask}; namespace {{ -template +template struct kernel_runner {{ using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>; @@ -68,7 +68,7 @@ {F_lse}, {F_squant}, {F_pagedkv}, - kHasUnevenSplits, + kIsMultipleSplits && kHasUnevenSplits, {F_occupancy}>; using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem< @@ -81,7 +81,11 @@ typename FmhaFwdTypeConfig::LSEDataType, typename FmhaFwdTypeConfig::PDataType, typename FmhaFwdTypeConfig::OaccDataType, - typename FmhaFwdTypeConfig::OaccDataType, + std::conditional_t< + kIsMultipleSplits, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType + >, fmha_shape, {F_mode}, fmha_mask_{F_idx}, @@ -90,10 +94,17 @@ using fmha_pipeline = {F_pipeline}< fmha_pipeline_problem>; +/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving +/// store_tile_raw() data corruption issue using fmha_epilogue = - ck_tile::Default2DEpilogue::OaccDataType, - typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, - {F_spad}, {F_dvpad}>>; + ck_tile::Default2DEpilogue::OaccDataType, + std::conditional_t< + kIsMultipleSplits, + typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType + >, + false, false>>; using fmha_kernel = ck_tile::FmhaFwdSplitKVKernel; @@ -118,25 +129,30 @@ template<> void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ - if constexpr({F_mode} == false) {{ // batch mode - // we don't check every seqlen_k values for kvcache - if (a.seqlen_k_ptr != nullptr) {{ - kernel_runner::run(s, a); - // make sure F_bn0 is divisible by F_bk1 - }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ - kernel_runner::run(s, a); - }} else {{ - kernel_runner::run(s, a); + if (1 < a.num_splits) {{ + constexpr bool kIsMultipleSplits = true; + if constexpr({F_mode} == false) {{ // batch mode + // we don't check every seqlen_k values for kvcache + if (a.seqlen_k_ptr != nullptr) {{ + kernel_runner::run(s, a); + // make sure F_bn0 is divisible by F_bk1 + }} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{ + kernel_runner::run(s, a); + }} else {{ + kernel_runner::run(s, a); + }} + }} else {{ // group mode + kernel_runner::run(s, a); }} }} else {{ - kernel_runner::run(s, a); + kernel_runner::run(s, a); }} }} template<> std::string fmha_fwd_splitkv_get_name_() {{ - using k_ = kernel_runner::fmha_kernel; /// FIXME: choose real kernel type + using k_ = kernel_runner::fmha_kernel; /// FIXME: choose real kernel type return k_::GetName(); }} """ @@ -220,19 +236,32 @@ FMHA_FWD_SPLITKV_API=""" #include -template +template float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{ - if(s.log_level_ > 0) - std::cout - << ", " << fmha_fwd_splitkv_get_name_() - << ", " << fmha_fwd_splitkv_combine_get_name_() - << std::flush; - - return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} - ); + // fmha_fwd_splitkv_combine_traits_=void, launch splitkv kernel only + if constexpr (std::is_same_v) {{ + if(s.log_level_ > 0) + std::cout + << ", " << fmha_fwd_splitkv_get_name_() + << std::flush; + + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }} + ); + // launch both splitkv & combine kernels + }} else {{ + if(s.log_level_ > 0) + std::cout + << ", " << fmha_fwd_splitkv_get_name_() + << ", " << fmha_fwd_splitkv_combine_get_name_() + << std::flush; + + return ck_tile::launch_kernel(s, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} + ); + }} }} float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{ @@ -244,28 +273,45 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && ((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ - using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - - // get combine kernel tile sizes - using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; - constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes::kM0; - - // make sure we can reuse the padding flags in combine kernels - static_assert({F_bm0} % kM0 == 0); - static_assert({F_bn1} % 32 == 0); - - if (t.has_lse) {{ - if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ - return -1; + + if (1 < a.num_splits) {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + + // get combine kernel tile sizes + using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType; + constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes::kM0; + + // make sure we can reuse the padding flags in combine kernels + static_assert({F_bm0} % kM0 == 0); + static_assert({F_bn1} % 32 == 0); + + if (t.has_lse) {{ + if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ + return -1; + }} else {{ + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>; + + return fmha_fwd_splitkv_(s, a); + }} }} else {{ - using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>; + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>; return fmha_fwd_splitkv_(s, a); }} }} else {{ - using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>; + if (t.has_lse) {{ + if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ + return -1; + }} else {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + + return fmha_fwd_splitkv_(s, a); + }} + }} else {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, false, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; - return fmha_fwd_splitkv_(s, a); + return fmha_fwd_splitkv_(s, a); + }} }} }} """ @@ -605,7 +651,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, -1), - } + } elif dtype == 'fp8' or dtype == 'bf8': return { '64' : FmhaFwdSplitKVCombineTileSize(32, -1), @@ -629,26 +675,28 @@ def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]: squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): + for mask, bias, lse, pagedkv in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): # TODO: use async pipeline when compiler is more stable if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: - # if True: - pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) + + pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) else: - pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) - pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) + pipelines.append(Pipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) if receipt == 1: - pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim - pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, 't', squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim + pipelines.append(Pipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, pagedkv, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 't', squant, 'f', mask)) + for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ['t', 'f']): + pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, 'f', mask)) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None @@ -660,18 +708,27 @@ def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]: api_pool = FmhaFwdSplitKVApiPool(mask_impl) for dtype in FWD_DTYPE_MAP.keys(): - d = get_fmha_fwd_tile_dict_from_dtype(dtype) - if d == None: + prefill_tiles = codegen.ops.fmha_fwd.get_fmha_fwd_tile_dict_from_dtype(dtype) + decode_tiles = get_fmha_fwd_tile_dict_from_dtype(dtype) + if decode_tiles == None: continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] + + # make sure if all the hdim str keys in decode_tiles are also available in prefill_tiles + assert all(tile in prefill_tiles.keys() for tile in decode_tiles.keys()) + + for hdim_str, mode in itertools.product(decode_tiles.keys(), MODE_MAP.keys()): + prefill_tile = prefill_tiles[hdim_str] + decode_tile = decode_tiles[hdim_str] hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): if mode == "group": if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue + + is_prefill = (dtype in ['fp16', 'bf16'] and mode == "group" and pipeline.F_pagedkv == 't') + tile = prefill_tile if is_prefill else decode_tile + k = Kernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -683,8 +740,11 @@ def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]: if not fnmatch.fnmatch(k.name, kernel_filter): continue if receipt == 2: + is_chunked_prefill = (mode == 'group' and pipeline.F_pagedkv == 't') + cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' + # use vlayout=row for chunked prefill + cond = cond and ((pipeline.F_vlayout == 'row' and not is_chunked_prefill) or (pipeline.F_vlayout == 'col' and is_chunked_prefill)) cond &= pipeline.F_bias in ['no', 'alibi'] cond &= pipeline.F_squant == 'f' if not cond: @@ -723,12 +783,11 @@ def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]: d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype) if d == None: continue - #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): tile = d[hdim_str] hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): - if mode == "group": + if mode == 'group': if pipeline.F_spad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index b3855e59df..a595188b92 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -176,61 +177,14 @@ auto get_elimit(std::string init_method) } } -int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) -{ - // If we have enough to almost fill the SMs, then just use 1 split - if(batch_nhead_mblocks >= 0.8f * num_SMs) - { - return 1; - } - max_splits = std::min({max_splits, num_SMs, num_n_blocks}); - float max_efficiency = 0.f; - std::vector efficiency; - efficiency.reserve(max_splits); - auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; - // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, - // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks - // (i.e. it's 11 splits anyway). - // So we check if the number of blocks per split is the same as the previous num_splits. - auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { - return num_splits == 1 || - ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); - }; - for(int num_splits = 1; num_splits <= max_splits; num_splits++) - { - if(!is_split_eligible(num_splits)) - { - efficiency.push_back(0.f); - } - else - { - float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; - float eff = n_waves / ceil(n_waves); - // printf("num_splits = %d, eff = %f\n", num_splits, eff); - if(eff > max_efficiency) - { - max_efficiency = eff; - } - efficiency.push_back(eff); - } - } - for(int num_splits = 1; num_splits <= max_splits; num_splits++) - { - if(!is_split_eligible(num_splits)) - { - continue; - } - if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) - { - // printf("num_splits chosen = %d\n", num_splits); - return num_splits; - } - } - return 1; -} - -int override_num_splits_if_necessary( - int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) +int override_num_splits_if_necessary(int batch, + int nhead, + int max_seqlen_q, + int hdim_q, + int hdim_v, + float p_drop, + bool is_prefill, + int num_splits) { int device; auto status = hipGetDevice(&device); @@ -246,17 +200,41 @@ int override_num_splits_if_necessary( return num_splits; } - // tile size should match the generate.py - const int kM0 = 64; - const int kN1 = hdim_v; + const int kM0 = [&] { + // get kM0 for prefill phase + if(is_prefill) + { + return 128; + } + + // get kM0 for decode phase + /// TODO: take dtype=fp8/bf8 into consideration + const std::map hdim_to_m0 = { + {32, 32}, + {64, 64}, + // {96, 64}, + {128, 64}, + {256, 64}, + }; + + for(auto [hdim, m0] : hdim_to_m0) + { + if(hdim_q <= hdim && hdim_v <= hdim) + { + return m0; + } + } + + return 64; // meet unsupported hdim_q/hdim_v + }(); + // const int kN1 = hdim_v; const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); - const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); + // const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); // always 1 if(num_splits < 1 && p_drop == 0.0f) { - return num_splits_heuristic( - batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); + return num_splits_heuristic(batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 8); } return num_splits; @@ -556,8 +534,15 @@ bool run(const ck_tile::ArgParser& arg_parser) // legalize num_splits according to other options if(num_splits < 1) { - num_splits = override_num_splits_if_necessary( - batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits); + num_splits = override_num_splits_if_necessary(batch, + nhead, + max_seqlen_q, + hdim_q, + hdim_v, + p_drop, + /*is_prefill=*/mode == mode_enum::group && + 0 < page_block_size, + num_splits); } if(128 < num_splits) { @@ -632,17 +617,18 @@ bool run(const ck_tile::ArgParser& arg_parser) auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin( std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed); + // lse_acc_host & o_acc_host are only used when 1 < num_spilts ck_tile::HostTensor lse_acc_host( - 1 < num_splits || use_kvcache + 1 < num_splits ? std::array{shape_batch, nhead, num_splits, shape_seqlen_q} : std::array{1, 1, 1, 1}); ck_tile::HostTensor o_acc_host( - 1 < num_splits || use_kvcache ? std::array{shape_batch, - nhead, - num_splits, - shape_seqlen_q, - hdim_v} - : std::array{1, 1, 1, 1, 1}); + 1 < num_splits ? std::array{shape_batch, + nhead, + num_splits, + shape_seqlen_q, + hdim_v} + : std::array{1, 1, 1, 1, 1}); // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] @@ -1043,9 +1029,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } else if constexpr(std::is_same_v>) { - args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); - args.o_acc_ptr = o_acc_buf.GetDeviceBuffer(); - + // lse_acc_buf & o_acc_buf are only used when 1 < num_spilts args.block_table_ptr = (0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr); args.batch_stride_block_table = batch_stride_block_table; @@ -1057,13 +1041,33 @@ bool run(const ck_tile::ArgParser& arg_parser) args.num_splits = num_splits; - args.stride_o_acc = stride_o_acc; - args.nhead_stride_lse_acc = nhead_stride_lse_acc; - args.nhead_stride_o_acc = nhead_stride_o_acc; - args.batch_stride_lse_acc = batch_stride_lse_acc; - args.batch_stride_o_acc = batch_stride_o_acc; - args.split_stride_lse_acc = split_stride_lse_acc; - args.split_stride_o_acc = split_stride_o_acc; + if(1 < num_splits) + { + args.lse_acc_ptr = lse_acc_buf.GetDeviceBuffer(); + args.o_acc_ptr = o_acc_buf.GetDeviceBuffer(); + + args.stride_o_acc = stride_o_acc; + args.nhead_stride_lse_acc = nhead_stride_lse_acc; + args.nhead_stride_o_acc = nhead_stride_o_acc; + args.batch_stride_lse_acc = batch_stride_lse_acc; + args.batch_stride_o_acc = batch_stride_o_acc; + args.split_stride_lse_acc = split_stride_lse_acc; + args.split_stride_o_acc = split_stride_o_acc; + } + else + { + // following attribues are ignored by fmha_fwd_splitkv() + args.lse_acc_ptr = nullptr; + args.o_acc_ptr = nullptr; + + args.stride_o_acc = 0; + args.nhead_stride_lse_acc = 0; + args.nhead_stride_o_acc = 0; + args.batch_stride_lse_acc = 0; + args.batch_stride_o_acc = 0; + args.split_stride_lse_acc = 0; + args.split_stride_o_acc = 0; + } } } }; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 0368de352f..38fee3384c 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -12,6 +12,7 @@ #include "mask.hpp" #include "rotary.hpp" +#include #include #include #include @@ -422,91 +423,93 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args) // create group mode kernel arguments if constexpr(Kernel::kIsGroupMode) { - return Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_acc_ptr, - args.o_acc_ptr, - args.batch, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.num_splits, - args.block_table_ptr, - args.batch_stride_block_table, - args.page_block_size, - args.is_gappy, - args.scale_s, - args.scale_p, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_o_acc, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_lse_acc, - args.nhead_stride_o_acc, - args.batch_stride_k, // only used for paged-kvcache - args.batch_stride_v, // only used for paged-kvcache - args.split_stride_lse_acc, - args.split_stride_o_acc, - args.window_size_left, - args.window_size_right, - args.mask_type); + return Kernel::MakeKargs( + args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + (1 < args.num_splits ? args.lse_acc_ptr : args.lse_ptr), + (1 < args.num_splits ? args.o_acc_ptr : args.o_ptr), + args.batch, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.is_gappy, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + (1 < args.num_splits ? args.stride_o_acc : args.stride_o), + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + (1 < args.num_splits ? args.nhead_stride_lse_acc : args.nhead_stride_lse), + (1 < args.num_splits ? args.nhead_stride_o_acc : args.nhead_stride_o), + args.batch_stride_k, // only used for paged-kvcache + args.batch_stride_v, // only used for paged-kvcache + (1 < args.num_splits ? args.split_stride_lse_acc : 0), + (1 < args.num_splits ? args.split_stride_o_acc : 0), + args.window_size_left, + args.window_size_right, + args.mask_type); } else { // create batch mode kernel arguments - return Kernel::MakeKargs(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_acc_ptr, - args.o_acc_ptr, - args.batch, - args.seqlen_q, - args.seqlen_k, - args.seqlen_k_ptr, - args.hdim_q, - args.hdim_v, - args.nhead_q, - args.nhead_q / args.nhead_k, - args.num_splits, - args.block_table_ptr, - args.batch_stride_block_table, - args.page_block_size, - args.cache_batch_idx, - args.scale_s, - args.scale_p, - args.stride_q, - args.stride_k, - args.stride_v, - args.stride_bias, - args.stride_o_acc, - args.nhead_stride_q, - args.nhead_stride_k, - args.nhead_stride_v, - args.nhead_stride_bias, - args.nhead_stride_lse_acc, - args.nhead_stride_o_acc, - args.batch_stride_q, - args.batch_stride_k, - args.batch_stride_v, - args.batch_stride_bias, - args.batch_stride_lse_acc, - args.batch_stride_o_acc, - args.split_stride_lse_acc, - args.split_stride_o_acc, - args.window_size_left, - args.window_size_right, - args.mask_type); + return Kernel::MakeKargs( + args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + (1 < args.num_splits ? args.lse_acc_ptr : args.lse_ptr), + (1 < args.num_splits ? args.o_acc_ptr : args.o_ptr), + args.batch, + args.seqlen_q, + args.seqlen_k, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.num_splits, + args.block_table_ptr, + args.batch_stride_block_table, + args.page_block_size, + args.cache_batch_idx, + args.scale_s, + args.scale_p, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + (1 < args.num_splits ? args.stride_o_acc : args.stride_o), + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + (1 < args.num_splits ? args.nhead_stride_lse_acc : args.nhead_stride_lse), + (1 < args.num_splits ? args.nhead_stride_o_acc : args.nhead_stride_o), + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + (1 < args.num_splits ? args.batch_stride_lse_acc : args.batch_stride_lse), + (1 < args.num_splits ? args.batch_stride_o_acc : args.batch_stride_o), + (1 < args.num_splits ? args.split_stride_lse_acc : 0), + (1 < args.num_splits ? args.split_stride_o_acc : 0), + args.window_size_left, + args.window_size_right, + args.mask_type); } }(); @@ -821,3 +824,40 @@ struct fmha_fwd_appendkv_traits float fmha_fwd_appendkv(fmha_fwd_appendkv_traits, fmha_fwd_appendkv_args, const ck_tile::stream_config&); + +template +Int num_splits_heuristic(Int batch_nhead_mblocks, Int num_SMs, Int max_splits) +{ + // If we have enough to almost fill the SMs, then just use 1 split + if(batch_nhead_mblocks >= 0.8f * num_SMs) + { + return 1; + } + + max_splits = std::min({max_splits, num_SMs}); + + constexpr std::array num_splits_array = {1, 2, 4, 8, 16}; + + float max_efficiency = 0.f; + std::array efficiency; + + for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx) + { + float n_blocks = float(batch_nhead_mblocks * num_splits_array[idx]) / num_SMs; + float eff = n_blocks / std::ceil(n_blocks); + + if(eff > max_efficiency) + { + max_efficiency = eff; + } + efficiency[idx] = eff; + } + for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx) + { + if(efficiency[idx] >= 0.85 * max_efficiency) + { + return num_splits_array[idx]; + } + } + return 1; +} diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 41f3383c7f..02ce449912 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 2f3a302eea..440b306705 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index d06d8529ac..8b5302257c 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 1510f18a30..9b9bf30ad3 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index cd1e43fb8c..15fa269740 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index c24744bdbc..95ead2645e 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index ba76e3070d..616db2fa5b 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index d5920f4837..4cbb59e95b 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/fmha/block/page_block_navigator.hpp b/include/ck_tile/ops/fmha/block/page_block_navigator.hpp index 5d158f9fb3..6447371e4f 100644 --- a/include/ck_tile/ops/fmha/block/page_block_navigator.hpp +++ b/include/ck_tile/ops/fmha/block/page_block_navigator.hpp @@ -144,11 +144,15 @@ struct PageBlockNavigator const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin); const index_t new_block_index = get_block_index(global_window_origin); - /// TODO: only update necessary attributes - tile_window.bottom_tensor_view_.desc_ = - (is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor(); - tile_window.set_window_origin(local_window_origin); - tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index)); + if(block_index != new_block_index) + { + /// TODO: only update necessary attributes + tile_window.bottom_tensor_view_.desc_ = + (is_last_block(new_block_index) ? last_view : complete_view) + .get_tensor_descriptor(); + tile_window.set_window_origin(local_window_origin); + tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index)); + } return new_block_index; } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index 10ab25119b..acf42fb263 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -60,7 +60,7 @@ struct FmhaFwdSplitKVKernel template <> struct t2s { static constexpr const char * name = "bf8"; }; // clang-format on - __host__ static std::string GetName() + CK_TILE_HOST static std::string GetName() { // sync with generate.py // clang-format off @@ -237,7 +237,7 @@ struct FmhaFwdSplitKVKernel using Kargs = std::conditional_t; template - __host__ static constexpr std::enable_if_t + CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, @@ -361,7 +361,7 @@ struct FmhaFwdSplitKVKernel } template - __host__ static constexpr std::enable_if_t + CK_TILE_HOST static constexpr std::enable_if_t MakeKargs(const void* q_ptr, const void* k_ptr, const void* v_ptr, @@ -482,10 +482,20 @@ struct FmhaFwdSplitKVKernel ck_tile::index_t num_splits) { // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits, - nhead, - batch_size); + if constexpr(kIsGroupMode) + { + return dim3(nhead, + batch_size, + ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits); + } + else + { + return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits, + nhead, + batch_size); + } } CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) @@ -498,15 +508,27 @@ struct FmhaFwdSplitKVKernel return ck_tile::make_tuple(quotient, modulus); }; - const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits); - const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1); - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; + if constexpr(kIsGroupMode) + { + const auto [mn, i_split] = f(blockIdx.z, kargs.num_splits); + const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1); + const index_t i_nhead = blockIdx.x; + const index_t i_batch = blockIdx.y; + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + } + else + { + const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits); + const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1); + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch); + } } - __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp index 04aa85644d..b5ff1f3447 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp @@ -272,10 +272,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS aligned_physical_seqlen_k_start)}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window( - v_dram_block_window_lengths, - {0, aligned_physical_seqlen_k_start}, // TODO: hdim split? - Policy::template MakeVDramTileDistribution()); + auto [i_page_block_v, v_dram_block_window] = v_page_block_navigator.make_tile_window( + v_dram_block_window_lengths, {0, aligned_physical_seqlen_k_start}); auto q_tile = tile_elementwise_in(q_element_func, q); @@ -289,10 +287,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS do { // STAGE 1, QK gemm + // K DRAM tile window for load auto k_dram_window = make_tile_window( - k_dram_block_window, - Policy::template MakeKDramTileDistribution()); // K DRAM tile window for - // load + k_dram_block_window, Policy::template MakeKDramTileDistribution()); auto k_block_tile = load_tile(k_dram_window); { @@ -334,6 +331,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS k_block_tile = load_tile(k_dram_window); // global read i + 2 }); } + // V DRAM tile window for load + auto v_dram_window = make_tile_window( + v_dram_block_window, Policy::template MakeVDramTileDistribution()); const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile { // tail @@ -402,27 +402,31 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } move_tile_window(bias_dram_window, {0, kN0}); - /// TODO: only check in first/last iteration without increasing code size + // only check in first/last iterations if constexpr(kHasUnevenSplits) { - const auto k_origin = k_page_block_navigator.to_global_window_origin( - i_page_block_k, k_dram_block_window.get_window_origin()); - set_tile_if( - s_acc, - -numeric::infinity(), - [&, - physical_seqlen_k_start_ = physical_seqlen_k_start, - physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { - const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - if constexpr(kIsPagedKV) - { - return col < physical_seqlen_k_start_ || physical_seqlen_k_end_ <= col; - } - else - { - return physical_seqlen_k_end_ <= col; - } - }); + if(1 < num_splits && (i_total_loops == 0 || i_total_loops == num_total_loop - 1)) + { + const auto k_origin = k_page_block_navigator.to_global_window_origin( + i_page_block_k, k_dram_block_window.get_window_origin()); + set_tile_if(s_acc, + -numeric::infinity(), + [&, + physical_seqlen_k_start_ = physical_seqlen_k_start, + physical_seqlen_k_end_ = physical_seqlen_k_end](auto tile_idx) { + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + if constexpr(kIsPagedKV) + { + return col < physical_seqlen_k_start_ || + physical_seqlen_k_end_ <= col; + } + else + { + return physical_seqlen_k_end_ <= col; + } + }); + } } if constexpr(kPadSeqLenK || FmhaMask::IsMasking) @@ -445,6 +449,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS } } + __builtin_amdgcn_sched_barrier(0); + // move K tile window + i_page_block_k = k_page_block_navigator.move_tile_window( + i_page_block_k, k_dram_block_window, {kN0, 0}); + __builtin_amdgcn_sched_barrier(0); + const auto s = cast_tile(s_acc); // S{j} auto m_local = block_tile_reduce( s, @@ -457,6 +467,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS tile_elementwise_inout( [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j} + __builtin_amdgcn_sched_barrier(0); + // move V tile window + i_page_block_v = v_page_block_navigator.move_tile_window( + i_page_block_v, v_dram_block_window, {0, kN0}); + __builtin_amdgcn_sched_barrier(0); + auto p_compute = make_static_distributed_tensor( s.get_tile_distribution()); // Pcompute{j} @@ -549,8 +565,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS store_tile(v_lds_window, tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch } - i_page_block_v = - v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1}); + // moving v_dram_window is an in-page-block operation, so there is + // no need to invoke v_page_block_navigator.move_tile_window() here. + move_tile_window(v_dram_window, {0, kK1}); const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); @@ -582,13 +599,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS store_tile(v_lds_window, tile_elementwise_in(v_element_func, v)); // store next v } - i_page_block_v_ = v_page_block_navigator.move_tile_window( - i_page_block_v_, v_dram_window_, {0, kK1}); + move_tile_window(v_dram_window, {0, kK1}); }); } - // move K tile windows - i_page_block_k = k_page_block_navigator.move_tile_window( - i_page_block_k, k_dram_block_window, {kN0, 0}); + // tail { block_sync_lds(); diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index d23af0af8d..d2d328fc46 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 2d38ef5925..5bbe0601b7 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 2b02bcc5d2..d54b7f60d6 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index 711c5d8595..47d986e1c2 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index 02d8eabd8a..9392f8b439 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index 990e9ecc03..f3abe84e46 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index aa617ee2b4..b817d09c72 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index 8d075dc5fa..f75f05140a 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index 24a59b45b0..3fe1b5b213 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index 4df34e1e0d..391609622a 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index fcae3e02dc..40b9edd72f 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index cc7dbffee4..efc1d17637 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once