Skip to content

Commit

Permalink
Add llama2-7b direct kv cache test
Browse files Browse the repository at this point in the history
Signed-off-by: aviator19941 <[email protected]>
  • Loading branch information
aviator19941 committed Aug 22, 2024
1 parent a9a5788 commit 5ebced9
Show file tree
Hide file tree
Showing 4 changed files with 469 additions and 3 deletions.
7 changes: 5 additions & 2 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..models.llama.llama import *
from ..utils.debugging import trace_tensor
from ..utils.tokenizer import InferenceTokenizer, load_tokenizer
from transformers import LlamaTokenizer


class TorchGenerator:
Expand Down Expand Up @@ -48,7 +49,7 @@ def block_seq_stride(self) -> int:

def begin_batch(self, prompts: list[str]):
token_ids, seq_lens = self.tokenizer.encode(
prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride
prompts, pad_to_multiple_of=1 # self.model.cache.pad_sequence_stride
)
token_ids = torch.tensor(token_ids, device=self.model.device)
seq_lens = torch.tensor(seq_lens, device=self.model.device)
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
cache_state: list[torch.Tensor],
):
self.bs = token_ids.shape[0]
print("bs:", self.bs)
assert seq_lens.shape[0] == self.bs
self.parent = parent
self.token_ids = token_ids
Expand Down Expand Up @@ -227,10 +229,11 @@ def main():
dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
prompts = args.prompt
print("prompts:", prompts, type(prompts))

config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
block_seq_stride=16,
block_seq_stride=1, # 16,
kv_cache_type=args.kv_cache_type,
device=device,
activation_dtype=activation_dtype,
Expand Down
1 change: 1 addition & 0 deletions sharktank/sharktank/layers/causal_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,6 @@ def extract_tokens_from_logits(
results = []
for batch, seq_len in enumerate(seq_lens):
step_logits = logits[batch, seq_len - 1]
# print('step_logits:', step_logits[torch.argmax(step_logits)])
results.append(torch.argmax(step_logits))
return results
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
theta: Theta,
*,
weight_name: str = "weight",
epsilon: float = 1e-6,
epsilon: float = 1e-5,
dtype: torch.dtype = torch.float32,
):
super().__init__(theta)
Expand Down
Loading

0 comments on commit 5ebced9

Please sign in to comment.