Skip to content

Commit

Permalink
Move config variables to setup
Browse files Browse the repository at this point in the history
Signed-off-by: aviator19941 <[email protected]>
  • Loading branch information
aviator19941 committed Aug 23, 2024
1 parent 735d2e2 commit 9bbb389
Showing 1 changed file with 50 additions and 59 deletions.
109 changes: 50 additions & 59 deletions sharktank/tests/models/llama/kv_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,53 +11,46 @@
from ...layers import *


default_arguments = {
"block_count": 1,
"seq_len": 16,
"head_count": 32,
"head_dim": 128,
"ffn_dim": 8640,
}
block_count = 1
seq_len = 16
head_count = 32
head_dim = 128
ffn_dim = 11008 # 4096
head_count_kv = 32
block_seq_stride = 16
rms_epsilon = 1e-5
rope_dimension_count = 128
max_seq_len = 4096
start_positions = torch.tensor([8])
bs = 1


class KVCacheTest(unittest.TestCase):
def setUp(self):
self.block_count = 1
self.seq_len = 16
self.head_count = 32
self.head_dim = 128
self.ffn_dim = 11008
self.head_count_kv = 32
self.block_seq_stride = 16
self.rms_epsilon = 1e-5
self.rope_dimension_count = 128
self.max_seq_len = 4096
self.start_positions = torch.tensor([8])
self.bs = 1
self.attention_block_theta = make_attention_block_theta(
feature_dim=head_count * head_dim, ffn_dim=ffn_dim, dtype=torch.float32
feature_dim=self.head_count * self.head_dim,
ffn_dim=self.ffn_dim,
dtype=torch.float32,
)
self.paged_kv_cache = PagedKVCache(
transformer_block_count=head_count,
attn_head_count=head_count,
attn_head_dim=head_dim,
transformer_block_count=self.head_count,
attn_head_count=self.head_count,
attn_head_dim=self.head_dim,
cache_partition_count=2, # One for each of K/V.
block_seq_stride=block_seq_stride,
block_seq_stride=self.block_seq_stride,
device="cpu",
dtype=torch.float32,
)
self.direct_kv_cache = DirectKVCache(
block_seq_stride=block_seq_stride,
transformer_block_count=head_count,
attn_head_count=head_count,
attn_head_dim=head_dim,
seq_length=seq_len,
block_seq_stride=self.block_seq_stride,
transformer_block_count=self.head_count,
attn_head_count=self.head_count,
attn_head_dim=self.head_dim,
seq_length=self.seq_len,
device="cpu",
dtype=torch.float32,
)
self.attention_embedding = RotaryEmbeddingLayer(
rope_dimension_count=rope_dimension_count,
max_seqlen=max_seq_len,
rope_dimension_count=self.rope_dimension_count,
max_seqlen=self.max_seq_len,
device="cpu",
use_hf=False,
)
Expand All @@ -67,13 +60,13 @@ def setUp(self):
self.attention_block_theta,
block_index=n,
cache=self.paged_kv_cache,
head_count=head_count,
head_dim=head_dim,
head_count_kv=head_count_kv,
rms_epsilon=rms_epsilon,
head_count=self.head_count,
head_dim=self.head_dim,
head_count_kv=self.head_count_kv,
rms_epsilon=self.rms_epsilon,
use_hf=False,
)
for n in range(block_count)
for n in range(self.block_count)
]
)
self.direct_attn_blocks = nn.ModuleList(
Expand All @@ -82,13 +75,13 @@ def setUp(self):
theta=self.attention_block_theta,
block_index=n,
cache=self.direct_kv_cache,
head_count=head_count,
head_dim=head_dim,
head_count_kv=head_count_kv,
rms_epsilon=rms_epsilon,
head_count=self.head_count,
head_dim=self.head_dim,
head_count_kv=self.head_count_kv,
rms_epsilon=self.rms_epsilon,
use_hf=False,
)
for n in range(block_count)
for n in range(self.block_count)
]
)

Expand All @@ -109,7 +102,7 @@ def testDirectAndPagedKVCachePrefill(self):
)

paged_input_tensor = make_rand_torch(
(1, seq_len, head_count * head_dim), dtype=torch.float32
(1, self.seq_len, self.head_count * self.head_dim), dtype=torch.float32
)
direct_input_tensor = paged_input_tensor.detach().clone()

Expand All @@ -119,9 +112,8 @@ def testDirectAndPagedKVCachePrefill(self):
paged_input_tensor = paged_block(
paged_input_tensor,
embedding=self.attention_embedding,
attention_mask=self.prefill_attention_mask,
start_index=0,
cache_state=paged_cache_state,
cache_state=self.paged_cache_state,
seq_block_ids=paged_seq_block_ids,
)
# Iterate over direct attention blocks.
Expand All @@ -130,7 +122,6 @@ def testDirectAndPagedKVCachePrefill(self):
direct_input_tensor = direct_block(
direct_input_tensor,
embedding=self.attention_embedding,
attention_mask=self.prefill_attention_mask,
start_index=0,
cache_state=direct_cache_state,
seq_block_ids=direct_seq_block_ids,
Expand Down Expand Up @@ -164,31 +155,31 @@ def testDirectAndPagedKVCacheDecode(self):
)

token_paged_input_tensor = make_rand_torch(
(1, 1, head_count * head_dim), dtype=torch.float32
(1, 1, self.head_count * self.head_dim), dtype=torch.float32
)
print(token_paged_input_tensor)
token_direct_input_tensor = token_paged_input_tensor.detach().clone()

embedding_batch_mask = self.attention_embedding.compute_batch_mask(
start_positions, batch_seq_len=1
self.start_positions, batch_seq_len=1
)

xk_temp = torch.empty(
[
bs,
max_seq_len,
head_count_kv,
head_dim,
self.bs,
self.max_seq_len,
self.head_count_kv,
self.head_dim,
],
dtype=torch.float32,
device="cpu",
)
xv_temp = torch.empty(
[
bs,
max_seq_len,
head_count_kv,
head_dim,
self.bs,
self.max_seq_len,
self.head_count_kv,
self.head_dim,
],
dtype=torch.float32,
device="cpu",
Expand Down Expand Up @@ -218,7 +209,7 @@ def testDirectAndPagedKVCacheDecode(self):
for block_idx, paged_block in enumerate(self.paged_attn_blocks):
token_paged_input_tensor = paged_block(
token_paged_input_tensor,
start_positions=start_positions,
start_positions=self.start_positions,
embedding=self.attention_embedding,
embedding_batch_mask=embedding_batch_mask,
attention_mask=attention_mask,
Expand All @@ -232,7 +223,7 @@ def testDirectAndPagedKVCacheDecode(self):
for block_idx, direct_block in enumerate(self.direct_attn_blocks):
token_direct_input_tensor = direct_block(
token_direct_input_tensor,
start_positions=start_positions,
start_positions=self.start_positions,
embedding=self.attention_embedding,
embedding_batch_mask=embedding_batch_mask,
attention_mask=attention_mask,
Expand Down

0 comments on commit 9bbb389

Please sign in to comment.