Skip to content

Commit

Permalink
improve kv_cache update perf
Browse files Browse the repository at this point in the history
Signed-off-by: Chengji Yao <[email protected]>
  • Loading branch information
yaochengji committed Feb 12, 2025
1 parent 14b7899 commit 7437110
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 32 deletions.
35 changes: 28 additions & 7 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState

MIN_PREFILL_SEQ_LEN = 16


class PallasAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -183,7 +185,8 @@ def forward(
if kv_cache[0].numel() > 0:
slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping,
seq_len > 1)

query = query * self.scale
if attn_metadata.num_prefills > 0:
Expand Down Expand Up @@ -296,16 +299,34 @@ def write_to_kv_cache(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
is_prefill: bool,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)

key = key.flatten(0, 2)
value = value.flatten(0, 2)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
if is_prefill:
batch_size, _, head_num, head_size = key.shape
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
key = key.view(batch_size, head_num, -1, MIN_PREFILL_SEQ_LEN,
head_size)
value = value.view(batch_size, head_num, -1, MIN_PREFILL_SEQ_LEN,
head_size)
key = key.flatten(0, 2)
value = value.flatten(0, 2)
key_cache = key_cache.flatten(0, 2)
key_cache = key_cache.view(-1, MIN_PREFILL_SEQ_LEN, head_size)
value_cache = value_cache.flatten(0, 2)
value_cache = value_cache.view(-1, MIN_PREFILL_SEQ_LEN, head_size)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
else:
key = key.flatten(0, 2)
value = value.flatten(0, 2)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)


def paged_attention(
Expand Down
68 changes: 43 additions & 25 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch_xla.runtime as xr

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.pallas import MIN_PREFILL_SEQ_LEN
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
Expand Down Expand Up @@ -108,6 +109,9 @@ def __init__(
self.is_driver_worker = is_driver_worker

self.block_size = self.cache_config.block_size
assert self.block_size % MIN_PREFILL_SEQ_LEN == 0, (
f"block size is required to be multiple of {MIN_PREFILL_SEQ_LEN}"
"for better performance")
self.max_num_blocks_per_seq = (self.model_config.max_model_len //
self.block_size)
self.block_tables = np.zeros(
Expand Down Expand Up @@ -172,16 +176,18 @@ def _dummy_run(
) -> None:
exec_mode = ExecutionMode(exec_mode)
if exec_mode.is_prefill():
seq_len = (seq_len + 15) // 16 * 16
seq_len = (seq_len + MIN_PREFILL_SEQ_LEN -
1) // MIN_PREFILL_SEQ_LEN * MIN_PREFILL_SEQ_LEN
token_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
position_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((batch_size, seq_len),
dtype=torch.int64,
device=self.device)
slot_mapping = torch.zeros(
(batch_size, seq_len // MIN_PREFILL_SEQ_LEN),
dtype=torch.int64,
device=self.device)
input_lens = torch.ones((batch_size, ),
dtype=torch.int32,
device=self.device)
Expand Down Expand Up @@ -258,10 +264,10 @@ def _dummy_run(
# be re-compiled for every different shapes. This overhead is inevitable
# in the first run, but can be skipped afterwards as we cache the XLA
# graphs in the disk (VLLM_XLA_CACHE_PATH).
# NOTE(chengjiyao): During prefill, seq_len cannot be marked as dynamic
# because it is used to calculate the output shape of a view op.
if exec_mode.is_prefill():
# Prefll
torch._dynamo.mark_dynamic(token_ids, 1)
torch._dynamo.mark_dynamic(position_ids, 1)
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
else:
# Decode
Expand All @@ -286,7 +292,7 @@ def warmup_model(
logger.info("Compiling the model with different input shapes...")
start = time.time()
for batch_size in [1]:
seq_len = 16
seq_len = MIN_PREFILL_SEQ_LEN
while seq_len <= self.model_config.max_model_len:
self._dummy_run(batch_size,
seq_len,
Expand All @@ -308,7 +314,7 @@ def warmup_model(
"prefix prefill...")
start = time.time()
for batch_size in [1]:
seq_len = 16
seq_len = MIN_PREFILL_SEQ_LEN
while seq_len <= self.model_config.max_model_len:
self._dummy_run(batch_size,
seq_len,
Expand Down Expand Up @@ -340,7 +346,8 @@ def warmup_model(

if batch_size >= self.scheduler_config.max_num_seqs:
break
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
batch_size = (batch_size + MIN_PREFILL_SEQ_LEN if batch_size
>= MIN_PREFILL_SEQ_LEN else batch_size * 2)

end = time.time()
logger.info("Compilation for decode done in %.2f s.", end - start)
Expand Down Expand Up @@ -384,11 +391,10 @@ def _prepare_prompt(

assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
for i in range(num_computed_tokens, seq_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
assert num_computed_tokens % MIN_PREFILL_SEQ_LEN == 0
for i in range(num_computed_tokens, seq_len, MIN_PREFILL_SEQ_LEN):
block_number = block_table[i // MIN_PREFILL_SEQ_LEN]
slot_mapping.append(block_number)
if num_computed_tokens > 0:
self.block_tables[batch_idx, :len(block_table)] = block_table

Expand All @@ -402,7 +408,8 @@ def _prepare_prompt(
num_paddings = padded_prompt_len - prompt_len
input_tokens += [0] * num_paddings
input_positions += [0] * num_paddings
slot_mapping += [_PAD_SLOT_ID] * num_paddings
slot_mapping += [_PAD_SLOT_ID
] * (num_paddings // MIN_PREFILL_SEQ_LEN)

assert len(prompt_lens) > 0
num_prefills = len(prompt_lens)
Expand Down Expand Up @@ -645,13 +652,15 @@ def execute_model(
model_input.attn_metadata.effective_query_lens
batch_size = model_input.input_lens.shape[0]
start_idx = 0
start_slot_idx = 0
next_token_ids = []
for i in range(batch_size):
# Get the actual prefill_len.
prefill_len = model_input.input_lens[i:i + 1].item()
prefill_len = _get_padded_prefill_len(prefill_len)
end_idx = start_idx + prefill_len

end_slot_idx = (start_slot_idx +
prefill_len // MIN_PREFILL_SEQ_LEN)
token_ids = model_input.token_ids[None, start_idx:end_idx].to(
self.device)
position_ids = model_input.position_ids[None,
Expand All @@ -660,7 +669,7 @@ def execute_model(
attn_metadata = model_input.attn_metadata
attn_metadata.num_prefills = 1
attn_metadata.slot_mapping = orig_slot_mapping[
None, start_idx:end_idx].to(self.device)
None, start_slot_idx:end_slot_idx].to(self.device)
if orig_context_lens[i].item() > 0:
attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
self.device)
Expand All @@ -684,6 +693,7 @@ def execute_model(
kv_caches)
next_token_ids.append(output_token_ids[0])
start_idx = end_idx
start_slot_idx = end_slot_idx

if model_input.async_callback is not None:
model_input.async_callback()
Expand Down Expand Up @@ -826,10 +836,17 @@ def forward(
num_kv_heads,
device=slot_mapping.device,
dtype=slot_mapping.dtype)
head_indicies *= block_size * num_blocks
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
-1, num_kv_heads)
slot_mapping = slot_mapping + head_indicies.view(1, -1)
if seq_len > 1:
# prefill
head_indicies *= num_blocks * block_size // MIN_PREFILL_SEQ_LEN
slot_mapping = slot_mapping.repeat(num_kv_heads, 1)
slot_mapping = slot_mapping + head_indicies.view(-1, 1)
else:
# decoding
head_indicies *= num_blocks * block_size
slot_mapping = slot_mapping.repeat_interleave(
num_kv_heads).view(-1, num_kv_heads)
slot_mapping = slot_mapping + head_indicies.view(1, -1)
slot_mapping = slot_mapping.flatten()
attn_metadata.slot_mapping = slot_mapping

Expand Down Expand Up @@ -867,10 +884,11 @@ def forward(

def _get_padded_prefill_len(x: int) -> int:
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
# length to be a multiple of 16. We pad the prompt length to the nearest
# multiple of 16. This is also good for performance.
if x <= 16:
return 16
# length to be a multiple of MIN_PREFILL_SEQ_LEN. We pad the prompt length
# to the nearest multiple of MIN_PREFILL_SEQ_LEN. This is also good for
# performance.
if x <= MIN_PREFILL_SEQ_LEN:
return MIN_PREFILL_SEQ_LEN
return 1 << (x - 1).bit_length()


Expand Down

0 comments on commit 7437110

Please sign in to comment.