diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index df24309f06..0d74f031fa 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -630,13 +630,18 @@ def init_tvm_model( init_cache_func = tvm.get_global_func(allocate_func_name) try: - model.cache_blocks = init_cache_func( + alloc_args = [ head_size, model_artifact_config.num_hidden_layers, num_kv_heads, block_size, num_blocks, - ) + ] + + if model_artifact_config.paged_kv_cache_type == "vllm": + alloc_args.append("float16") + + model.cache_blocks = init_cache_func(*alloc_args) except tvm.error.InternalError: raise RuntimeError(f"Failed to allocate {num_blocks} cache blocks.")