diff --git a/sharktank/sharktank/examples/export_paged_llm_v1.py b/sharktank/sharktank/examples/export_paged_llm_v1.py index 84b174bba..17193bd17 100644 --- a/sharktank/sharktank/examples/export_paged_llm_v1.py +++ b/sharktank/sharktank/examples/export_paged_llm_v1.py @@ -43,6 +43,12 @@ def main(): type=lambda arg: [int(bs) for bs in arg.split(",")], default="4", ) + parser.add_argument( + "--block-seq-stride", + help="Block sequence stride for a paged KV cache.", + type=int, + default=LlamaModelConfig.default_block_seq_stride, + ) parser.add_argument( "--verbose", help="Include verbose logging", @@ -59,6 +65,16 @@ def main(): default="decomposed", choices=["decomposed", "torch"], ) + parser.add_argument( + "--attention-dtype", + type=str, + default=dtype_to_serialized_name(LlamaModelConfig.default_attention_dtype), + ) + parser.add_argument( + "--activation-dtype", + type=str, + default=dtype_to_serialized_name(LlamaModelConfig.default_activation_dtype), + ) args = cli.parse(parser) dataset_type = cli.get_input_data_files(args) @@ -73,11 +89,14 @@ def main(): ) llama_config = LlamaModelConfig( hp, + block_seq_stride=args.block_seq_stride, tensor_parallelism_size=tensor_parallelism_size, use_hf=False, static_tables=False, # Rely on the compiler for hoisting tables. kv_cache_type="direct" if args.bs == [1] else "paged", attention_kernel=args.attention_kernel, + attention_dtype=serialized_name_to_dtype(args.attention_dtype), + activation_dtype=serialized_name_to_dtype(args.activation_dtype), ) if llama_config.hp.expert_count: diff --git a/sharktank/sharktank/layers/configs/llm_configs.py b/sharktank/sharktank/layers/configs/llm_configs.py index 4dee24701..d4a8a2f77 100644 --- a/sharktank/sharktank/layers/configs/llm_configs.py +++ b/sharktank/sharktank/layers/configs/llm_configs.py @@ -15,7 +15,7 @@ """ from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Optional, ClassVar import torch __all__ = ["LlamaHParams", "LlamaModelConfig"] @@ -121,7 +121,8 @@ class LlamaModelConfig: # Block sequence stride for a paged KV cache. This must divide evenly # into the context length. - block_seq_stride: int = 16 + default_block_seq_stride: ClassVar[int] = 16 + block_seq_stride: int = default_block_seq_stride # Either "paged" or "direct". kv_cache_type: str = "paged" @@ -130,10 +131,12 @@ class LlamaModelConfig: device: Optional[torch.device] = None # Dtype to use for general FP activations not otherwise configured. - activation_dtype: torch.dtype = torch.float16 + default_activation_dtype: ClassVar[torch.dtype] = torch.float16 + activation_dtype: torch.dtype = default_activation_dtype # Dtype to use for attention. - attention_dtype: torch.dtype = torch.float16 + default_attention_dtype: ClassVar[torch.dtype] = torch.float16 + attention_dtype: torch.dtype = default_attention_dtype # How many devices are involved for tensor parallel sharding. # If greater than 1, the model will expect sharded model parameters and function