diff --git a/pyproject.toml b/pyproject.toml index 977476932..9702d80cd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.63" +version = "0.9.64rc001" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" @@ -189,7 +189,7 @@ markers = [ addopts = "--ignore=smoketests" [tool.ruff] -src = ["truss", "truss-chains"] +src = ["truss", "truss-chains", "truss-utils"] target-version = "py38" line-length = 88 lint.extend-select = [ diff --git a/truss/base/trt_llm_config.py b/truss/base/trt_llm_config.py index 9b0ed384c..03537c851 100644 --- a/truss/base/trt_llm_config.py +++ b/truss/base/trt_llm_config.py @@ -21,13 +21,16 @@ class TrussTRTLLMModel(str, Enum): + ENCODER = "encoder" + DECODER = "decoder" + # auto migrated settings + PALMYRA = "palmyra" + QWEN = "qwen" LLAMA = "llama" MISTRAL = "mistral" DEEPSEEK = "deepseek" + # deprecated workflow WHISPER = "whisper" - QWEN = "qwen" - ENCODER = "encoder" - PALMYRA = "palmyra" class TrussTRTLLMQuantizationType(str, Enum): @@ -43,11 +46,38 @@ class TrussTRTLLMQuantizationType(str, Enum): class TrussTRTLLMPluginConfiguration(BaseModel): paged_kv_cache: bool = True - gemm_plugin: str = "auto" use_paged_context_fmha: bool = True use_fp8_context_fmha: bool = False +class TrussTRTQuantizationConfiguration(BaseModel): + """Configuration for quantization of TRT models + + Args: + calib_size (int, optional): Size of calibration dataset. Defaults to 1024. + recommended to increase for production runs (e.g. 1536), or decrease e.g. to 256 for quick testing. + calib_dataset (str, optional): Hugginface dataset to use for calibration. Defaults to 'cnn_dailymail'. + uses split='train' and quantized based on 'text' column. + calib_max_seq_length (int, optional): Maximum sequence length for calibration. Defaults to 2048. + """ + + calib_size: int = 1024 + calib_dataset: str = "cnn_dailymail" + calib_max_seq_length: int = 2048 + + def __init__(self, **data): + super().__init__(**data) + self.validate_cuda_friendly("calib_size") + self.validate_cuda_friendly("calib_max_seq_length") + + def validate_cuda_friendly(self, key): + value = getattr(self, key) + if value < 64 or value > 16384: + raise ValueError(f"{key} must be between 64 and 16384, but got {value}") + elif value % 64 != 0: + raise ValueError(f"{key} must be a multiple of 64, but got {value}") + + class CheckpointSource(str, Enum): HF = "HF" GCS = "GCS" @@ -97,7 +127,7 @@ class TrussTRTLLMRuntimeConfiguration(BaseModel): class TrussTRTLLMBuildConfiguration(BaseModel): - base_model: TrussTRTLLMModel + base_model: TrussTRTLLMModel = TrussTRTLLMModel.DECODER max_seq_len: int max_batch_size: int = 256 max_num_tokens: int = 8192 @@ -109,6 +139,9 @@ class TrussTRTLLMBuildConfiguration(BaseModel): quantization_type: TrussTRTLLMQuantizationType = ( TrussTRTLLMQuantizationType.NO_QUANT ) + quantization_config: TrussTRTQuantizationConfiguration = ( + TrussTRTQuantizationConfiguration() + ) tensor_parallel_count: int = 1 pipeline_parallel_count: int = 1 plugin_configuration: TrussTRTLLMPluginConfiguration = ( diff --git a/truss/tests/conftest.py b/truss/tests/conftest.py index a4011c2ca..d86e5932c 100644 --- a/truss/tests/conftest.py +++ b/truss/tests/conftest.py @@ -825,7 +825,6 @@ def trtllm_spec_dec_config_full(trtllm_config) -> Dict[str, Any]: "checkpoint_repository": {"source": "HF", "repo": "meta/llama4-500B"}, "plugin_configuration": { "paged_kv_cache": True, - "gemm_plugin": "auto", "use_paged_context_fmha": True, }, "speculator": { @@ -857,7 +856,6 @@ def trtllm_spec_dec_config(trtllm_config) -> Dict[str, Any]: "checkpoint_repository": {"source": "HF", "repo": "meta/llama4-500B"}, "plugin_configuration": { "paged_kv_cache": True, - "gemm_plugin": "auto", "use_paged_context_fmha": True, }, "speculator": {