Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config refactor: Remove gemm_plugin from config, add quantization config with calibration size #1398

Merged
merged 11 commits into from
Feb 21, 2025
43 changes: 38 additions & 5 deletions truss/base/trt_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@


class TrussTRTLLMModel(str, Enum):
ENCODER = "encoder"
DECODER = "decoder"
# auto migrated settings
PALMYRA = "palmyra"
QWEN = "qwen"
LLAMA = "llama"
MISTRAL = "mistral"
DEEPSEEK = "deepseek"
# deprecated workflow
WHISPER = "whisper"
QWEN = "qwen"
ENCODER = "encoder"
PALMYRA = "palmyra"


class TrussTRTLLMQuantizationType(str, Enum):
Expand All @@ -43,11 +46,38 @@ class TrussTRTLLMQuantizationType(str, Enum):

class TrussTRTLLMPluginConfiguration(BaseModel):
paged_kv_cache: bool = True
gemm_plugin: str = "auto"
use_paged_context_fmha: bool = True
use_fp8_context_fmha: bool = False


class TrussTRTQuantizationConfiguration(BaseModel):
"""Configuration for quantization of TRT models

Args:
calib_size (int, optional): Size of calibration dataset. Defaults to 1024.
recommended to increase for production runs (e.g. 1536), or decrease e.g. to 256 for quick testing.
calib_dataset (str, optional): Hugginface dataset to use for calibration. Defaults to 'cnn_dailymail'.
uses split='train' and quantized based on 'text' column.
calib_max_seq_length (int, optional): Maximum sequence length for calibration. Defaults to 2048.
"""

calib_size: int = 1024
calib_dataset: str = "cnn_dailymail"
calib_max_seq_length: int = 2048

def __init__(self, **data):
super().__init__(**data)
self.validate_cuda_friendly("calib_size")
self.validate_cuda_friendly("calib_max_seq_length")

def validate_cuda_friendly(self, key):
value = getattr(self, key)
if value < 64 or value > 16384:
raise ValueError(f"{key} must be between 64 and 16384, but got {value}")
elif value % 64 != 0:
raise ValueError(f"{key} must be a multiple of 64, but got {value}")


class CheckpointSource(str, Enum):
HF = "HF"
GCS = "GCS"
Expand Down Expand Up @@ -97,7 +127,7 @@ class TrussTRTLLMRuntimeConfiguration(BaseModel):


class TrussTRTLLMBuildConfiguration(BaseModel):
base_model: TrussTRTLLMModel
base_model: TrussTRTLLMModel = TrussTRTLLMModel.DECODER
max_seq_len: int
max_batch_size: int = 256
max_num_tokens: int = 8192
Expand All @@ -109,6 +139,9 @@ class TrussTRTLLMBuildConfiguration(BaseModel):
quantization_type: TrussTRTLLMQuantizationType = (
TrussTRTLLMQuantizationType.NO_QUANT
)
quantization_config: TrussTRTQuantizationConfiguration = (
TrussTRTQuantizationConfiguration()
)
tensor_parallel_count: int = 1
pipeline_parallel_count: int = 1
plugin_configuration: TrussTRTLLMPluginConfiguration = (
Expand Down
2 changes: 0 additions & 2 deletions truss/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,6 @@ def trtllm_spec_dec_config_full(trtllm_config) -> Dict[str, Any]:
"checkpoint_repository": {"source": "HF", "repo": "meta/llama4-500B"},
"plugin_configuration": {
"paged_kv_cache": True,
"gemm_plugin": "auto",
"use_paged_context_fmha": True,
},
"speculator": {
Expand Down Expand Up @@ -857,7 +856,6 @@ def trtllm_spec_dec_config(trtllm_config) -> Dict[str, Any]:
"checkpoint_repository": {"source": "HF", "repo": "meta/llama4-500B"},
"plugin_configuration": {
"paged_kv_cache": True,
"gemm_plugin": "auto",
"use_paged_context_fmha": True,
},
"speculator": {
Expand Down
Loading