Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding TRT-LLM + Triton truss #55

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions trtllm-truss/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# TRTLLM

### Overview
This Truss adds support for TRT-LLM engines via Triton Inference Server. TRT-LLM is a highly-performant language model runtime. We leverage the C++ runtime to take advantage of in-flight batching (aka continous batching).

### Prerequisites

To use this Truss, your engine must be built with in-flight batching support. Refer to your architecture-specific `build.py` re: how to build with in-flight-batching support.

### Config

This Truss is primarily config driven. This means that most settings you'll need to edit are located in the `config.yaml`. These settings are all located underneath the `model_metadata` key.

- `tensor_parallelism` (int): If you built your model with tensor parallelism support, you'll need to set this value with the same value used during the build engine step. This value should be the same as the number of GPUs in the `resources` section.

*Pipeline parallelism is not supported in this version but will be added later. As noted from Nvidia, pipeline parallelism reduces the need for high-bandwidth communication but may incur load-balancing issues and may be less efficient in terms of GPU utilization.*

- `engine_repository` (str): We expect engines to be uploaded to Huggingface with a flat directory structure (i.e the engine and associated files are not underneath a folder structure). This value is the full `{org_name}/{repo_name}` string. Engines can be private or public.

- `tokenizer_repository` (str): Engines do not come bundled with their own tokenizer. This is the Huggingface repository where we can find a tokenizer. Tokenizers can be private or public.

If the engine and repository tokenizers are private, you'll need to update the `secrets` section of the `config.yaml` as follows:

```
secrets:
hf_access_token: "my_hf_api_key"
```

### Performance

TRT-LLM engines are designed to be highly performant. Once your Truss has been deployed, you may find that you're not fully utilizing the GPU. The following are levers to improve performance but require trial-and-error to identify appropriates. All of these values live inside the `config.pbtxt` for a given ensemble model.

#### Preprocessing / Postprocessing

```
instance_group [
{
count: 1
kind: KIND_CPU
}
]
```
By default, we load 1 instance of the pre/post models. If you find that the tokenizer is a bottleneck, increasing the `count` variable here will load more replicas of these models and Triton will automatically load balance across model instances.

### Tensorrt LLM
```
parameters: {
key: "max_tokens_in_paged_kv_cache"
value: {
string_value: "10000"
}
}
```
By default, we set the `max_tokens_in_paged_kv_cache` to 10000. For a 7B model on 1 A100 with a batch size of 8, we have over 60GB of GPU memory left over. We can increase this value to 100k comfortably and allow for more tokens in the KV cache. Your mileage will vary based on the size of your model and the hardware you're running on.

```
parameters: {
key: "kv_cache_free_gpu_mem_fraction"
value: {
string_value: "0.1"
}
}
```
TODO(Abu): __fill__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for your reference: looks like using both max_tokens_in_paged_kv_cache and kv_cache_free_gpu_mem_fraction are redundant. I didn't find docs on kv_cache_free_gpu_mem_fraction but sounds like they preallocate 85% of the free gpu mem for the kv cache by default, if max_tokens_in_paged_kv_cache is not specified


```
parameters: {
key: "max_num_sequences"
value: {
string_value: "64"
}
}
```
The `max_num_sequences` param is the maximum numbers of requests that the inference server can maintain state for at a given time (state = KV cache + decoder state). If this value is greater than your max batch size, we'll try to ping pong processing between max_num_sequences // max_batch_size batches. This assumes that `enable_trt_overlap` is set to `True` (as it is by default in this Truss). Setting this value higher allows for more parallel processing but uses more GPU memory.

### API

We expect requests will the following information:


- ```text_input``` (str): The prompt you'd like to complete
- ```output_len``` (int, default: 50): The max token count. This includes the number of tokens in your prompt so if this value is less than your prompt, you'll just recieve a truncated version of the prompt.
- ```beam_width``` (int, default:50): The number of beams to compute. This must be 1 for this version of TRT-LLM. Inflight-batching does not support beams > 1.
- ```bad_words_list``` (list, default:[]): A list of words to not include in generated output.
- ```stop_words_list``` (list, default:[]): A list of words to stop generation upon encountering.
- ```repetition_penalty``` (float, defualt: 1.0): A repetition penalty to incentivize not repeating tokens.

This Truss will stream responses back. Responses will be buffered chunks of text.
19 changes: 19 additions & 0 deletions trtllm-truss/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
base_image:
image: docker.io/baseten/triton_trt_llm:v1
python_executable_path: /usr/bin/python3
environment_variables: {}
external_package_dirs: []
model_metadata:
tensor_parallelism: 1
engine_repository: "baseten/llama2-7b-chat-hf-fp16-tp1"
tokenizer_repository: "NousResearch/Llama-2-7b-chat-hf"
model_name: trtllm-triton
python_version: py311
requirements: []
resources:
accelerator: A100
use_gpu: true
secrets: {}
system_packages: []
runtime:
predict_concurrency: 256
Empty file added trtllm-truss/model/__init__.py
Empty file.
92 changes: 92 additions & 0 deletions trtllm-truss/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import numpy as np
from client import UserData, TritonClient
from threading import Thread
from utils import prepare_grpc_tensor, download_engine
from pathlib import Path
from itertools import count

TRITON_MODEL_REPOSITORY_PATH = Path("/packages/inflight_batcher_llm/")

class Model:
def __init__(self, **kwargs):
self._data_dir = kwargs["data_dir"]
self._config = kwargs["config"]
self._secrets = kwargs["secrets"]
self._request_id_counter = count(start=1)
self.triton_client = None

def load(self):
tensor_parallel_count = self._config["model_metadata"].get("tensor_parallelism", 1)
is_hf_token = "hf_access_token" in self._secrets._base_secrets.keys()
is_external_engine_repo = "engine_repository" in self._config["model_metadata"]

# Instantiate TritonClient
self.triton_client = TritonClient(
data_dir=self._data_dir,
model_repository_dir=TRITON_MODEL_REPOSITORY_PATH,
tensor_parallel_count=tensor_parallel_count,
)

# Download model from Hugging Face Hub if specified
if is_external_engine_repo:
download_engine(
engine_repository=self._config["model_metadata"]["engine_repository"],
fp=self._data_dir,
auth_token=self._secrets["hf_access_token"] if is_hf_token else None
)

# Load Triton Server and model
env = {
"triton_tokenizer_repository": self._config["model_metadata"]["tokenizer_repository"],
}
if is_hf_token:
env["HUGGING_FACE_HUB_TOKEN"] = self._secrets["hf_access_token"]

self.triton_client.load_server_and_model(env=env)

def predict(self, model_input):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should try to use async predict. Sync predict runs on a thread pool which has limited number of threads and can limit concurrency. Plus creating a new thread per request is not ideal. cc @squidarth who may know of examples of where we use async predict.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I missed the yield before. So this predict function is a generator, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not obvious to me that there would be a big perf increase by switching to having this by async (it's true that doing things this way will produce another thread). It's a medium lift at least to switch it, since we'd have to change the TritonClient implementation to also be async

user_data = UserData()
model_name = "ensemble"
stream_uuid = str(next(self._request_id_counter))

prompt = model_input.get("text_input")
output_len = model_input.get("output_len", 50)
beam_width = model_input.get("beam_width", 1)
bad_words_list = model_input.get("bad_words_list", [""])
stop_words_list = model_input.get("stop_words_list", [""])
repetition_penalty = model_input.get("repetition_penalty", 1.0)

input0 = [[prompt]]
input0_data = np.array(input0).astype(object)
output0_len = np.ones_like(input0).astype(np.uint32) * output_len
bad_words_list = np.array([bad_words_list], dtype=object)
stop_words_list = np.array([stop_words_list], dtype=object)
streaming = [[True]]
streaming_data = np.array(streaming, dtype=bool)
beam_width = [[beam_width]]
beam_width_data = np.array(beam_width, dtype=np.uint32)
repetition_penalty_data = np.array([[repetition_penalty]], dtype=np.float32)

inputs = [
prepare_grpc_tensor("text_input", input0_data),
prepare_grpc_tensor("max_tokens", output0_len),
prepare_grpc_tensor("bad_words", bad_words_list),
prepare_grpc_tensor("stop_words", stop_words_list),
prepare_grpc_tensor("stream", streaming_data),
prepare_grpc_tensor("beam_width", beam_width_data),
prepare_grpc_tensor("repetition_penalty", repetition_penalty_data)
]

# Start GRPC stream in a separate thread
stream_thread = Thread(
target=self.triton_client.start_grpc_stream,
args=(user_data, model_name, inputs, stream_uuid)
)
stream_thread.start()

# Yield results from the queue
for i in TritonClient.stream_predict(user_data):
yield i

# Clean up GRPC stream and thread
self.triton_client.stop_grpc_stream(stream_uuid, stream_thread)
137 changes: 137 additions & 0 deletions trtllm-truss/packages/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
import json
import subprocess
import time
from functools import partial
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient
from pathlib import Path
from queue import Queue
from utils import prepare_model_repository
from tritonclient.utils import InferenceServerException
from threading import Thread

class UserData:
def __init__(self):
self._completed_requests = Queue()

def callback(user_data, result, error):
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)

class TritonClient:
def __init__(self, data_dir: Path, model_repository_dir: Path, tensor_parallel_count=1):
self._data_dir = data_dir
self._model_repository_dir = model_repository_dir
self._tensor_parallel_count = tensor_parallel_count
self._http_client = None
self._grpc_client_map = {}

def start_grpc_stream(self, user_data, model_name, inputs, stream_uuid):
"""Starts a GRPC stream and sends a request to the Triton server."""
grpc_client_instance = grpcclient.InferenceServerClient(url="localhost:8001", verbose=False)
self._grpc_client_map[stream_uuid] = grpc_client_instance
grpc_client_instance.start_stream(callback=partial(callback, user_data))
grpc_client_instance.async_stream_infer(
model_name,
inputs,
request_id=stream_uuid,
enable_empty_final_response=True,
)

def stop_grpc_stream(self, stream_uuid, stream_thread: Thread):
"""Closes a GRPC stream and stops the associated thread."""
triton_grpc_stream = self._grpc_client_map[stream_uuid]
triton_grpc_stream.stop_stream()
stream_thread.join()
del self._grpc_client_map[stream_uuid]

def start_server(
self,
mpi: int = 1,
env: dict = {},
):
"""Triton Inference Server has different startup commands depending on
whether it is running in a TP=1 or TP>1 configuration. This function
starts the server with the appropriate command."""
if mpi == 1:
command = [
"tritonserver",
"--model-repository", str(self._model_repository_dir),
"--grpc-port", "8001",
"--http-port", "8003"
]
command = [
"mpirun",
"--allow-run-as-root",
]
for i in range(mpi):
command += [
"-n",
"1",
"tritonserver",
"--model-repository", str(self._model_repository_dir),
"--grpc-port", "8001",
"--http-port", "8003",
"--disable-auto-complete-config",
f"--backend-config=python,shm-region-prefix-name=prefix{str(i)}_",
":"
]
return subprocess.Popen(
command,
env={**os.environ, **env},
)

def load_server_and_model(self, env: dict):
"""Loads the Triton server and the model."""
prepare_model_repository(self._data_dir)
self.start_server(mpi=self._tensor_parallel_count, env=env)

self._http_client = httpclient.InferenceServerClient(url="localhost:8003", verbose=False)
is_server_up = False
while not is_server_up:
try:
is_server_up = self._http_client.is_server_live()
except ConnectionRefusedError:
time.sleep(2)
continue

while self._http_client.is_model_ready(model_name="ensemble") == False:
time.sleep(2)
continue

@staticmethod
def stream_predict(user_data: UserData):
"""Static method to yield predictions or errors based on input and a streaming user_data queue."""

def _is_final_response(result):
"""Check if the given result is a final response according to Triton's specification."""
if isinstance(result, InferenceServerException):
return True

if result:
final_response_param = result.get_response().parameters.get("triton_final_response")
return final_response_param.bool_param if final_response_param else False
return False

result = None

while not _is_final_response(result):
try:
result = user_data._completed_requests.get()
if not isinstance(result, InferenceServerException):
res = result.as_numpy('text_output')
yield res[0].decode("utf-8")
else:
yield json.dumps({
"status": "error",
"message": result.message()
})
except Exception as e:
yield json.dumps({
"status": "error",
"message": str(e)
})
break
Loading