Skip to content

Commit

Permalink
SWA working?
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 12, 2023
1 parent 051f57f commit 5b29a05
Showing 1 changed file with 109 additions and 33 deletions.
142 changes: 109 additions & 33 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math
import os
import copy
from collections import defaultdict
from typing import List, Union, Optional
from pathlib import Path
Expand Down Expand Up @@ -35,40 +34,76 @@


class DecodeBlockTable:
def __init__(self, prompt_blocks: list[int]):
self.prompt_cursor = 0
def __init__(
self,
prompt_blocks: list[int],
num_prompt_tokens: int,
block_size,
block_sliding_window: Optional[int] = None,
prompt_shared: bool = False,
):
self.num_prompt_blocks = len(prompt_blocks)
# Prompt blocks between [prompt_cursor, self.num_prompt_blocks) are shared
# with other sequences in a parallel-sampling request.
self.prompt_blocks = prompt_blocks # immutable
self.block_sliding_window = block_sliding_window
self.prompt_shared = prompt_shared

if (
self.block_sliding_window
and self.num_prompt_blocks >= self.block_sliding_window
and prompt_shared
):
self.prompt_cursor_head = (
num_prompt_tokens // block_size
) % block_sliding_window
self.prompt_cursor_tail = self.prompt_cursor_head
else:
self.prompt_cursor_head = 0
self.prompt_cursor_tail = self.num_prompt_blocks

# Prompt blocks between [prompt_cursor_head, prompt_cursor_tail) are shared
# with other sequences in a parallel-sampling request.
self.decode_blocks: list[int] = []

def append(self, new_block_id):
def append(self, new_block_id: int):
self.decode_blocks.append(new_block_id)

def __len__(self):
return self.num_prompt_blocks - self.prompt_cursor + len(self.decode_blocks)
return self.num_prompt_blocks + len(self.decode_blocks)

def __getitem__(self, index: int):
if index == -1:
if len(self.decode_blocks) == 0:
return self.prompt_blocks[-1], True
return self.prompt_blocks[-1]

return self.decode_blocks[-1], False
return self.decode_blocks[-1]

assert index >= 0

if index < self.num_prompt_blocks:
return self.prompt_blocks[self.prompt_cursor + index], True
return self.prompt_blocks[index]

return self.decode_blocks[index - self.num_prompt_blocks], False
return self.decode_blocks[index - self.num_prompt_blocks]

def get_blocks(self):
return self.prompt_blocks[self.prompt_cursor :] + self.decode_blocks
if not self.block_sliding_window or not self.prompt_shared:
return self.prompt_blocks + self.decode_blocks

if self.prompt_cursor_head <= self.prompt_cursor_tail:
return (
self.prompt_blocks[self.prompt_cursor_head : self.prompt_cursor_tail]
+ self.decode_blocks
)

return (
self.prompt_blocks[self.prompt_cursor_head :]
+ self.prompt_blocks[: self.prompt_cursor_tail]
+ self.decode_blocks
)

def replace_head_prompt_block_with(self, new_block):
self.append(new_block)
self.prompt_cursor += 1
self.prompt_cursor_head += 1
self.prompt_cursor_head %= self.num_prompt_blocks
self.num_prompt_blocks -= 1


Expand Down Expand Up @@ -148,6 +183,8 @@ def __init__(
else:
self.block_sliding_window = None

self.sliding_window = sliding_window

def set_size(self, sequence_ids: List[SequenceId], target_sizes: List[int]):
for id, size in zip(sequence_ids, target_sizes):
num_needed_block = math.ceil(size / self.block_size)
Expand All @@ -172,32 +209,52 @@ def set_size(self, sequence_ids: List[SequenceId], target_sizes: List[int]):

elif id in self.kv_cache.decode_block_tables:
decode_block_table = self.kv_cache.decode_block_tables[id]

if len(decode_block_table) < num_needed_block:
# Need to allocate a new block for this request
assert len(decode_block_table) + 1 == num_needed_block
decode_block_table.append(self.free_blocks.pop())

pos = size - 1
prompt_seq_id = get_prompt_sequence_id(id.request_id)
prompt_ref_counts = self.kv_cache.prompt_blocks_ref_counts[
prompt_seq_id
]

if self.block_sliding_window:
index = (pos // self.block_size) % self.block_sliding_window
else:
index = -1

block_number, is_prompt = decode_block_table[index]
def get_block_circular_index(token_pos):
assert self.block_sliding_window
return (token_pos // self.block_size) % self.block_sliding_window

if (
is_prompt
and self.block_sliding_window
and len(self.kv_cache.prompt_blocks_ref_counts[id]) > 0
and self.kv_cache.prompt_blocks_ref_counts[id][index] > 1
self.block_sliding_window
and self.sliding_window
and size >= self.sliding_window
and len(prompt_ref_counts) > 0
):
self.kv_cache.prompt_blocks_ref_counts[id][index] -= 1
# TODO(masahi): The engine should take into account this additional
# free block allocation requirment.
assert len(self.free_blocks) > 0, "No more free block in the cache."
block_number = self.free_blocks.pop()
decode_block_table.replace_head_prompt_block_with(block_number)
if (
decode_block_table.prompt_cursor_head
== get_block_circular_index(pos)
and prompt_ref_counts[get_block_circular_index(pos)] >= 1
):
prompt_ref_counts[decode_block_table.prompt_cursor_head] -= 1

# TODO(masahi): The engine should take into account this additional
# free block allocation requirment.
assert (
len(self.free_blocks) > 0
), "No more free block in the cache."
block_number = self.free_blocks.pop()
decode_block_table.replace_head_prompt_block_with(block_number)
else:
block_number = decode_block_table[-1]

else:
if self.block_sliding_window:
index = get_block_circular_index(pos)
else:
index = -1

block_number = decode_block_table[index]

block_offset = pos % self.block_size
slot = block_number * self.block_size + block_offset
Expand Down Expand Up @@ -249,13 +306,20 @@ def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int):
num_tokens % self.block_size != 0
)

if self.block_sliding_window:
last_block_partially_shared &= (
num_tokens < self.block_sliding_window * self.block_size
)

if last_block_partially_shared:
self.allocated_prompt_tokens[prompt_seq_id] -= num_tokens % self.block_size

prompt_blocks = self.kv_cache.prompt_block_tables[prompt_seq_id]
assert prompt_blocks

if self.block_sliding_window and num_sequences > 1:
prompt_shared = num_sequences > 1

if self.block_sliding_window and prompt_shared:
self.kv_cache.prompt_blocks_ref_counts[prompt_seq_id] = [
num_sequences
] * len(prompt_blocks)
Expand All @@ -266,7 +330,11 @@ def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int):
if not last_block_partially_shared:
self.allocated_decode_tokens[decode_seq_id] = 0
self.kv_cache.decode_block_tables[decode_seq_id] = DecodeBlockTable(
prompt_blocks
prompt_blocks,
num_tokens,
self.block_size,
self.block_sliding_window,
prompt_shared,
)
else:
# Tokens in the partially-shared prompt block are considered to be part of each decode sequence
Expand All @@ -277,7 +345,11 @@ def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int):
if i < num_sequences:
# Need to copy the last block in self.kv_cache.block_tables[prompt_seq_id]
self.kv_cache.decode_block_tables[decode_seq_id] = DecodeBlockTable(
prompt_blocks[:-1]
prompt_blocks[:-1],
num_tokens,
self.block_size,
self.block_sliding_window,
prompt_shared,
)
last_block_copy = self.free_blocks.pop()
self.kv_cache.decode_block_tables[decode_seq_id].append(
Expand All @@ -290,7 +362,11 @@ def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int):
# The last sequence can directly overwrite the last block without copying it,
# since other sequences have its own copy of the last block.
self.kv_cache.decode_block_tables[decode_seq_id] = DecodeBlockTable(
prompt_blocks
prompt_blocks,
num_tokens,
self.block_size,
self.block_sliding_window,
prompt_shared,
)

def extend(self, sequence_id: SequenceId, new_tokens: int):
Expand Down

0 comments on commit 5b29a05

Please sign in to comment.