Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add block seq len and attention/activation dtype args to LLM export script #391

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions sharktank/sharktank/examples/export_paged_llm_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ def main():
type=lambda arg: [int(bs) for bs in arg.split(",")],
default="4",
)
parser.add_argument(
"--block-seq-stride",
help="Block sequence stride for a paged KV cache.",
type=int,
default=LlamaModelConfig.default_block_seq_stride,
)
parser.add_argument(
"--verbose",
help="Include verbose logging",
Expand All @@ -59,6 +65,16 @@ def main():
default="decomposed",
choices=["decomposed", "torch"],
)
parser.add_argument(
"--attention-dtype",
type=str,
default=dtype_to_serialized_name(LlamaModelConfig.default_attention_dtype),
)
parser.add_argument(
"--activation-dtype",
type=str,
default=dtype_to_serialized_name(LlamaModelConfig.default_activation_dtype),
)

args = cli.parse(parser)
dataset_type = cli.get_input_data_files(args)
Expand All @@ -73,11 +89,14 @@ def main():
)
llama_config = LlamaModelConfig(
hp,
block_seq_stride=args.block_seq_stride,
tensor_parallelism_size=tensor_parallelism_size,
use_hf=False,
static_tables=False, # Rely on the compiler for hoisting tables.
kv_cache_type="direct" if args.bs == [1] else "paged",
attention_kernel=args.attention_kernel,
attention_dtype=serialized_name_to_dtype(args.attention_dtype),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This depends on the other PR. It should be included there and not bleed into this PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already merged the other PR.

activation_dtype=serialized_name_to_dtype(args.activation_dtype),
)

if llama_config.hp.expert_count:
Expand Down
11 changes: 7 additions & 4 deletions sharktank/sharktank/layers/configs/llm_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""

from dataclasses import dataclass
from typing import Any, Optional
from typing import Any, Optional, ClassVar
import torch

__all__ = ["LlamaHParams", "LlamaModelConfig"]
Expand Down Expand Up @@ -121,7 +121,8 @@ class LlamaModelConfig:

# Block sequence stride for a paged KV cache. This must divide evenly
# into the context length.
block_seq_stride: int = 16
default_block_seq_stride: ClassVar[int] = 16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need ClassVar here? Seeing the default_block_seq_stride default feels bad.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted a way to reference the default from the export script, so that the CLI argument has the same default.

block_seq_stride: int = default_block_seq_stride

# Either "paged" or "direct".
kv_cache_type: str = "paged"
Expand All @@ -130,10 +131,12 @@ class LlamaModelConfig:
device: Optional[torch.device] = None

# Dtype to use for general FP activations not otherwise configured.
activation_dtype: torch.dtype = torch.float16
default_activation_dtype: ClassVar[torch.dtype] = torch.float16
activation_dtype: torch.dtype = default_activation_dtype

# Dtype to use for attention.
attention_dtype: torch.dtype = torch.float16
default_attention_dtype: ClassVar[torch.dtype] = torch.float16
attention_dtype: torch.dtype = default_attention_dtype

# How many devices are involved for tensor parallel sharding.
# If greater than 1, the model will expect sharded model parameters and function
Expand Down
Loading