Skip to content

Commit

Permalink
Config refactor: Remove gemm_plugin from config, add quantization con…
Browse files Browse the repository at this point in the history
…fig with calibration size (#1398)

* Update trt_llm_config.py

* add decoder

* rm conftest

* add quantization config

* add quantization config

* add config migration

* TrussTRTLLMModel

* make trt-llm.base_model config optional
  • Loading branch information
michaelfeil authored Feb 21, 2025
1 parent 26deabf commit 0970b7d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 9 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.63"
version = "0.9.64rc001"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down Expand Up @@ -189,7 +189,7 @@ markers = [
addopts = "--ignore=smoketests"

[tool.ruff]
src = ["truss", "truss-chains"]
src = ["truss", "truss-chains", "truss-utils"]
target-version = "py38"
line-length = 88
lint.extend-select = [
Expand Down
43 changes: 38 additions & 5 deletions truss/base/trt_llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@


class TrussTRTLLMModel(str, Enum):
ENCODER = "encoder"
DECODER = "decoder"
# auto migrated settings
PALMYRA = "palmyra"
QWEN = "qwen"
LLAMA = "llama"
MISTRAL = "mistral"
DEEPSEEK = "deepseek"
# deprecated workflow
WHISPER = "whisper"
QWEN = "qwen"
ENCODER = "encoder"
PALMYRA = "palmyra"


class TrussTRTLLMQuantizationType(str, Enum):
Expand All @@ -43,11 +46,38 @@ class TrussTRTLLMQuantizationType(str, Enum):

class TrussTRTLLMPluginConfiguration(BaseModel):
paged_kv_cache: bool = True
gemm_plugin: str = "auto"
use_paged_context_fmha: bool = True
use_fp8_context_fmha: bool = False


class TrussTRTQuantizationConfiguration(BaseModel):
"""Configuration for quantization of TRT models
Args:
calib_size (int, optional): Size of calibration dataset. Defaults to 1024.
recommended to increase for production runs (e.g. 1536), or decrease e.g. to 256 for quick testing.
calib_dataset (str, optional): Hugginface dataset to use for calibration. Defaults to 'cnn_dailymail'.
uses split='train' and quantized based on 'text' column.
calib_max_seq_length (int, optional): Maximum sequence length for calibration. Defaults to 2048.
"""

calib_size: int = 1024
calib_dataset: str = "cnn_dailymail"
calib_max_seq_length: int = 2048

def __init__(self, **data):
super().__init__(**data)
self.validate_cuda_friendly("calib_size")
self.validate_cuda_friendly("calib_max_seq_length")

def validate_cuda_friendly(self, key):
value = getattr(self, key)
if value < 64 or value > 16384:
raise ValueError(f"{key} must be between 64 and 16384, but got {value}")
elif value % 64 != 0:
raise ValueError(f"{key} must be a multiple of 64, but got {value}")


class CheckpointSource(str, Enum):
HF = "HF"
GCS = "GCS"
Expand Down Expand Up @@ -97,7 +127,7 @@ class TrussTRTLLMRuntimeConfiguration(BaseModel):


class TrussTRTLLMBuildConfiguration(BaseModel):
base_model: TrussTRTLLMModel
base_model: TrussTRTLLMModel = TrussTRTLLMModel.DECODER
max_seq_len: int
max_batch_size: int = 256
max_num_tokens: int = 8192
Expand All @@ -109,6 +139,9 @@ class TrussTRTLLMBuildConfiguration(BaseModel):
quantization_type: TrussTRTLLMQuantizationType = (
TrussTRTLLMQuantizationType.NO_QUANT
)
quantization_config: TrussTRTQuantizationConfiguration = (
TrussTRTQuantizationConfiguration()
)
tensor_parallel_count: int = 1
pipeline_parallel_count: int = 1
plugin_configuration: TrussTRTLLMPluginConfiguration = (
Expand Down
2 changes: 0 additions & 2 deletions truss/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,6 @@ def trtllm_spec_dec_config_full(trtllm_config) -> Dict[str, Any]:
"checkpoint_repository": {"source": "HF", "repo": "meta/llama4-500B"},
"plugin_configuration": {
"paged_kv_cache": True,
"gemm_plugin": "auto",
"use_paged_context_fmha": True,
},
"speculator": {
Expand Down Expand Up @@ -857,7 +856,6 @@ def trtllm_spec_dec_config(trtllm_config) -> Dict[str, Any]:
"checkpoint_repository": {"source": "HF", "repo": "meta/llama4-500B"},
"plugin_configuration": {
"paged_kv_cache": True,
"gemm_plugin": "auto",
"use_paged_context_fmha": True,
},
"speculator": {
Expand Down

0 comments on commit 0970b7d

Please sign in to comment.