Skip to content

Commit

Permalink
modify tpu model runner
Browse files Browse the repository at this point in the history
  • Loading branch information
yaochengji committed Feb 12, 2025
1 parent 78c6be2 commit d08e04a
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,12 +839,14 @@ def forward(
if seq_len > 1:
# prefill
head_indicies *= num_blocks * block_size // MIN_PREFILL_SEQ_LEN
slot_mapping = slot_mapping.repeat(num_kv_heads, 1)
slot_mapping = slot_mapping + head_indicies.view(-1, 1)
else:
# decoding
head_indicies *= num_blocks * block_size
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
-1, num_kv_heads)
slot_mapping = slot_mapping + head_indicies.view(1, -1)
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
-1, num_kv_heads)
slot_mapping = slot_mapping + head_indicies.view(1, -1)
slot_mapping = slot_mapping.flatten()
attn_metadata.slot_mapping = slot_mapping

Expand Down

0 comments on commit d08e04a

Please sign in to comment.