From 6dfc7a4955d590ed6d6777a13ca93f0278f14872 Mon Sep 17 00:00:00 2001 From: Boian Petkantchin Date: Thu, 31 Oct 2024 06:02:08 -0500 Subject: [PATCH] In LLM export make dynamic dims compattible with non-defaults --- .../sharktank/examples/export_paged_llm_v1.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 84b174bba..f22b2ccbd 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -13,6 +13,7 @@ from sharktank.layers import * from sharktank.types import * +from sharktank.utils.math import ceildiv # TODO: Should be using a base class with the protocol supported. from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1 @@ -152,13 +153,18 @@ def repack_cache(cache, shard_dim): return [SplitPrimitiveTensor(ts=c, shard_dim=shard_dim) for c in cache] def generate_batch_prefill(bs: int): - tokens = torch.empty(bs, 64, dtype=torch.int64) - seq_lens = torch.empty(bs, dtype=torch.int64) - seq_block_ids = torch.empty(bs, 4, dtype=torch.int64) - block_dim = torch.export.Dim( - "block", max=(hp.context_length - 1) // llama_config.block_seq_stride - ) + # torch.export.Dim would make min at least 2 + block_dim_min = 2 + block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1 + block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max) sl_dim = llama_config.block_seq_stride * block_dim + seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64) + tokens = torch.empty( + bs, + seq_block_ids.shape[1] * llama_config.block_seq_stride, + dtype=torch.int64, + ) + seq_lens = torch.empty(bs, dtype=torch.int64) cache, cache_shard_dim, cache_dynamic_shapes, arg_affinities = setup_cache( model, llama_config.tensor_parallelism_size @@ -221,13 +227,18 @@ def _(model, tokens, seq_lens, seq_block_ids, cs): return logits def generate_batch_decode(bs: int): - tokens = torch.ones(bs, 1, dtype=torch.int64) - seq_lens = torch.ones(bs, dtype=torch.int64) - start_positions = torch.ones(bs, dtype=torch.int64) - seq_block_ids = torch.zeros(bs, 4, dtype=torch.int64) - block_dim = torch.export.Dim( - "block", max=(hp.context_length - 1) // llama_config.block_seq_stride + # torch.export.Dim would make min at least 2 + block_dim_min = 2 + block_dim_max = ceildiv(hp.context_length, llama_config.block_seq_stride) - 1 + block_dim = torch.export.Dim("block", min=block_dim_min, max=block_dim_max) + tokens = torch.empty( + bs, + 1, + dtype=torch.int64, ) + seq_lens = torch.empty(bs, dtype=torch.int64) + start_positions = torch.ones(bs, dtype=torch.int64) + seq_block_ids = torch.empty(bs, block_dim_min, dtype=torch.int64) ( cache_state,