Skip to content

Commit

Permalink
rename truss and add warning in README (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
joostinyi authored Nov 10, 2023
1 parent ff2f097 commit ede84a5
Show file tree
Hide file tree
Showing 14 changed files with 207 additions and 162 deletions.
22 changes: 12 additions & 10 deletions llama/llama-7b-trt/README.md → llama/llama-2-7b-trt-llm/README.md
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
[![Deploy to Baseten](https://user-images.githubusercontent.com/2389286/236301770-16f46d4f-4e23-4db5-9462-f578ec31e751.svg)](https://app.baseten.co/explore/llama)

# LLaMA-7B-Chat Truss
# LLaMA2-7B-Chat Truss

This is a [Truss](https://truss.baseten.co/) for an int8 SmoothQuant version of LLaMA-7B-Chat. Llama is a family of language models released by Meta. This README will walk you through how to deploy this Truss on Baseten to get your own instance of LLaMA-7B-Chat.
This is a [Truss](https://truss.baseten.co/) for an int8 SmoothQuant version of LLaMA2-7B-Chat. Llama is a family of language models released by Meta. This README will walk you through how to deploy this Truss on Baseten to get your own instance of LLaMA2-7B-Chat.

**Warning: This example is only intended for usage on a single A100, changing your resource type for this deployment will result in unsupported behavior**

## Truss

Truss is an open-source model serving framework developed by Baseten. It allows you to develop and deploy machine learning models onto Baseten (and other platforms like [AWS](https://truss.baseten.co/deploy/aws) or [GCP](https://truss.baseten.co/deploy/gcp). Using Truss, you can develop a GPU model using [live-reload](https://baseten.co/blog/technical-deep-dive-truss-live-reload), package models and their associated code, create Docker containers and deploy on Baseten.

## Deploying LLaMA-7B
## Deploying LLaMA2-7B-Chat

First, clone this repository:

```sh
git clone https://github.com/basetenlabs/truss-examples/
cd llama/llama-7b-trt
cd llama/llama-2-7b-trt-llm
```

Before deployment:

1. Make sure you have a [Baseten account](https://app.baseten.co/signup) and [API key](https://app.baseten.co/settings/account/api_keys).
2. Install the latest version of Truss: `pip install --upgrade truss`

With `llama-7b-trt` as your working directory, you can deploy the model with:
With `llama-2-7b-trt-llm` as your working directory, you can deploy the model with:

```sh
truss push
Expand All @@ -32,8 +34,8 @@ Paste your Baseten API key if prompted.

For more information, see [Truss documentation](https://truss.baseten.co).

## LLaMA-7B API documentation
This section provides an overview of the LLaMA-7B API, its parameters, and how to use it. The API consists of a single route named `predict`, which you can invoke to generate text based on the provided instruction.
## LLaMA2-7B API documentation
This section provides an overview of the LLaMA2-7B API, its parameters, and how to use it. The API consists of a single route named `predict`, which you can invoke to generate text based on the provided instruction.

### API route: `predict`

Expand All @@ -42,12 +44,12 @@ We expect requests will the following information:

- ```text_input``` (str): The prompt you'd like to complete
- ```output_len``` (int, default: 50): The max token count. This includes the number of tokens in your prompt so if this value is less than your prompt, you'll just recieve a truncated version of the prompt.
- ```beam_width``` (int, default:50): The number of beams to compute. This must be 1 for this version of TRT-LLM. Inflight-batching does not support beams > 1.
- ```beam_width``` (int, default:50): The number of beams to compute. This must be 1 for this version of TRT-LLM. Inflight-batching does not support beams > 1.
- ```bad_words_list``` (list, default:[]): A list of words to not include in generated output.
- ```stop_words_list``` (list, default:[]): A list of words to stop generation upon encountering.
- ```repetition_penalty``` (float, defualt: 1.0): A repetition penalty to incentivize not repeating tokens.
- ```repetition_penalty``` (float, defualt: 1.0): A repetition penalty to incentivize not repeating tokens.

This Truss will stream responses back. Responses will be buffered chunks of text.
This Truss will stream responses back. Responses will be buffered chunks of text.

## Example usage

Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@


# TRTLLM

### Overview
This Truss adds support for TRT-LLM engines via Triton Inference Server. TRT-LLM is a highly-performant language model runtime. We leverage the C++ runtime to take advantage of in-flight batching (aka continous batching).
This Truss adds support for TRT-LLM engines via Triton Inference Server. TRT-LLM is a highly-performant language model runtime. We leverage the C++ runtime to take advantage of in-flight batching (aka continous batching).

### Prerequisites
### Prerequisites

To use this Truss, your engine must be built with in-flight batching support. Refer to your architecture-specific `build.py` re: how to build with in-flight-batching support.
To use this Truss, your engine must be built with in-flight batching support. Refer to your architecture-specific `build.py` re: how to build with in-flight-batching support.

### Config

This Truss is primarily config driven. This means that most settings you'll need to edit are located in the `config.yaml`. These settings are all located underneath the `model_metadata` key.
This Truss is primarily config driven. This means that most settings you'll need to edit are located in the `config.yaml`. These settings are all located underneath the `model_metadata` key.

- `tensor_parallelism` (int): If you built your model with tensor parallelism support, you'll need to set this value with the same value used during the build engine step. This value should be the same as the number of GPUs in the `resources` section.
- `tensor_parallelism` (int): If you built your model with tensor parallelism support, you'll need to set this value with the same value used during the build engine step. This value should be the same as the number of GPUs in the `resources` section.

*Pipeline parallelism is not supported in this version but will be added later. As noted from Nvidia, pipeline parallelism reduces the need for high-bandwidth communication but may incur load-balancing issues and may be less efficient in terms of GPU utilization.*

Expand All @@ -28,9 +28,9 @@ secrets:
hf_access_token: "my_hf_api_key"
```

### Performance
### Performance

TRT-LLM engines are designed to be highly performant. Once your Truss has been deployed, you may find that you're not fully utilizing the GPU. The following are levers to improve performance but require trial-and-error to identify appropriates. All of these values live inside the `config.pbtxt` for a given ensemble model.
TRT-LLM engines are designed to be highly performant. Once your Truss has been deployed, you may find that you're not fully utilizing the GPU. The following are levers to improve performance but require trial-and-error to identify appropriates. All of these values live inside the `config.pbtxt` for a given ensemble model.

#### Preprocessing / Postprocessing

Expand All @@ -42,7 +42,7 @@ instance_group [
}
]
```
By default, we load 1 instance of the pre/post models. If you find that the tokenizer is a bottleneck, increasing the `count` variable here will load more replicas of these models and Triton will automatically load balance across model instances.
By default, we load 1 instance of the pre/post models. If you find that the tokenizer is a bottleneck, increasing the `count` variable here will load more replicas of these models and Triton will automatically load balance across model instances.

### Tensorrt LLM
```
Expand All @@ -53,7 +53,7 @@ parameters: {
}
}
```
By default, we set the `max_tokens_in_paged_kv_cache` to 10000. For a 7B model on 1 A100 with a batch size of 8, we have over 60GB of GPU memory left over. We can increase this value to 100k comfortably and allow for more tokens in the KV cache. Your mileage will vary based on the size of your model and the hardware you're running on.
By default, we set the `max_tokens_in_paged_kv_cache` to 10000. For a 7B model on 1 A100 with a batch size of 8, we have over 60GB of GPU memory left over. We can increase this value to 100k comfortably and allow for more tokens in the KV cache. Your mileage will vary based on the size of your model and the hardware you're running on.

```
parameters: {
Expand All @@ -73,7 +73,7 @@ parameters: {
}
}
```
The `max_num_sequences` param is the maximum numbers of requests that the inference server can maintain state for at a given time (state = KV cache + decoder state). If this value is greater than your max batch size, we'll try to ping pong processing between max_num_sequences // max_batch_size batches. This assumes that `enable_trt_overlap` is set to `True` (as it is by default in this Truss). Setting this value higher allows for more parallel processing but uses more GPU memory.
The `max_num_sequences` param is the maximum numbers of requests that the inference server can maintain state for at a given time (state = KV cache + decoder state). If this value is greater than your max batch size, we'll try to ping pong processing between max_num_sequences // max_batch_size batches. This assumes that `enable_trt_overlap` is set to `True` (as it is by default in this Truss). Setting this value higher allows for more parallel processing but uses more GPU memory.

### API

Expand All @@ -82,9 +82,9 @@ We expect requests will the following information:

- ```text_input``` (str): The prompt you'd like to complete
- ```output_len``` (int, default: 50): The max token count. This includes the number of tokens in your prompt so if this value is less than your prompt, you'll just recieve a truncated version of the prompt.
- ```beam_width``` (int, default:50): The number of beams to compute. This must be 1 for this version of TRT-LLM. Inflight-batching does not support beams > 1.
- ```beam_width``` (int, default:50): The number of beams to compute. This must be 1 for this version of TRT-LLM. Inflight-batching does not support beams > 1.
- ```bad_words_list``` (list, default:[]): A list of words to not include in generated output.
- ```stop_words_list``` (list, default:[]): A list of words to stop generation upon encountering.
- ```repetition_penalty``` (float, defualt: 1.0): A repetition penalty to incentivize not repeating tokens.
- ```repetition_penalty``` (float, defualt: 1.0): A repetition penalty to incentivize not repeating tokens.

This Truss will stream responses back. Responses will be buffered chunks of text.
This Truss will stream responses back. Responses will be buffered chunks of text.
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ resources:
secrets: {}
system_packages: []
runtime:
predict_concurrency: 256
predict_concurrency: 256
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import numpy as np
from client import UserData, TritonClient
from threading import Thread
from utils import prepare_grpc_tensor, download_engine
from pathlib import Path
from itertools import count
from pathlib import Path
from threading import Thread

import numpy as np
from client import TritonClient, UserData
from utils import download_engine, prepare_grpc_tensor

TRITON_MODEL_REPOSITORY_PATH = Path("/packages/inflight_batcher_llm/")


class Model:
def __init__(self, **kwargs):
self._data_dir = kwargs["data_dir"]
Expand All @@ -16,7 +18,9 @@ def __init__(self, **kwargs):
self.triton_client = None

def load(self):
tensor_parallel_count = self._config["model_metadata"].get("tensor_parallelism", 1)
tensor_parallel_count = self._config["model_metadata"].get(
"tensor_parallelism", 1
)
is_hf_token = "hf_access_token" in self._secrets._base_secrets.keys()
is_external_engine_repo = "engine_repository" in self._config["model_metadata"]

Expand All @@ -26,18 +30,20 @@ def load(self):
model_repository_dir=TRITON_MODEL_REPOSITORY_PATH,
tensor_parallel_count=tensor_parallel_count,
)

# Download model from Hugging Face Hub if specified
if is_external_engine_repo:
download_engine(
engine_repository=self._config["model_metadata"]["engine_repository"],
fp=self._data_dir,
auth_token=self._secrets["hf_access_token"] if is_hf_token else None
auth_token=self._secrets["hf_access_token"] if is_hf_token else None,
)

# Load Triton Server and model
env = {
"triton_tokenizer_repository": self._config["model_metadata"]["tokenizer_repository"],
"triton_tokenizer_repository": self._config["model_metadata"][
"tokenizer_repository"
],
}
if is_hf_token:
env["HUGGING_FACE_HUB_TOKEN"] = self._secrets["hf_access_token"]
Expand Down Expand Up @@ -74,13 +80,13 @@ def predict(self, model_input):
prepare_grpc_tensor("stop_words", stop_words_list),
prepare_grpc_tensor("stream", streaming_data),
prepare_grpc_tensor("beam_width", beam_width_data),
prepare_grpc_tensor("repetition_penalty", repetition_penalty_data)
prepare_grpc_tensor("repetition_penalty", repetition_penalty_data),
]

# Start GRPC stream in a separate thread
stream_thread = Thread(
target=self.triton_client.start_grpc_stream,
args=(user_data, model_name, inputs, stream_uuid)
args=(user_data, model_name, inputs, stream_uuid),
)
stream_thread.start()

Expand All @@ -89,4 +95,4 @@ def predict(self, model_input):
yield i

# Clean up GRPC stream and thread
self.triton_client.stop_grpc_stream(stream_uuid, stream_thread)
self.triton_client.stop_grpc_stream(stream_uuid, stream_thread)
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
import os
import json
import os
import subprocess
import time
from functools import partial
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient
from pathlib import Path
from queue import Queue
from utils import prepare_model_repository
from tritonclient.utils import InferenceServerException
from threading import Thread

import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient
from tritonclient.utils import InferenceServerException
from utils import prepare_model_repository


class UserData:
def __init__(self):
self._completed_requests = Queue()


def callback(user_data, result, error):
if error:
user_data._completed_requests.put(error)
else:
user_data._completed_requests.put(result)


class TritonClient:
def __init__(self, data_dir: Path, model_repository_dir: Path, tensor_parallel_count=1):
def __init__(
self, data_dir: Path, model_repository_dir: Path, tensor_parallel_count=1
):
self._data_dir = data_dir
self._model_repository_dir = model_repository_dir
self._tensor_parallel_count = tensor_parallel_count
Expand All @@ -31,7 +37,9 @@ def __init__(self, data_dir: Path, model_repository_dir: Path, tensor_parallel_c

def start_grpc_stream(self, user_data, model_name, inputs, stream_uuid):
"""Starts a GRPC stream and sends a request to the Triton server."""
grpc_client_instance = grpcclient.InferenceServerClient(url="localhost:8001", verbose=False)
grpc_client_instance = grpcclient.InferenceServerClient(
url="localhost:8001", verbose=False
)
self._grpc_client_map[stream_uuid] = grpc_client_instance
grpc_client_instance.start_stream(callback=partial(callback, user_data))
grpc_client_instance.async_stream_infer(
Expand Down Expand Up @@ -59,9 +67,12 @@ def start_server(
if mpi == 1:
command = [
"tritonserver",
"--model-repository", str(self._model_repository_dir),
"--grpc-port", "8001",
"--http-port", "8003"
"--model-repository",
str(self._model_repository_dir),
"--grpc-port",
"8001",
"--http-port",
"8003",
]
command = [
"mpirun",
Expand All @@ -72,12 +83,15 @@ def start_server(
"-n",
"1",
"tritonserver",
"--model-repository", str(self._model_repository_dir),
"--grpc-port", "8001",
"--http-port", "8003",
"--model-repository",
str(self._model_repository_dir),
"--grpc-port",
"8001",
"--http-port",
"8003",
"--disable-auto-complete-config",
f"--backend-config=python,shm-region-prefix-name=prefix{str(i)}_",
":"
":",
]
return subprocess.Popen(
command,
Expand All @@ -89,7 +103,9 @@ def load_server_and_model(self, env: dict):
prepare_model_repository(self._data_dir)
self.start_server(mpi=self._tensor_parallel_count, env=env)

self._http_client = httpclient.InferenceServerClient(url="localhost:8003", verbose=False)
self._http_client = httpclient.InferenceServerClient(
url="localhost:8003", verbose=False
)
is_server_up = False
while not is_server_up:
try:
Expand All @@ -112,8 +128,12 @@ def _is_final_response(result):
return True

if result:
final_response_param = result.get_response().parameters.get("triton_final_response")
return final_response_param.bool_param if final_response_param else False
final_response_param = result.get_response().parameters.get(
"triton_final_response"
)
return (
final_response_param.bool_param if final_response_param else False
)
return False

result = None
Expand All @@ -122,16 +142,10 @@ def _is_final_response(result):
try:
result = user_data._completed_requests.get()
if not isinstance(result, InferenceServerException):
res = result.as_numpy('text_output')
res = result.as_numpy("text_output")
yield res[0].decode("utf-8")
else:
yield json.dumps({
"status": "error",
"message": result.message()
})
yield json.dumps({"status": "error", "message": result.message()})
except Exception as e:
yield json.dumps({
"status": "error",
"message": str(e)
})
break
yield json.dumps({"status": "error", "message": str(e)})
break
Loading

0 comments on commit ede84a5

Please sign in to comment.