diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py index 4dde584a8..b2dc4a7e3 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/base_attention_cache.py @@ -38,17 +38,21 @@ def __init__(self, page_pool, tokens_per_page): self.tokens_per_page = tokens_per_page - def acquire_pages_for_tokens(self, tokens: List[int]) -> tuple[List[PageInfo], int)]: + def acquire_pages_for_tokens(self, tokens: List[int], extra_token_slots: int = 1) -> tuple[List[PageInfo], int)]: """ Given a list of tokens, return a list of pages and a start position to continue generation from. + Parameters: + - tokens: all the known tokens for this generation request + - extra_token_slots: number of kvcache slots needed in addition to the ones needed to hold the given tokens. + In the base implementation, this will just allocate all new pages, but in shared-kv implementations, we will fetch cached pages if applicable. The pages are returned in order. No token at idx < n_cached_token should be written to. TODO: consider enforcing this. """ - pages_needed = math.ceil(len(tokens) / self.tokens_per_page) + pages_needed = math.ceil(len(tokens + extra_token_slots) / self.tokens_per_page) pages = self.page_pool.acquire_free_pages(pages_needed) n_cached_tokens = 0 diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index bcd08b756..6fa44fbf9 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -218,7 +218,11 @@ def board_prefills(self, cache: AttnPageCache): needed_pages = math.ceil( len(prefill_request.input_token_ids) / self.page_seq_stride ) - pages = cache.acquire_free_pages(needed_pages) + # allocate kv cache pages + pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + prefill_request.input_token_ids, + extra_token_slots=0, # prefill needs no extra kvcache slots to write to + ) if pages is None: logger.debug("Cannot fulfill request for %d pages", needed_pages) continue @@ -254,7 +258,11 @@ def board_decodes(self, cache: AttnPageCache): / self.page_seq_stride ) if needed_pages > len(decode_request.locked_pages): - pages = cache.acquire_free_pages(needed_pages) + # allocate kv cache pages + pages, cache_hit_prefix_length = cache.acquire_pages_for_tokens( + decode_request.input_token_ids, + extra_token_slots=1, # need 1 extra slot to write result. + ) if pages is None: logger.debug( "Cannot fulfill decode request for %d pages", needed_pages