Skip to content

Commit

Permalink
[Wave] Make num_seq to dynamic for Extend Attention (#476)
Browse files Browse the repository at this point in the history
In order to prevent re-compile when number of sequence changes, we
modify kernel to handle dynamic num_seqs

Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu authored Feb 10, 2025
1 parent f0b35e0 commit 4214677
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 12 deletions.
28 changes: 16 additions & 12 deletions iree/turbine/kernel/wave/templates/extend_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,7 @@ def get_extend_attention_kernel(
layer_scaling = (layer_scaling or dk_sqrt) * LOG2E

constraints: list[tkw.Constraint] = []
constraints += [
tkw.WorkgroupConstraint(
N_Q, BLOCK_N_Q, 0, iters=math.ceil(shape.max_seq_len / SEQ_TILE_SIZE)
)
]
constraints += [tkw.WorkgroupConstraint(N_Q, BLOCK_N_Q, 0)]
constraints += [tkw.WorkgroupConstraint(D_KV, BLOCK_D_KV, 1)]
constraints += [tkw.WorkgroupConstraint(H, BLOCK_H, 2)]
constraints += [tkw.WorkgroupConstraint(H_KV, BLOCK_H, 2, primary=False)]
Expand Down Expand Up @@ -162,6 +158,7 @@ def get_extend_attention_kernel(
block_table_layout = tkl.MemoryLayout(shape=block_table_shape)
k_cache_layout = tkl.MemoryLayout(shape=k_cache_shape)
v_cache_layout = tkl.MemoryLayout(shape=v_cache_shape)
num_seqs_layout = tkl.MemoryLayout(shape=[None])

@tkw.wave(constraints)
def extend_attention(
Expand All @@ -177,10 +174,18 @@ def extend_attention(
block_table: tkl.Memory[
S, N_KV, GLOBAL_ADDRESS_SPACE, tkl.i32, block_table_layout
],
request_indices: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, wave_size_dtype],
sequence_lengths: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, wave_size_dtype],
sequence_lengths_extend: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32],
start_indices_extend: tkl.Memory[S, GLOBAL_ADDRESS_SPACE, tkl.i32],
request_indices: tkl.Memory[
S, GLOBAL_ADDRESS_SPACE, wave_size_dtype, num_seqs_layout
],
sequence_lengths: tkl.Memory[
S, GLOBAL_ADDRESS_SPACE, wave_size_dtype, num_seqs_layout
],
sequence_lengths_extend: tkl.Memory[
S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout
],
start_indices_extend: tkl.Memory[
S, GLOBAL_ADDRESS_SPACE, tkl.i32, num_seqs_layout
],
c: tkl.Memory[N_Q, H, D_KV, GLOBAL_ADDRESS_SPACE, wave_output_dtype, o_layout],
):
c_reg = tkl.Register[H, D_KV, N_Q, tkl.f32](0.0)
Expand Down Expand Up @@ -332,10 +337,9 @@ def second_loop(
H_KV: shape.num_kv_heads,
D_KV: shape.head_size_kv,
D_Q: shape.head_size,
S: shape.num_seqs,
}

dynamic_symbols = [N_Q, N_KV]
dynamic_symbols_map = {N_Q: q_shape[0], N_KV: k_shape[0]}
dynamic_symbols = [N_Q, N_KV, S]
dynamic_symbols_map = {N_Q: q_shape[0], N_KV: k_shape[0], S: shape.num_seqs}

return extend_attention, hyperparams, dynamic_symbols, dynamic_symbols_map
9 changes: 9 additions & 0 deletions lit_tests/kernel/wave/attention/extend_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ def test_extend_attention():
output,
).module_op
)
# This part ensure correctness of WG distribution for extend attention.
# CHECK: stream.executable.export public @extend_attention workgroups(%[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index)
# CHECK: %[[D0:.+]] = arith.subi %[[ARG0]], %c1 : index
# CHECK: %[[D1:.+]] = arith.divui %[[D0]], %c64 : index
# CHECK: %[[D2:.+]] = arith.addi %[[D1]], %c1 : index
# CHECK: %[[D3:.+]] = arith.cmpi eq, %[[ARG0]], %c0 : index
# CHECK: %[[NQ_GRID:.+]] = arith.select %[[D3]], %c0, %[[D2]] : index
# CHECK: %[[NUM_SEQ:.+]] = arith.muli %[[ARG2]], %c16 overflow<nsw, nuw> : index
# CHECK: stream.return %[[NQ_GRID]], %c1, %[[NUM_SEQ]] : index, index, index

# CHECK-LABEL: func.func @extend_attention
# CHECK-DAG: stream.binding.subspan %{{.*}}[%{{.*}}] : !stream.binding -> memref<?x16x64xf16, strided<[1024, 64, 1], offset: ?>>
Expand Down

0 comments on commit 4214677

Please sign in to comment.