diff --git a/sharktank/sharktank/evaluate/perplexity_iree.py b/sharktank/sharktank/evaluate/perplexity_iree.py index f42a4cf4a..3d075fc6b 100644 --- a/sharktank/sharktank/evaluate/perplexity_iree.py +++ b/sharktank/sharktank/evaluate/perplexity_iree.py @@ -69,6 +69,8 @@ def __init__( tensor_parallelism_size, attention_kernel, block_seq_stride, + activation_dtype=torch.float16, + attention_dtype=torch.float16, ): self.torch_device = torch_device self.iree_device = iree_device @@ -76,8 +78,8 @@ def __init__( self.iree_hal_target_device = iree_hal_target_device self.kv_cache_type = kv_cache_type self.block_seq_stride = block_seq_stride - self.activation_dtype = torch.float16 - self.attention_dtype = torch.float16 + self.activation_dtype = activation_dtype + self.attention_dtype = attention_dtype self.tensor_parallelism_size = tensor_parallelism_size self.attention_kernel = attention_kernel @@ -430,18 +432,6 @@ def run_perplexity( def main(argv): parser = cli.create_parser() - parser.add_argument( - "--attention-kernel", - type=str, - default="decomposed", - choices=["decomposed", "torch_sdpa"], - ) - parser.add_argument( - "--block-seq-stride", - help="Block sequence stride for paged KV cache, must divide evenly into the context length", - type=int, - default=32, - ) parser.add_argument("--iree-device", help="List an IREE device (e.g., 'hip://0')") parser.add_argument( "--iree-hip-target", @@ -455,29 +445,29 @@ def main(argv): default="hip", help="Specify the iree-hal target device (e.g., hip, cpu)", ) - parser.add_argument("--kv-cache-type", default="paged", help="KV cache type") parser.add_argument( "--num-prompts", type=int, default=100, help="Number of prompts for perplexity test (1 to 100)", ) - parser.add_argument( - "--tensor-parallelism-size", - type=int, - default=1, - help="Number of devices for tensor parallel sharding", - ) - parser.add_argument("--torch-device", help="Torch device (or default)") + cli.add_model_options(parser) cli.add_tokenizer_options(parser) cli.add_input_dataset_options(parser) args = cli.parse(parser, args=argv) - torch_device = torch.device(args.torch_device) if args.torch_device else None + torch_device = torch.device(args.device) if args.device else None weight_path = cli.get_input_dataset(args) tokenizer = cli.get_tokenizer(args) + # Override flag if dataset disagrees + tensor_parallelism_size = ( + weight_path.properties["tensor_parallelism_size"] + if "tensor_parallelism_size" in weight_path.properties + else args.tensor_parallelism_size + ) + ppl = run_perplexity( weight_path=weight_path, weight_path_str=str(args.irpa_file), @@ -486,8 +476,7 @@ def main(argv): iree_device=args.iree_device, iree_hip_target=args.iree_hip_target, iree_hal_target_device=args.iree_hal_target_device, - kv_cache_type=args.kv_cache_type, - tensor_parallelism_size=args.tensor_parallelism_size, + tensor_parallelism_size=tensor_parallelism_size, attention_kernel=args.attention_kernel, num_prompts=args.num_prompts, block_seq_stride=args.block_seq_stride, diff --git a/sharktank/sharktank/evaluate/perplexity_torch.py b/sharktank/sharktank/evaluate/perplexity_torch.py index 9ece1e4a0..4974de1c2 100644 --- a/sharktank/sharktank/evaluate/perplexity_torch.py +++ b/sharktank/sharktank/evaluate/perplexity_torch.py @@ -58,12 +58,10 @@ class Perplexity_torch: def __init__( self, device, - kv_cache_type, activation_dtype=torch.float32, attention_dtype=torch.float32, ): self.device = device - self.kv_cache_type = kv_cache_type self.activation_dtype = activation_dtype self.attention_dtype = attention_dtype @@ -115,7 +113,6 @@ def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kern self.config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(dataset.properties), block_seq_stride=16, - kv_cache_type=self.kv_cache_type, device=self.device, activation_dtype=self.activation_dtype, attention_dtype=self.attention_dtype, @@ -298,14 +295,13 @@ def run_perplexity_torch( dataset, tokenizer, device, - kv_cache_type, tensor_parallelism_size, attention_kernel, num_prompts, ): start = time.time() - perplexity = Perplexity_torch(device=device, kv_cache_type=kv_cache_type) + perplexity = Perplexity_torch(device=device) perplexity.get_prompts(num_prompts=num_prompts) perplexity.load_model(dataset, tokenizer, tensor_parallelism_size, attention_kernel) ppl = perplexity.get_perplexity() diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 24ec55cf5..32680ac3c 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -45,12 +45,6 @@ def main(): type=lambda arg: [int(bs) for bs in arg.split(",")], default="4", ) - parser.add_argument( - "--block-seq-stride", - help="Block sequence stride for paged KV cache, must divide evenly into the context length", - type=int, - default=32, - ) parser.add_argument( "--verbose", help="Include verbose logging", diff --git a/sharktank/sharktank/examples/paged_llm_v1.py b/sharktank/sharktank/examples/paged_llm_v1.py index 768575441..07b707211 100644 --- a/sharktank/sharktank/examples/paged_llm_v1.py +++ b/sharktank/sharktank/examples/paged_llm_v1.py @@ -42,8 +42,8 @@ def __init__( ): self.model = model self.tokenizer = tokenizer - if model.cache.is_paged: - self.shared_cache_state = model.cache.paged.allocate(page_cache_size) + if self.model.config.kv_cache_type == "paged": + self.shared_cache_state = model.cache.allocate(page_cache_size) self.free_pages = list(range(1, page_cache_size)) else: self.shared_cache_state = None @@ -63,18 +63,18 @@ def begin_batch(self, prompts: list[str]): if self.shared_cache_state is not None: cache_state = self.shared_cache_state else: - cache_state = self.model.cache.direct.allocate(bs=len(prompts)) + cache_state = self.model.cache.allocate(bs=len(prompts)) return Batch(self, token_ids, seq_lens, cache_state) def alloc_page(self) -> int: - if self.model.cache.is_direct: + if self.model.config.kv_cache_type == "direct": # We don't allocate block ids for the direct cache. return 0 return self.free_pages.pop() def release_page(self, index: int): - if self.model.cache.is_direct: + if self.model.config.kv_cache_type == "direct": return self.free_pages.append(index) @@ -238,12 +238,6 @@ def main(): "--save_intermediates_path", help="save module forward outputs to safetensors, ex: run_0 will save to run_0_prefill.savetensors", ) - parser.add_argument( - "--tensor-parallelism-size", - type=int, - default=1, - help="How many devices are involved for tensor parallel sharding.", - ) cli.add_input_dataset_options(parser) cli.add_tokenizer_options(parser) cli.add_quantization_options(parser) @@ -255,7 +249,7 @@ def main(): prompts = args.prompt config = LlamaModelConfig( hp=configs.LlamaHParams.from_gguf_props(dataset.properties), - block_seq_stride=16, + block_seq_stride=args.block_seq_stride, device=device, activation_dtype=args.activation_dtype, attention_dtype=args.activation_dtype, diff --git a/sharktank/sharktank/export_layer/export_paged_attention.py b/sharktank/sharktank/export_layer/export_paged_attention.py index e9d284111..644d44258 100644 --- a/sharktank/sharktank/export_layer/export_paged_attention.py +++ b/sharktank/sharktank/export_layer/export_paged_attention.py @@ -44,25 +44,14 @@ def paged_attention( # Full sequence length. kv_seq_len = seq_block_ids.shape[1] * attention_block.cache.block_seq_stride - if attention_block.cache.is_paged: - xk, xv = attention_block.transact_cache_paged( - xk_cache_update=xk, - xv_cache_update=xv, - seq_block_ids=seq_block_ids, - kv_seq_len=kv_seq_len, - start_positions=start_positions, - cache_state=cache_state, - ) - elif attention_block.cache.is_direct: - xk, xv = attention_block.transact_cache_direct( - xk_cache_update=xk, - xv_cache_update=xv, - start_positions=start_positions, - kv_seq_len=kv_seq_len, - cache_state=cache_state, - ) - else: - raise NotImplementedError(f"Unsupported KV cache type: {type(cache)}") + xk, xv = attention_block.transact_cache( + xk_cache_update=xk, + xv_cache_update=xv, + seq_block_ids=seq_block_ids, + kv_seq_len=kv_seq_len, + start_positions=start_positions, + cache_state=cache_state, + ) # Expand kv heads for GQA. gqa_n_rep = attention_block.head_count // attention_block.head_count_kv diff --git a/sharktank/sharktank/utils/cli.py b/sharktank/sharktank/utils/cli.py index 3dc946817..52aa379a6 100644 --- a/sharktank/sharktank/utils/cli.py +++ b/sharktank/sharktank/utils/cli.py @@ -108,6 +108,12 @@ def add_model_options(parser: argparse.ArgumentParser): default=1, help="Number of devices for tensor parallel sharding. Will be overridden by dataset.properties if present", ) + parser.add_argument( + "--block-seq-stride", + help="Block sequence stride for paged KV cache, must divide evenly into the context length", + type=int, + default=32, + ) def add_quantization_options(parser: argparse.ArgumentParser): diff --git a/sharktank/sharktank/utils/load_llm.py b/sharktank/sharktank/utils/load_llm.py index 47d9f0244..4296cb151 100644 --- a/sharktank/sharktank/utils/load_llm.py +++ b/sharktank/sharktank/utils/load_llm.py @@ -45,11 +45,11 @@ def begin_batch( token_ids = torch.tensor(token_ids, device=self.model.device) seq_lens = torch.tensor(seq_lens, device=self.model.device) - if self.model.cache.is_paged: - cache_state = self.model.cache.paged.allocate(page_cache_size) + if self.model.config.kv_cache_type == "paged": + cache_state = self.model.cache.allocate(page_cache_size) self.free_pages = list(range(1, page_cache_size)) - else: - cache_state = self.model.cache.direct.allocate(bs=len(prompts)) + elif self.model.config.kv_cache_type == "direct": + cache_state = self.model.cache.allocate(bs=1) return Batch(self, token_ids, seq_lens, cache_state) def begin_eval_batch( @@ -59,22 +59,22 @@ def begin_eval_batch( bs: int, page_cache_size: int = 128, ): - if self.model.cache.is_paged: - cache_state = self.model.cache.paged.allocate(page_cache_size) + if self.model.config.kv_cache_type == "paged": + cache_state = self.model.cache.allocate(page_cache_size) self.free_pages = list(range(1, page_cache_size)) - else: - cache_state = self.model.cache.direct.allocate(bs=bs) + elif self.model.config.kv_cache_type == "direct": + cache_state = self.model.cache.allocate(bs=1) return Batch(self, token_batch, seq_lens_batch, cache_state) def alloc_page(self) -> int: - if self.model.cache.is_direct: + if self.model.config.kv_cache_type == "direct": # We don't allocate block ids for the direct cache. return 0 return self.free_pages.pop() def release_page(self, index: int): - if self.model.cache.is_direct: + if self.model.config.kv_cache_type == "direct": return self.free_pages.append(index) diff --git a/sharktank/tests/evaluate/perplexity_iree_test.py b/sharktank/tests/evaluate/perplexity_iree_test.py index dc655af59..8302317cd 100644 --- a/sharktank/tests/evaluate/perplexity_iree_test.py +++ b/sharktank/tests/evaluate/perplexity_iree_test.py @@ -86,7 +86,7 @@ def test_llama3_8B_f16(self): f"--iree-hal-target-device={self.iree_hal_target_device}", f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", - f"--attention-kernel=torch_sdpa", + f"--attention-kernel=torch", f"--num-prompts={self.batch_size}", ] ) @@ -158,7 +158,7 @@ def test_llama3_8B_fp8(self): f"--iree-hal-target-device={self.iree_hal_target_device}", f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size=1", - f"--attention-kernel=torch_sdpa", + f"--attention-kernel=torch", f"--num-prompts={self.batch_size}", ] ) @@ -232,7 +232,7 @@ def test_llama3_405B_f16(self): f"--iree-hal-target-device={self.iree_hal_target_device}", f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", - f"--attention-kernel=torch_sdpa", + f"--attention-kernel=torch", f"--num-prompts={self.batch_size}", ] ) @@ -304,7 +304,7 @@ def test_llama3_405B_fp8(self): f"--iree-hal-target-device={self.iree_hal_target_device}", f"--iree-hip-target={self.iree_hip_target}", f"--tensor-parallelism-size={self.tensor_parallelism_size}", - f"--attention-kernel=torch_sdpa", + f"--attention-kernel=torch", f"--num-prompts={self.batch_size}", ] )