From 9bbb38959b133c8ca5896c0eb8584de64fa3c421 Mon Sep 17 00:00:00 2001 From: aviator19941 Date: Fri, 23 Aug 2024 16:44:22 -0500 Subject: [PATCH] Move config variables to setup Signed-off-by: aviator19941 --- sharktank/tests/models/llama/kv_cache_test.py | 109 ++++++++---------- 1 file changed, 50 insertions(+), 59 deletions(-) diff --git a/sharktank/tests/models/llama/kv_cache_test.py b/sharktank/tests/models/llama/kv_cache_test.py index 299e31a30..f3024c14c 100644 --- a/sharktank/tests/models/llama/kv_cache_test.py +++ b/sharktank/tests/models/llama/kv_cache_test.py @@ -11,53 +11,46 @@ from ...layers import * -default_arguments = { - "block_count": 1, - "seq_len": 16, - "head_count": 32, - "head_dim": 128, - "ffn_dim": 8640, -} -block_count = 1 -seq_len = 16 -head_count = 32 -head_dim = 128 -ffn_dim = 11008 # 4096 -head_count_kv = 32 -block_seq_stride = 16 -rms_epsilon = 1e-5 -rope_dimension_count = 128 -max_seq_len = 4096 -start_positions = torch.tensor([8]) -bs = 1 - - class KVCacheTest(unittest.TestCase): def setUp(self): + self.block_count = 1 + self.seq_len = 16 + self.head_count = 32 + self.head_dim = 128 + self.ffn_dim = 11008 + self.head_count_kv = 32 + self.block_seq_stride = 16 + self.rms_epsilon = 1e-5 + self.rope_dimension_count = 128 + self.max_seq_len = 4096 + self.start_positions = torch.tensor([8]) + self.bs = 1 self.attention_block_theta = make_attention_block_theta( - feature_dim=head_count * head_dim, ffn_dim=ffn_dim, dtype=torch.float32 + feature_dim=self.head_count * self.head_dim, + ffn_dim=self.ffn_dim, + dtype=torch.float32, ) self.paged_kv_cache = PagedKVCache( - transformer_block_count=head_count, - attn_head_count=head_count, - attn_head_dim=head_dim, + transformer_block_count=self.head_count, + attn_head_count=self.head_count, + attn_head_dim=self.head_dim, cache_partition_count=2, # One for each of K/V. - block_seq_stride=block_seq_stride, + block_seq_stride=self.block_seq_stride, device="cpu", dtype=torch.float32, ) self.direct_kv_cache = DirectKVCache( - block_seq_stride=block_seq_stride, - transformer_block_count=head_count, - attn_head_count=head_count, - attn_head_dim=head_dim, - seq_length=seq_len, + block_seq_stride=self.block_seq_stride, + transformer_block_count=self.head_count, + attn_head_count=self.head_count, + attn_head_dim=self.head_dim, + seq_length=self.seq_len, device="cpu", dtype=torch.float32, ) self.attention_embedding = RotaryEmbeddingLayer( - rope_dimension_count=rope_dimension_count, - max_seqlen=max_seq_len, + rope_dimension_count=self.rope_dimension_count, + max_seqlen=self.max_seq_len, device="cpu", use_hf=False, ) @@ -67,13 +60,13 @@ def setUp(self): self.attention_block_theta, block_index=n, cache=self.paged_kv_cache, - head_count=head_count, - head_dim=head_dim, - head_count_kv=head_count_kv, - rms_epsilon=rms_epsilon, + head_count=self.head_count, + head_dim=self.head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, use_hf=False, ) - for n in range(block_count) + for n in range(self.block_count) ] ) self.direct_attn_blocks = nn.ModuleList( @@ -82,13 +75,13 @@ def setUp(self): theta=self.attention_block_theta, block_index=n, cache=self.direct_kv_cache, - head_count=head_count, - head_dim=head_dim, - head_count_kv=head_count_kv, - rms_epsilon=rms_epsilon, + head_count=self.head_count, + head_dim=self.head_dim, + head_count_kv=self.head_count_kv, + rms_epsilon=self.rms_epsilon, use_hf=False, ) - for n in range(block_count) + for n in range(self.block_count) ] ) @@ -109,7 +102,7 @@ def testDirectAndPagedKVCachePrefill(self): ) paged_input_tensor = make_rand_torch( - (1, seq_len, head_count * head_dim), dtype=torch.float32 + (1, self.seq_len, self.head_count * self.head_dim), dtype=torch.float32 ) direct_input_tensor = paged_input_tensor.detach().clone() @@ -119,9 +112,8 @@ def testDirectAndPagedKVCachePrefill(self): paged_input_tensor = paged_block( paged_input_tensor, embedding=self.attention_embedding, - attention_mask=self.prefill_attention_mask, start_index=0, - cache_state=paged_cache_state, + cache_state=self.paged_cache_state, seq_block_ids=paged_seq_block_ids, ) # Iterate over direct attention blocks. @@ -130,7 +122,6 @@ def testDirectAndPagedKVCachePrefill(self): direct_input_tensor = direct_block( direct_input_tensor, embedding=self.attention_embedding, - attention_mask=self.prefill_attention_mask, start_index=0, cache_state=direct_cache_state, seq_block_ids=direct_seq_block_ids, @@ -164,31 +155,31 @@ def testDirectAndPagedKVCacheDecode(self): ) token_paged_input_tensor = make_rand_torch( - (1, 1, head_count * head_dim), dtype=torch.float32 + (1, 1, self.head_count * self.head_dim), dtype=torch.float32 ) print(token_paged_input_tensor) token_direct_input_tensor = token_paged_input_tensor.detach().clone() embedding_batch_mask = self.attention_embedding.compute_batch_mask( - start_positions, batch_seq_len=1 + self.start_positions, batch_seq_len=1 ) xk_temp = torch.empty( [ - bs, - max_seq_len, - head_count_kv, - head_dim, + self.bs, + self.max_seq_len, + self.head_count_kv, + self.head_dim, ], dtype=torch.float32, device="cpu", ) xv_temp = torch.empty( [ - bs, - max_seq_len, - head_count_kv, - head_dim, + self.bs, + self.max_seq_len, + self.head_count_kv, + self.head_dim, ], dtype=torch.float32, device="cpu", @@ -218,7 +209,7 @@ def testDirectAndPagedKVCacheDecode(self): for block_idx, paged_block in enumerate(self.paged_attn_blocks): token_paged_input_tensor = paged_block( token_paged_input_tensor, - start_positions=start_positions, + start_positions=self.start_positions, embedding=self.attention_embedding, embedding_batch_mask=embedding_batch_mask, attention_mask=attention_mask, @@ -232,7 +223,7 @@ def testDirectAndPagedKVCacheDecode(self): for block_idx, direct_block in enumerate(self.direct_attn_blocks): token_direct_input_tensor = direct_block( token_direct_input_tensor, - start_positions=start_positions, + start_positions=self.start_positions, embedding=self.attention_embedding, embedding_batch_mask=embedding_batch_mask, attention_mask=attention_mask,