Skip to content

Commit

Permalink
fix modify
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroRains committed Jan 11, 2025
1 parent 60b4ffb commit 7a0f294
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 25 deletions.
3 changes: 2 additions & 1 deletion paddle/phi/kernels/fusion/gpu/block_attn.h
Original file line number Diff line number Diff line change
Expand Up @@ -1601,8 +1601,9 @@ void dispatch_blha_impl_headsize(const phi::GPUContext &dev_ctx,
params, dev_ctx.stream(), load_func, store_func, use_cachekv_int8);
break;
case 96:
dispatch_blha_impl_blocksize<T, 96, 96>(
dispatch_blha_impl_blocksize<T, 96, 128>(
params, dev_ctx.stream(), load_func, store_func, use_cachekv_int8);
break;
case 128:
dispatch_blha_impl_blocksize<T, 128, 128>(
params, dev_ctx.stream(), load_func, store_func, use_cachekv_int8);
Expand Down
24 changes: 0 additions & 24 deletions paddle/phi/kernels/fusion/gpu/mmha_util.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,6 @@ struct Qk_vec_<float, 64> {
using Type = float2;
};
template <>
struct Qk_vec_<float, 96> {
using Type = float4;
};
template <>
struct Qk_vec_<float, 128> {
using Type = float4;
};
Expand All @@ -321,10 +317,6 @@ struct Qk_vec_<float16, 64> {
using Type = uint32_t;
};
template <>
struct Qk_vec_<float16, 96> {
using Type = uint2;
};
template <>
struct Qk_vec_<float16, 128> {
using Type = uint2;
};
Expand All @@ -342,10 +334,6 @@ struct Qk_vec_<bfloat16, 64> {
using Type = __nv_bfloat162;
};
template <>
struct Qk_vec_<bfloat16, 96> {
using Type = bf16_4_t;
};
template <>
struct Qk_vec_<bfloat16, 128> {
using Type = bf16_4_t;
};
Expand All @@ -367,10 +355,6 @@ struct Qk_vec_RoPE_<float16, float, 64> {
using Type = float2;
};
template <>
struct Qk_vec_RoPE_<float16, float, 96> {
using Type = float4;
};
template <>
struct Qk_vec_RoPE_<float16, float, 128> {
using Type = float4;
};
Expand All @@ -387,10 +371,6 @@ struct Qk_vec_RoPE_<float, float, 64> {
using Type = float2;
};
template <>
struct Qk_vec_RoPE_<float, float, 96> {
using Type = float4;
};
template <>
struct Qk_vec_RoPE_<float, float, 128> {
using Type = float4;
};
Expand All @@ -408,10 +388,6 @@ struct Qk_vec_RoPE_<bfloat16, float, 64> {
using Type = float2;
};
template <>
struct Qk_vec_RoPE_<bfloat16, float, 96> {
using Type = float4;
};
template <>
struct Qk_vec_RoPE_<bfloat16, float, 128> {
using Type = float4;
};
Expand Down

0 comments on commit 7a0f294

Please sign in to comment.