Skip to content

Commit

Permalink
In LLM export make dynamic dims compattible with non-defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
sogartar committed Nov 1, 2024
1 parent cf3bf54 commit 6dfc7a4
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from sharktank.layers import *
from sharktank.types import *
from sharktank.utils.math import ceildiv

# TODO: Should be using a base class with the protocol supported.
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
Expand Down Expand Up @@ -152,13 +153,18 @@ def repack_cache(cache, shard_dim):
return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache]

def generate_batch_prefill(bs: int):
tokens = torch.empty(bs, 64, dtype=torch.int64)
seq_lens = torch.empty(bs, dtype=torch.int64)
seq_block_ids = torch.empty(bs, 4, dtype=torch.int64)
block_dim = torch.export.Dim(
"block", max=(hp.context_length - 1) // llama_config.block_seq_stride
)
# torch.export.Dim would make min at least 2
block_dim_min = 2
block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1
block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max)
sl_dim = llama_config.block_seq_stride * block_dim
seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64)
tokens = torch.empty(
bs,
seq_block_ids.shape[1] * llama_config.block_seq_stride,
dtype=torch.int64,
)
seq_lens = torch.empty(bs, dtype=torch.int64)

cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache(
model, llama_config.tensor_parallelism_size
Expand Down Expand Up @@ -221,13 +227,18 @@ def _(model, tokens, seq_lens, seq_block_ids, cs):
return logits

def generate_batch_decode(bs: int):
tokens = torch.ones(bs, 1, dtype=torch.int64)
seq_lens = torch.ones(bs, dtype=torch.int64)
start_positions = torch.ones(bs, dtype=torch.int64)
seq_block_ids = torch.zeros(bs, 4, dtype=torch.int64)
block_dim = torch.export.Dim(
"block", max=(hp.context_length - 1) // llama_config.block_seq_stride
# torch.export.Dim would make min at least 2
block_dim_min = 2
block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1
block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max)
tokens = torch.empty(
bs,
1,
dtype=torch.int64,
)
seq_lens = torch.empty(bs, dtype=torch.int64)
start_positions = torch.ones(bs, dtype=torch.int64)
seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64)

(
cache_state,
Expand Down

0 comments on commit 6dfc7a4

Please sign in to comment.