Skip to content

Commit

Permalink
Better TRT-LLM building on the fly engines (#150)
Browse files Browse the repository at this point in the history
This change add ability to specify TRT-LLM engine configurations using
convenient DSL, as well as uses cached built engines from HuggingFace if
possible
  • Loading branch information
Timur Abishev authored Jan 12, 2024
1 parent 8ede1d3 commit 1655512
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 12 deletions.
11 changes: 6 additions & 5 deletions mistral/mistral-7b-trt-llm-build-engine/config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
base_image:
image: docker.io/baseten/triton_trt_llm:v2
image: baseten/trtllm-build-server:r23.12_baseten_v0.7.1_20240111
python_executable_path: /usr/bin/python3
description: Generate text from a prompt with this seven billion parameter language model.
model_metadata:
avatar_url: https://cdn.baseten.co/production/static/explore/mistral_logo.png
cover_image_url: https://cdn.baseten.co/production/static/explore/mistral.png
engine_build:
args: --remove_input_padding --use_gpt_attention_plugin float16 --enable_context_fmha --use_gemm_plugin float16 --max_batch_size
64 --use_inflight_batching --max_input_len 2000 --max_output_len 2000 --paged_kv_cache
cmd: examples/llama/build.py
engine:
args:
max_batch_size: 64
max_input_len: 2000
max_output_len: 2000
example_model_input: {"messages": [{"role": "user", "content": "What is the mistral wind?"}]}
pipeline_parallelism: 1
tags:
Expand Down
28 changes: 25 additions & 3 deletions mistral/mistral-7b-trt-llm-build-engine/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from pathlib import Path
from threading import Thread

import build_engine_utils
import numpy as np
from build_engine_utils import BuildConfig, build_engine
from client import TritonClient, UserData
from transformers import AutoTokenizer
from utils import download_engine, prepare_grpc_tensor, server_loaded
Expand Down Expand Up @@ -57,9 +57,9 @@ def load(self):
tokenizer_repository = self._config["model_metadata"]["tokenizer_repository"]
if "engine_build" in self._config["model_metadata"]:
if not is_external_engine_repo:
build_engine(
build_engine_utils.build_engine(
model_repo=tokenizer_repository,
config=BuildConfig(
config=build_engine_utils.BuildConfig(
**self._config["model_metadata"]["engine_build"]
),
dst=self._data_dir,
Expand All @@ -71,6 +71,28 @@ def load(self):
raise Exception(
"`engine_build` and `engine_repository` can't be specified at the same time"
)
if "engine" in self._config["model_metadata"]:
import os
import shutil
import sys

sys.path.append("/app/baseten")
from build_engine import Engine, build_engine
from trtllm_utils import docker_tag_aware_file_cache

engine = Engine(**self._config["model_metadata"]["engine"])
engine.repo = tokenizer_repository
with docker_tag_aware_file_cache("/root/.cache/trtllm"):
built_engine = build_engine(engine, download_remote=True)

if not os.path.exists(self._data_dir):
os.makedirs(self._data_dir)

for filename in os.listdir(str(built_engine)):
source_file = os.path.join(str(built_engine), filename)
destination_file = os.path.join(self._data_dir, filename)
if not os.path.exists(destination_file):
shutil.copy(source_file, destination_file)

# Load Triton Server and model
env = {"triton_tokenizer_repository": tokenizer_repository}
Expand Down
10 changes: 6 additions & 4 deletions templates/generate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -184,15 +184,17 @@ mistral/mistral-7b-trt-llm-build-engine:
based_on: trt-llm
config:
base_image:
image: docker.io/baseten/triton_trt_llm:v2
image: baseten/trtllm-build-server:r23.12_baseten_v0.7.1_20240111
model_metadata:
example_model_input: {"messages": [{"role": "user", "content": "What is the mistral wind?"}]}
avatar_url: https://cdn.baseten.co/production/static/explore/mistral_logo.png
cover_image_url: https://cdn.baseten.co/production/static/explore/mistral.png
tokenizer_repository: "mistralai/Mistral-7B-Instruct-v0.2"
engine_build:
cmd: examples/llama/build.py
args: --remove_input_padding --use_gpt_attention_plugin float16 --enable_context_fmha --use_gemm_plugin float16 --max_batch_size 64 --use_inflight_batching --max_input_len 2000 --max_output_len 2000 --paged_kv_cache
engine:
args:
max_input_len: 2000
max_output_len: 2000
max_batch_size: 64
tensor_parallelism: 1
pipeline_parallelism: 1
tags:
Expand Down

0 comments on commit 1655512

Please sign in to comment.