-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding TRT-LLM + Triton truss #55
base: main
Are you sure you want to change the base?
Changes from 19 commits
60faf86
52cdf9f
b0f419f
5e81f43
2d6732d
c593037
a2caf7f
c5734f7
970e12e
775668a
ef24e7c
535abfc
d8eea27
8726b03
471cbb9
fcb5c6c
ec3b286
44a412c
90061cc
69a28b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# TRTLLM | ||
|
||
### Overview | ||
This Truss adds support for TRT-LLM engines via Triton Inference Server. TRT-LLM is a highly-performant language model runtime. We leverage the C++ runtime to take advantage of in-flight batching (aka continous batching). | ||
|
||
### Prerequisites | ||
|
||
To use this Truss, your engine must be built with in-flight batching support. Refer to your architecture-specific `build.py` re: how to build with in-flight-batching support. | ||
|
||
### Config | ||
|
||
This Truss is primarily config driven. This means that most settings you'll need to edit are located in the `config.yaml`. These settings are all located underneath the `model_metadata` key. | ||
|
||
- `tensor_parallelism` (int): If you built your model with tensor parallelism support, you'll need to set this value with the same value used during the build engine step. This value should be the same as the number of GPUs in the `resources` section. | ||
|
||
*Pipeline parallelism is not supported in this version but will be added later. As noted from Nvidia, pipeline parallelism reduces the need for high-bandwidth communication but may incur load-balancing issues and may be less efficient in terms of GPU utilization.* | ||
|
||
- `engine_repository` (str): We expect engines to be uploaded to Huggingface with a flat directory structure (i.e the engine and associated files are not underneath a folder structure). This value is the full `{org_name}/{repo_name}` string. Engines can be private or public. | ||
|
||
- `tokenizer_repository` (str): Engines do not come bundled with their own tokenizer. This is the Huggingface repository where we can find a tokenizer. Tokenizers can be private or public. | ||
|
||
If the engine and repository tokenizers are private, you'll need to update the `secrets` section of the `config.yaml` as follows: | ||
|
||
``` | ||
secrets: | ||
hf_access_token: "my_hf_api_key" | ||
``` | ||
|
||
### Performance | ||
|
||
TRT-LLM engines are designed to be highly performant. Once your Truss has been deployed, you may find that you're not fully utilizing the GPU. The following are levers to improve performance but require trial-and-error to identify appropriates. All of these values live inside the `config.pbtxt` for a given ensemble model. | ||
|
||
#### Preprocessing / Postprocessing | ||
|
||
``` | ||
instance_group [ | ||
{ | ||
count: 1 | ||
kind: KIND_CPU | ||
} | ||
] | ||
``` | ||
By default, we load 1 instance of the pre/post models. If you find that the tokenizer is a bottleneck, increasing the `count` variable here will load more replicas of these models and Triton will automatically load balance across model instances. | ||
|
||
### Tensorrt LLM | ||
``` | ||
parameters: { | ||
key: "max_tokens_in_paged_kv_cache" | ||
value: { | ||
string_value: "10000" | ||
} | ||
} | ||
``` | ||
By default, we set the `max_tokens_in_paged_kv_cache` to 10000. For a 7B model on 1 A100 with a batch size of 8, we have over 60GB of GPU memory left over. We can increase this value to 100k comfortably and allow for more tokens in the KV cache. Your mileage will vary based on the size of your model and the hardware you're running on. | ||
|
||
``` | ||
parameters: { | ||
key: "kv_cache_free_gpu_mem_fraction" | ||
value: { | ||
string_value: "0.1" | ||
} | ||
} | ||
``` | ||
TODO(Abu): __fill__ | ||
|
||
``` | ||
parameters: { | ||
key: "max_num_sequences" | ||
value: { | ||
string_value: "64" | ||
} | ||
} | ||
``` | ||
The `max_num_sequences` param is the maximum numbers of requests that the inference server can maintain state for at a given time (state = KV cache + decoder state). If this value is greater than your max batch size, we'll try to ping pong processing between max_num_sequences // max_batch_size batches. This assumes that `enable_trt_overlap` is set to `True` (as it is by default in this Truss). Setting this value higher allows for more parallel processing but uses more GPU memory. | ||
|
||
### API | ||
|
||
We expect requests will the following information: | ||
|
||
|
||
- ```text_input``` (str): The prompt you'd like to complete | ||
- ```output_len``` (int, default: 50): The max token count. This includes the number of tokens in your prompt so if this value is less than your prompt, you'll just recieve a truncated version of the prompt. | ||
- ```beam_width``` (int, default:50): The number of beams to compute. This must be 1 for this version of TRT-LLM. Inflight-batching does not support beams > 1. | ||
- ```bad_words_list``` (list, default:[]): A list of words to not include in generated output. | ||
- ```stop_words_list``` (list, default:[]): A list of words to stop generation upon encountering. | ||
- ```repetition_penalty``` (float, defualt: 1.0): A repetition penalty to incentivize not repeating tokens. | ||
|
||
This Truss will stream responses back. Responses will be buffered chunks of text. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
base_image: | ||
image: docker.io/baseten/triton_trt_llm:v1 | ||
python_executable_path: /usr/bin/python3 | ||
environment_variables: {} | ||
external_package_dirs: [] | ||
model_metadata: | ||
tensor_parallelism: 1 | ||
engine_repository: "baseten/llama2-7b-chat-hf-fp16-tp1" | ||
tokenizer_repository: "NousResearch/Llama-2-7b-chat-hf" | ||
model_name: trtllm-triton | ||
python_version: py311 | ||
requirements: [] | ||
resources: | ||
accelerator: A100 | ||
use_gpu: true | ||
secrets: {} | ||
system_packages: [] | ||
runtime: | ||
predict_concurrency: 256 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import numpy as np | ||
from client import UserData, TritonClient | ||
from threading import Thread | ||
from utils import prepare_grpc_tensor, download_engine | ||
from pathlib import Path | ||
from itertools import count | ||
|
||
TRITON_MODEL_REPOSITORY_PATH = Path("/packages/inflight_batcher_llm/") | ||
|
||
class Model: | ||
def __init__(self, **kwargs): | ||
self._data_dir = kwargs["data_dir"] | ||
self._config = kwargs["config"] | ||
self._secrets = kwargs["secrets"] | ||
self._request_id_counter = count(start=1) | ||
self.triton_client = None | ||
|
||
def load(self): | ||
tensor_parallel_count = self._config["model_metadata"].get("tensor_parallelism", 1) | ||
is_hf_token = "hf_access_token" in self._secrets._base_secrets.keys() | ||
is_external_engine_repo = "engine_repository" in self._config["model_metadata"] | ||
|
||
# Instantiate TritonClient | ||
self.triton_client = TritonClient( | ||
data_dir=self._data_dir, | ||
model_repository_dir=TRITON_MODEL_REPOSITORY_PATH, | ||
tensor_parallel_count=tensor_parallel_count, | ||
) | ||
|
||
# Download model from Hugging Face Hub if specified | ||
if is_external_engine_repo: | ||
download_engine( | ||
engine_repository=self._config["model_metadata"]["engine_repository"], | ||
fp=self._data_dir, | ||
auth_token=self._secrets["hf_access_token"] if is_hf_token else None | ||
) | ||
|
||
# Load Triton Server and model | ||
env = { | ||
"triton_tokenizer_repository": self._config["model_metadata"]["tokenizer_repository"], | ||
} | ||
if is_hf_token: | ||
env["HUGGING_FACE_HUB_TOKEN"] = self._secrets["hf_access_token"] | ||
|
||
self.triton_client.load_server_and_model(env=env) | ||
|
||
def predict(self, model_input): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should try to use async predict. Sync predict runs on a thread pool which has limited number of threads and can limit concurrency. Plus creating a new thread per request is not ideal. cc @squidarth who may know of examples of where we use async predict. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I missed the yield before. So this predict function is a generator, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's not obvious to me that there would be a big perf increase by switching to having this by async (it's true that doing things this way will produce another thread). It's a medium lift at least to switch it, since we'd have to change the TritonClient implementation to also be async |
||
user_data = UserData() | ||
model_name = "ensemble" | ||
stream_uuid = str(next(self._request_id_counter)) | ||
|
||
prompt = model_input.get("text_input") | ||
output_len = model_input.get("output_len", 50) | ||
beam_width = model_input.get("beam_width", 1) | ||
bad_words_list = model_input.get("bad_words_list", [""]) | ||
stop_words_list = model_input.get("stop_words_list", [""]) | ||
repetition_penalty = model_input.get("repetition_penalty", 1.0) | ||
|
||
input0 = [[prompt]] | ||
input0_data = np.array(input0).astype(object) | ||
output0_len = np.ones_like(input0).astype(np.uint32) * output_len | ||
bad_words_list = np.array([bad_words_list], dtype=object) | ||
stop_words_list = np.array([stop_words_list], dtype=object) | ||
streaming = [[True]] | ||
streaming_data = np.array(streaming, dtype=bool) | ||
beam_width = [[beam_width]] | ||
beam_width_data = np.array(beam_width, dtype=np.uint32) | ||
repetition_penalty_data = np.array([[repetition_penalty]], dtype=np.float32) | ||
|
||
inputs = [ | ||
prepare_grpc_tensor("text_input", input0_data), | ||
prepare_grpc_tensor("max_tokens", output0_len), | ||
prepare_grpc_tensor("bad_words", bad_words_list), | ||
prepare_grpc_tensor("stop_words", stop_words_list), | ||
prepare_grpc_tensor("stream", streaming_data), | ||
prepare_grpc_tensor("beam_width", beam_width_data), | ||
prepare_grpc_tensor("repetition_penalty", repetition_penalty_data) | ||
] | ||
|
||
# Start GRPC stream in a separate thread | ||
stream_thread = Thread( | ||
target=self.triton_client.start_grpc_stream, | ||
args=(user_data, model_name, inputs, stream_uuid) | ||
) | ||
stream_thread.start() | ||
|
||
# Yield results from the queue | ||
for i in TritonClient.stream_predict(user_data): | ||
yield i | ||
|
||
# Clean up GRPC stream and thread | ||
self.triton_client.stop_grpc_stream(stream_uuid, stream_thread) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import os | ||
import json | ||
import subprocess | ||
import time | ||
from functools import partial | ||
import tritonclient.grpc as grpcclient | ||
import tritonclient.http as httpclient | ||
from pathlib import Path | ||
from queue import Queue | ||
from utils import prepare_model_repository | ||
from tritonclient.utils import InferenceServerException | ||
from threading import Thread | ||
|
||
class UserData: | ||
def __init__(self): | ||
self._completed_requests = Queue() | ||
|
||
def callback(user_data, result, error): | ||
if error: | ||
user_data._completed_requests.put(error) | ||
else: | ||
user_data._completed_requests.put(result) | ||
|
||
class TritonClient: | ||
def __init__(self, data_dir: Path, model_repository_dir: Path, tensor_parallel_count=1): | ||
self._data_dir = data_dir | ||
self._model_repository_dir = model_repository_dir | ||
self._tensor_parallel_count = tensor_parallel_count | ||
self._http_client = None | ||
self._grpc_client_map = {} | ||
|
||
def start_grpc_stream(self, user_data, model_name, inputs, stream_uuid): | ||
"""Starts a GRPC stream and sends a request to the Triton server.""" | ||
grpc_client_instance = grpcclient.InferenceServerClient(url="localhost:8001", verbose=False) | ||
self._grpc_client_map[stream_uuid] = grpc_client_instance | ||
grpc_client_instance.start_stream(callback=partial(callback, user_data)) | ||
grpc_client_instance.async_stream_infer( | ||
model_name, | ||
inputs, | ||
request_id=stream_uuid, | ||
enable_empty_final_response=True, | ||
) | ||
|
||
def stop_grpc_stream(self, stream_uuid, stream_thread: Thread): | ||
"""Closes a GRPC stream and stops the associated thread.""" | ||
triton_grpc_stream = self._grpc_client_map[stream_uuid] | ||
triton_grpc_stream.stop_stream() | ||
stream_thread.join() | ||
del self._grpc_client_map[stream_uuid] | ||
|
||
def start_server( | ||
self, | ||
mpi: int = 1, | ||
env: dict = {}, | ||
): | ||
"""Triton Inference Server has different startup commands depending on | ||
whether it is running in a TP=1 or TP>1 configuration. This function | ||
starts the server with the appropriate command.""" | ||
if mpi == 1: | ||
command = [ | ||
"tritonserver", | ||
"--model-repository", str(self._model_repository_dir), | ||
"--grpc-port", "8001", | ||
"--http-port", "8003" | ||
] | ||
command = [ | ||
"mpirun", | ||
"--allow-run-as-root", | ||
] | ||
for i in range(mpi): | ||
command += [ | ||
"-n", | ||
"1", | ||
"tritonserver", | ||
"--model-repository", str(self._model_repository_dir), | ||
"--grpc-port", "8001", | ||
"--http-port", "8003", | ||
"--disable-auto-complete-config", | ||
f"--backend-config=python,shm-region-prefix-name=prefix{str(i)}_", | ||
":" | ||
] | ||
return subprocess.Popen( | ||
command, | ||
env={**os.environ, **env}, | ||
) | ||
|
||
def load_server_and_model(self, env: dict): | ||
"""Loads the Triton server and the model.""" | ||
prepare_model_repository(self._data_dir) | ||
self.start_server(mpi=self._tensor_parallel_count, env=env) | ||
|
||
self._http_client = httpclient.InferenceServerClient(url="localhost:8003", verbose=False) | ||
is_server_up = False | ||
while not is_server_up: | ||
try: | ||
is_server_up = self._http_client.is_server_live() | ||
except ConnectionRefusedError: | ||
time.sleep(2) | ||
continue | ||
|
||
while self._http_client.is_model_ready(model_name="ensemble") == False: | ||
time.sleep(2) | ||
continue | ||
|
||
@staticmethod | ||
def stream_predict(user_data: UserData): | ||
"""Static method to yield predictions or errors based on input and a streaming user_data queue.""" | ||
|
||
def _is_final_response(result): | ||
"""Check if the given result is a final response according to Triton's specification.""" | ||
if isinstance(result, InferenceServerException): | ||
return True | ||
|
||
if result: | ||
final_response_param = result.get_response().parameters.get("triton_final_response") | ||
return final_response_param.bool_param if final_response_param else False | ||
return False | ||
|
||
result = None | ||
|
||
while not _is_final_response(result): | ||
try: | ||
result = user_data._completed_requests.get() | ||
if not isinstance(result, InferenceServerException): | ||
res = result.as_numpy('text_output') | ||
yield res[0].decode("utf-8") | ||
else: | ||
yield json.dumps({ | ||
"status": "error", | ||
"message": result.message() | ||
}) | ||
except Exception as e: | ||
yield json.dumps({ | ||
"status": "error", | ||
"message": str(e) | ||
}) | ||
break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for your reference: looks like using both
max_tokens_in_paged_kv_cache
andkv_cache_free_gpu_mem_fraction
are redundant. I didn't find docs onkv_cache_free_gpu_mem_fraction
but sounds like they preallocate 85% of the free gpu mem for the kv cache by default, ifmax_tokens_in_paged_kv_cache
is not specified