Skip to content

Commit

Permalink
Fix bug in decode, add direct/paged 7b tests
Browse files Browse the repository at this point in the history
Signed-off-by: aviator19941 <[email protected]>
  • Loading branch information
aviator19941 committed Aug 22, 2024
1 parent 5ebced9 commit 4e81822
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 429 deletions.
6 changes: 2 additions & 4 deletions sharktank/sharktank/examples/paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def block_seq_stride(self) -> int:

def begin_batch(self, prompts: list[str]):
token_ids, seq_lens = self.tokenizer.encode(
prompts, pad_to_multiple_of=1 # self.model.cache.pad_sequence_stride
prompts, pad_to_multiple_of=self.model.cache.pad_sequence_stride
)
token_ids = torch.tensor(token_ids, device=self.model.device)
seq_lens = torch.tensor(seq_lens, device=self.model.device)
Expand Down Expand Up @@ -81,7 +81,6 @@ def __init__(
cache_state: list[torch.Tensor],
):
self.bs = token_ids.shape[0]
print("bs:", self.bs)
assert seq_lens.shape[0] == self.bs
self.parent = parent
self.token_ids = token_ids
Expand Down Expand Up @@ -229,11 +228,10 @@ def main():
dataset = cli.get_input_dataset(args)
tokenizer = cli.get_tokenizer(args)
prompts = args.prompt
print("prompts:", prompts, type(prompts))

config = LlamaModelConfig(
hp=configs.LlamaHParams.from_gguf_props(dataset.properties),
block_seq_stride=1, # 16,
block_seq_stride=16,
kv_cache_type=args.kv_cache_type,
device=device,
activation_dtype=activation_dtype,
Expand Down
1 change: 0 additions & 1 deletion sharktank/sharktank/layers/causal_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,5 @@ def extract_tokens_from_logits(
results = []
for batch, seq_len in enumerate(seq_lens):
step_logits = logits[batch, seq_len - 1]
# print('step_logits:', step_logits[torch.argmax(step_logits)])
results.append(torch.argmax(step_logits))
return results
2 changes: 1 addition & 1 deletion sharktank/sharktank/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
theta: Theta,
*,
weight_name: str = "weight",
epsilon: float = 1e-5,
epsilon: float = 1e-6,
dtype: torch.dtype = torch.float32,
):
super().__init__(theta)
Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def transact_cache_paged(
xv_cache_update,
],
transformer_block_index=self.block_index,
seq_positions=start_positions + 1,
seq_positions=start_positions,
page_ids=seq_block_ids,
)

Expand Down
Loading

0 comments on commit 4e81822

Please sign in to comment.