diff --git a/sharktank/tests/models/llama/kv_cache_test.py b/sharktank/tests/models/llama/kv_cache_test.py index fdb07afb2..190c41771 100644 --- a/sharktank/tests/models/llama/kv_cache_test.py +++ b/sharktank/tests/models/llama/kv_cache_test.py @@ -31,10 +31,12 @@ def setUp(self): self.max_seq_len = 4096 self.start_positions = torch.tensor([8]) self.bs = 1 + self.device = "cpu" + self.attention_dtype = torch.float32 self.attention_block_theta = make_attention_block_theta( feature_dim=self.head_count * self.head_dim, ffn_dim=self.ffn_dim, - dtype=torch.float32, + dtype=self.attention_dtype, ) self.paged_kv_cache = PagedKVCache( transformer_block_count=self.head_count, @@ -42,8 +44,8 @@ def setUp(self): attn_head_dim=self.head_dim, cache_partition_count=2, # One for each of K/V. block_seq_stride=self.block_seq_stride, - device="cpu", - dtype=torch.float32, + device=self.device, + dtype=self.attention_dtype, ) self.direct_kv_cache = DirectKVCache( block_seq_stride=self.block_seq_stride, @@ -51,13 +53,13 @@ def setUp(self): attn_head_count=self.head_count, attn_head_dim=self.head_dim, seq_length=self.max_seq_len, - device="cpu", - dtype=torch.float32, + device=self.device, + dtype=self.attention_dtype, ) self.attention_embedding = RotaryEmbeddingLayer( rope_dimension_count=self.rope_dimension_count, max_seqlen=self.max_seq_len, - device="cpu", + device=self.device, use_hf=False, ) self.paged_attn_blocks = nn.ModuleList( @@ -116,7 +118,8 @@ def testDirectAndPagedKVCachePrefill(self): torch.set_default_dtype(torch.float32) paged_input_tensor = make_rand_torch( - (1, self.seq_len, self.head_count * self.head_dim), dtype=torch.float32 + (1, self.seq_len, self.head_count * self.head_dim), + dtype=self.attention_dtype, ) direct_input_tensor = paged_input_tensor.detach().clone() # Iterate over paged attention blocks. @@ -142,9 +145,29 @@ def testDirectAndPagedKVCachePrefill(self): page_table = self.paged_kv_cache.unflatten_page_table(self.paged_cache_state) index_written = self.start_positions.item() page_id = self.paged_seq_block_ids[0][0].item() + """ + direct_cache_state is a list of num_transformer_blocks * 2 (one for K and one for V), + so here we index into the first transformer block's keys with self.direct_cache_state[0] + and the first transformer block's values with self.direct_cache_state[1]. Each row + in direct_cache_state is a tensor of [bs, seq_len , attn_heads, attn_dim], so we make sure + the first 8 (start_position) tensors starting at sequence 0 of the seq_len are written to. + """ updated_direct_cache_state = self.direct_cache_state[0][ :, :index_written ].squeeze(0) + """ + paged_cache_state is a list of a single tensor that represents a flattened page table. + Indexing into self.paged_cache_state[0] and unflattening the page table columns to a 6D tensor of: + * transformer block + * cache partition (K or V cache) + * block sequence stride (number of sequence positions per block) + * attention heads + * attention dimensionality + allows us to access the cache partitions for a certain transformer block and sequence in a + certain page_id. For example, page_table[page_id][0, 0, :index_written] lets us access the + first transformer block's K cache for the first 8 (start_positions) tensors starting at + sequence 0. + """ updated_paged_cache_state = page_table[page_id][0, 0, :index_written] assert updated_direct_cache_state.shape == updated_paged_cache_state.shape torch.testing.assert_close( @@ -158,6 +181,9 @@ def testDirectAndPagedKVCachePrefill(self): paged_prefill_attn_output, direct_prefill_attn_output ) + @unittest.skip( + "Bug in Windows decode test for paged_decode_attn_output vs. direct_decode_attn_output" + ) def testDirectAndPagedKVCacheDecode(self): torch.set_default_dtype(torch.float32) self.start_positions.add_(1) @@ -169,7 +195,7 @@ def testDirectAndPagedKVCacheDecode(self): ) token_paged_input_tensor = make_rand_torch( - (1, 1, self.head_count * self.head_dim), dtype=torch.float32 + (1, 1, self.head_count * self.head_dim), dtype=self.attention_dtype ) token_direct_input_tensor = token_paged_input_tensor.detach().clone() @@ -180,8 +206,8 @@ def testDirectAndPagedKVCacheDecode(self): self.head_count_kv, self.head_dim, ], - dtype=torch.float32, - device="cpu", + dtype=self.attention_dtype, + device=self.device, ) xv_temp = torch.empty( [ @@ -190,8 +216,8 @@ def testDirectAndPagedKVCacheDecode(self): self.head_count_kv, self.head_dim, ], - dtype=torch.float32, - device="cpu", + dtype=self.attention_dtype, + device=self.device, ) # Iterate over paged attention blocks.