From e4bb4a7342a5c8e22fa31c8756d0aff4d3bd6cdd Mon Sep 17 00:00:00 2001 From: Abu Qader <48742992+aspctu@users.noreply.github.com> Date: Wed, 27 Mar 2024 20:23:00 +0000 Subject: [PATCH 1/3] dolphin --- dolfo/config.yaml | 43 +++ dolfo/model/__init__.py | 0 dolfo/model/model.py | 139 ++++++++++ dolfo/packages/build_engine_utils.py | 24 ++ dolfo/packages/constants.py | 9 + dolfo/packages/schema.py | 155 +++++++++++ .../ensemble/config.pbtxt | 246 +++++++++++++++++ .../postprocessing/1/model.py | 205 ++++++++++++++ .../postprocessing/config.pbtxt | 64 +++++ .../preprocessing/1/model.py | 260 ++++++++++++++++++ .../preprocessing/config.pbtxt | 99 +++++++ .../tensorrt_llm/config.pbtxt | 208 ++++++++++++++ dolfo/packages/triton_client.py | 136 +++++++++ dolfo/packages/utils.py | 81 ++++++ 14 files changed, 1669 insertions(+) create mode 100644 dolfo/config.yaml create mode 100644 dolfo/model/__init__.py create mode 100644 dolfo/model/model.py create mode 100644 dolfo/packages/build_engine_utils.py create mode 100644 dolfo/packages/constants.py create mode 100644 dolfo/packages/schema.py create mode 100644 dolfo/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt create mode 100644 dolfo/packages/tensorrt_llm_model_repository/postprocessing/1/model.py create mode 100644 dolfo/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt create mode 100644 dolfo/packages/tensorrt_llm_model_repository/preprocessing/1/model.py create mode 100644 dolfo/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt create mode 100644 dolfo/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt create mode 100644 dolfo/packages/triton_client.py create mode 100644 dolfo/packages/utils.py diff --git a/dolfo/config.yaml b/dolfo/config.yaml new file mode 100644 index 00000000..caa655dd --- /dev/null +++ b/dolfo/config.yaml @@ -0,0 +1,43 @@ +apply_library_patches: true +base_image: + image: baseten/trtllm-build-server:r23.12_baseten_v0.9.0_20240305 + python_executable_path: /usr/bin/python3 +bundled_packages_dir: packages +data_dir: data +description: Generate text from a prompt with this seven billion parameter language + model. +environment_variables: {} +examples_filename: examples.yaml +external_data: null +external_package_dirs: [] +input_type: Any +live_reload: false +model_class_filename: model.py +model_class_name: Model +model_framework: custom +trt_llm: + serve: + engine_repository: baseten/dolphin_i6000_o1024_bs96_tp8-tllm_0.9.0.dev2024022000 + pipeline_parallel_count: 1 + tensor_parallel_count: 8 + tokenizer_repository: cognitivecomputations/dolphin-2.6-mixtral-8x7b +model_metadata: + engine_repository: baseten/dolphin_i6000_o1024_bs96_tp8-tllm_0.9.0.dev2024022000 + tags: + - text-generation + - openai-compatible +model_module_dir: model +model_name: Dolphin Mixtral TP2 - TP8 Num Workers 1 +model_type: Model +python_version: py311 +requirements: +- tritonclient[all] +- transformers +- jinja2 +resources: + accelerator: H100:8 + use_gpu: true +runtime: + num_workers: 1 + predict_concurrency: 1000 +secrets: {} diff --git a/dolfo/model/__init__.py b/dolfo/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dolfo/model/model.py b/dolfo/model/model.py new file mode 100644 index 00000000..961e0adb --- /dev/null +++ b/dolfo/model/model.py @@ -0,0 +1,139 @@ +import os +from itertools import count + +import build_engine_utils +from builder.types import TrussTRTLLMConfiguration +from constants import ( + GRPC_SERVICE_PORT, + HF_AUTH_KEY_CONSTANT, + HTTP_SERVICE_PORT, + TOKENIZER_KEY_CONSTANT, +) +from schema import ModelInput +from transformers import AutoTokenizer +from triton_client import TritonClient, TritonServer +from utils import execute_command, server_loaded_file_approach + +APPEND_ASSISTANT_TEMPLATE_TO_PROMPT = True +APPEND_ASSISTANT_TEMPLATE_TO_PROMPT_STR = "<|im_start|>assistant" +STOP_TOKEN = "<|im_end|>" + +class Model: + def __init__(self, data_dir, config, secrets): + self._data_dir = data_dir + self._config = config + self._secrets = secrets + self._request_id_counter = count(start=1) + self.triton_client = None + self.triton_server = None + self.tokenizer = None + self.uses_openai_api = None + + def load(self): + execute_command(["ldconfig"]) + trtllm_config = TrussTRTLLMConfiguration(**self._config.get("trt_llm", {})) + self.uses_openai_api = "openai-compatible" in self._config.get( + "model_metadata", {} + ).get("tags", []) + hf_access_token = None + if "hf_access_token" in self._secrets._base_secrets.keys(): + hf_access_token = self._secrets["hf_access_token"] + + # TODO(Abu): Move to pre-runtime + # if trtllm_config.requires_build: + # build_engine_utils.build_engine_from_config_args( + # truss_trtllm_configuration=trtllm_config, + # checkpoint_dir_path=None, + # dst=self._data_dir, + # ) + + self.triton_server = TritonServer( + grpc_port=GRPC_SERVICE_PORT, + http_port=HTTP_SERVICE_PORT, + ) + + if not trtllm_config.requires_build: + engine_repository_path = trtllm_config.serve.engine_repository + tokenizer_repository = trtllm_config.serve.tokenizer_repository + tensor_parallel_count = trtllm_config.serve.tensor_parallel_count + pipeline_parallel_count = trtllm_config.serve.pipeline_parallel_count + world_size = tensor_parallel_count * pipeline_parallel_count + else: + engine_repository_path = None + tokenizer_repository = trtllm_config.build.huggingface_ckpt_repository + tensor_parallel_count = trtllm_config.build.tensor_parallel_count + pipeline_parallel_count = trtllm_config.build.pipeline_parallel_count + world_size = tensor_parallel_count * pipeline_parallel_count + + if not server_loaded_file_approach(): + self.triton_server.create_model_repository( + truss_data_dir=self._data_dir, + engine_repository_path=engine_repository_path, + huggingface_auth_token=hf_access_token, + ) + + env = {} + if hf_access_token: + env[HF_AUTH_KEY_CONSTANT] = hf_access_token + env[TOKENIZER_KEY_CONSTANT] = tokenizer_repository + + self.triton_server.start( + world_size=world_size, + env=env, + ) + + self.triton_client = TritonClient( + grpc_service_port=GRPC_SERVICE_PORT, + ) + + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_repository, token=hf_access_token + ) + self.eos_token_id = self.tokenizer.eos_token_id + + async def predict(self, model_input): + if model_input.get("max_tokens") is None: + model_input["max_tokens"] = 500 + + if model_input.get("max_new_tokens") is None: + model_input["max_new_tokens"] = 500 + + model_input["request_id"] = str(os.getpid()) + str( + next(self._request_id_counter) + ) + model_input["eos_token_id"] = self.eos_token_id + messages = model_input.get("messages", []) + if "messages" in model_input: + del model_input["messages"] + prompt = model_input.get("prompt", None) + if not prompt and messages == []: + raise ValueError("Prompt or messages must be provided") + + if self.uses_openai_api and not prompt: + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + ) + model_input["prompt"] = prompt + + if APPEND_ASSISTANT_TEMPLATE_TO_PROMPT: + model_input["prompt"] = f"{model_input['prompt']}{APPEND_ASSISTANT_TEMPLATE_TO_PROMPT_STR}" + + self.triton_client.start_grpc_stream() + model_input = ModelInput(**model_input) + result_iterator = self.triton_client.infer(model_input) + + async def generate(): + async for result in result_iterator: + if result != STOP_TOKEN: + yield result + else: + yield "" + + if model_input.stream: + return generate() + else: + if self.uses_openai_api: + return "".join(generate()) + else: + return {"text": "".join(generate())} diff --git a/dolfo/packages/build_engine_utils.py b/dolfo/packages/build_engine_utils.py new file mode 100644 index 00000000..900abae6 --- /dev/null +++ b/dolfo/packages/build_engine_utils.py @@ -0,0 +1,24 @@ +from pathlib import Path +from typing import Optional + +from builder.types import TrussTRTLLMConfiguration + + +def build_engine_from_config_args( + truss_trtllm_configuration: TrussTRTLLMConfiguration, + dst: Path, + checkpoint_dir_path: Optional[Path] = None, +): + # NOTE: These are provided by the underlying base image + # TODO(Abu): Remove this when we have a better way of handling this + from builder.main import build_engine + + build_engine( + engine_configuration=truss_trtllm_configuration, + engine_serialization_path=dst, + # If checkpoint_dir_path is provided, we'll look there for the + # weight files. If not, we will attempt to use the `huggingface_ckpt_repository` + # key in the `truss_trtllm_configuration` to download the weights. + checkpoint_dir_path=checkpoint_dir_path, + ) + return dst \ No newline at end of file diff --git a/dolfo/packages/constants.py b/dolfo/packages/constants.py new file mode 100644 index 00000000..1f19e806 --- /dev/null +++ b/dolfo/packages/constants.py @@ -0,0 +1,9 @@ +from pathlib import Path + +# If changing model repo path, please updated inside tensorrt_llm config.pbtxt as well +TENSORRT_LLM_MODEL_REPOSITORY_PATH = Path("/packages/tensorrt_llm_model_repository/") +GRPC_SERVICE_PORT = 8001 +HTTP_SERVICE_PORT = 8003 +HF_AUTH_KEY_CONSTANT = "HUGGING_FACE_HUB_TOKEN" +TOKENIZER_KEY_CONSTANT = "TRITON_TOKENIZER_REPOSITORY" +ENTRYPOINT_MODEL_NAME = "ensemble" diff --git a/dolfo/packages/schema.py b/dolfo/packages/schema.py new file mode 100644 index 00000000..4847fcb0 --- /dev/null +++ b/dolfo/packages/schema.py @@ -0,0 +1,155 @@ +from typing import Optional + +import numpy as np +import tritonclient +import tritonclient.grpc.aio as grpcclient + + +class ModelInput: + def __init__( + self, + prompt: str, + request_id: int, + max_tokens: int = 50, + max_new_tokens: int = 50, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 50, + beam_width: int = 1, + bad_words_list: Optional[list] = None, + stop_words_list: Optional[list] = None, + repetition_penalty: float = 1.0, + ignore_eos: bool = False, + stream: bool = True, + eos_token_id: int = None, # type: ignore + ) -> None: + self.stream = stream + self.request_id = request_id + self._prompt = prompt + self._max_tokens = max_tokens + self._beam_width = beam_width + self._bad_words_list = [""] if bad_words_list is None else bad_words_list + self._stop_words_list = [""] if stop_words_list is None else stop_words_list + self._repetition_penalty = repetition_penalty + self._eos_token_id = eos_token_id + self._ignore_eos = ignore_eos + # These variables are passed by OAI proxy but are unused + # TODO(Abu): Add support for these + self._max_new_tokens = max_new_tokens + self._temperature = temperature + self._top_p = top_p + self._top_k = top_k + + def _prepare_grpc_tensor( + self, name: str, input_data: np.ndarray + ) -> grpcclient.InferInput: + tensor = grpcclient.InferInput( + name, + input_data.shape, + tritonclient.utils.np_to_triton_dtype(input_data.dtype), + ) + tensor.set_data_from_numpy(input_data) + return tensor + + def to_tensors(self): + if self._eos_token_id is None and self._ignore_eos: + raise ValueError("eos_token_id is required when ignore_eos is True") + + prompt_data = np.array([[self._prompt]], dtype=object) + output_len_data = np.ones_like(prompt_data, dtype=np.uint32) * self._max_tokens + bad_words_data = np.array([self._bad_words_list], dtype=object) + stop_words_data = np.array([self._stop_words_list], dtype=object) + stream_data = np.array([[self.stream]], dtype=bool) + beam_width_data = np.array([[self._beam_width]], dtype=np.uint32) + repetition_penalty_data = np.array( + [[self._repetition_penalty]], dtype=np.float32 + ) + + inputs = [ + self._prepare_grpc_tensor("text_input", prompt_data), + self._prepare_grpc_tensor("max_tokens", output_len_data), + self._prepare_grpc_tensor("bad_words", bad_words_data), + self._prepare_grpc_tensor("stop_words", stop_words_data), + self._prepare_grpc_tensor("stream", stream_data), + self._prepare_grpc_tensor("beam_width", beam_width_data), + self._prepare_grpc_tensor("repetition_penalty", repetition_penalty_data), + ] + + if not self._ignore_eos: + end_id_data = np.array([[self._eos_token_id]], dtype=np.uint32) + inputs.append(self._prepare_grpc_tensor("end_id", end_id_data)) + + return inputs + + +# The following are duplicated from the underlying base image. +# We list them as a comment for posterity: +# +# class TRTLLMModelArchitecture(Enum): +# LLAMA: str = "llama" +# MISTRAL: str = "mistral" +# DEEPSEEK: str = "deepseek" + + +# class TRTLLMQuantizationType(Enum): +# NO_QUANT: str = "no_quant" +# WEIGHTS_ONLY_INT8: str = "weights_int8" +# WEIGHTS_KV_INT8: str = "weights_kv_int8" +# WEIGHTS_ONLY_INT4: str = "weights_int4" +# WEIGHTS_KV_INT4: str = "weights_kv_int4" +# SMOOTH_QUANT: str = "smooth_quant" +# FP8: str = "fp8" +# FP8_KV: str = "fp8_kv" + +# class TrussTRTLLMPluginConfiguration(BaseModel): +# multi_block_mode: bool = False +# paged_kv_cache: bool = True +# use_fused_mlp: bool = False + +# class TrussTRTLLMBuildConfiguration(BaseModel): +# base_model_architecture: TRTLLMModelArchitecture +# max_input_len: int +# max_output_len: int +# max_batch_size: int +# max_beam_width: int +# max_prompt_embedding_table_size: int = 0 +# huggingface_ckpt_repository: Optional[str] +# gather_all_token_logits: bool = False +# strongly_typed: bool = False +# quantization_type: TRTLLMQuantizationType = TRTLLMQuantizationType.NO_QUANT +# tensor_parallel_count: int = 1 +# pipeline_parallel_count: int = 1 +# plugin_configuration: TrussTRTLLMPluginConfiguration = TrussTRTLLMPluginConfiguration() + +# class TrussTRTLLMServingConfiguration(BaseModel): +# engine_repository: str +# tokenizer_repository: str +# tensor_parallel_count: int = 1 +# pipeline_parallel_count: int = 1 + +# class TrussTRTLLMConfiguration(BaseModel): +# serve: Optional[TrussTRTLLMServingConfiguration] = None +# build: Optional[TrussTRTLLMBuildConfiguration] = None + +# @model_validator(mode="after") +# def check_minimum_required_configuration(self): +# if not self.serve and not self.build: +# raise ValueError( +# "Either serve or build configurations must be provided" +# ) +# if self.serve and self.build: +# raise ValueError( +# "Both serve and build configurations cannot be provided" +# ) +# if self.serve is not None: +# if (self.serve.engine_repository is None) ^ (self.serve.tokenizer_repository is None): +# raise ValueError( +# "Both engine_repository and tokenizer_repository must be provided" +# ) +# return self + +# @property +# def requires_build(self): +# if self.build is not None: +# return True +# return False \ No newline at end of file diff --git a/dolfo/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt b/dolfo/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt new file mode 100644 index 00000000..618098de --- /dev/null +++ b/dolfo/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt @@ -0,0 +1,246 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "ensemble" +platform: "ensemble" +max_batch_size: 2048 +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "max_tokens" + data_type: TYPE_UINT32 + dims: [ -1 ] + }, + { + name: "bad_words" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "stop_words" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "length_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + optional: true + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + } +] +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1, -1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocessing" + model_version: -1 + input_map { + key: "QUERY" + value: "text_input" + } + input_map { + key: "REQUEST_OUTPUT_LEN" + value: "max_tokens" + } + input_map { + key: "BAD_WORDS_DICT" + value: "bad_words" + } + input_map { + key: "STOP_WORDS_DICT" + value: "stop_words" + } + output_map { + key: "REQUEST_INPUT_LEN" + value: "_REQUEST_INPUT_LEN" + } + output_map { + key: "INPUT_ID" + value: "_INPUT_ID" + } + output_map { + key: "REQUEST_OUTPUT_LEN" + value: "_REQUEST_OUTPUT_LEN" + } + }, + { + model_name: "tensorrt_llm" + model_version: -1 + input_map { + key: "input_ids" + value: "_INPUT_ID" + } + input_map { + key: "input_lengths" + value: "_REQUEST_INPUT_LEN" + } + input_map { + key: "request_output_len" + value: "_REQUEST_OUTPUT_LEN" + } + input_map { + key: "end_id" + value: "end_id" + } + input_map { + key: "pad_id" + value: "pad_id" + } + input_map { + key: "runtime_top_k" + value: "top_k" + } + input_map { + key: "runtime_top_p" + value: "top_p" + } + input_map { + key: "temperature" + value: "temperature" + } + input_map { + key: "len_penalty" + value: "length_penalty" + } + input_map { + key: "repetition_penalty" + value: "repetition_penalty" + } + input_map { + key: "min_length" + value: "min_length" + } + input_map { + key: "presence_penalty" + value: "presence_penalty" + } + input_map { + key: "random_seed" + value: "random_seed" + } + input_map { + key: "beam_width" + value: "beam_width" + } + input_map { + key: "streaming" + value: "stream" + } + output_map { + key: "output_ids" + value: "_TOKENS_BATCH" + } + }, + { + model_name: "postprocessing" + model_version: -1 + input_map { + key: "TOKENS_BATCH" + value: "_TOKENS_BATCH" + } + output_map { + key: "OUTPUT" + value: "text_output" + } + } + ] +} diff --git a/dolfo/packages/tensorrt_llm_model_repository/postprocessing/1/model.py b/dolfo/packages/tensorrt_llm_model_repository/postprocessing/1/model.py new file mode 100644 index 00000000..ff7ab4ad --- /dev/null +++ b/dolfo/packages/tensorrt_llm_model_repository/postprocessing/1/model.py @@ -0,0 +1,205 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import os +from collections import OrderedDict + +import numpy as np +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + +# https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/strip.rs#L8 +INVALID_UNICODE_CHAR = "�" + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args["model_config"]) + # NOTE: Keep this in sync with the truss model.py variable + tokenizer_dir = os.environ["TRITON_TOKENIZER_REPOSITORY"] + tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] + + if tokenizer_type == "t5": + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") + elif tokenizer_type == "auto": + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_dir, padding_side="left" + ) + elif tokenizer_type == "llama": + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side="left" + ) + else: + raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}") + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Parse model output configs + output_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT") + # Convert Triton types to numpy types + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + self.state_dict = OrderedDict() + # TODO(pankaj) This should come from the batch size + self.cache_size = 2048 + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for idx, request in enumerate(requests): + # Get request ID + request_id = request.request_id() + + # Get input tensors + tokens_batch = ( + pb_utils.get_input_tensor_by_name(request, "TOKENS_BATCH") + .as_numpy() + .flatten() + ) + + if len(tokens_batch) == 0: + continue + + # Postprocess output data + prev_token = self._get_var(request_id, "prev_token") + token_buffer = self._get_var(request_id, "token_buffer") + token_buffer = token_buffer if token_buffer is not None else [] + current_tokens = np.concatenate( + (np.array(token_buffer, dtype=int), tokens_batch), dtype=int + ) + current_tokens_decoded = self.tokenizer.decode(current_tokens) + + if len(current_tokens_decoded) == 0: + responses.append(pb_utils.InferenceResponse()) + continue + + if current_tokens_decoded[-1] == INVALID_UNICODE_CHAR: + # If the last token is invalid, we need to keep it in the buffer + # for the next request to see if this is a multi-token unicode + # character. + self._store_var(request_id, "token_buffer", current_tokens) + responses.append(pb_utils.InferenceResponse()) + continue + + if prev_token is None: + delta = current_tokens_decoded + else: + # TODO(pankaj) Figure out how to make tokenizer.decode not + # ignore initial whitespace so we can avoid this hack. + # Get string with and without previous token and diff. This hack + # is needed because tokenizer.decode strips initial whitespace. + old_string = self.tokenizer.decode(prev_token) + with_prev_token = np.concatenate((prev_token, current_tokens)) + new_string = self.tokenizer.decode(with_prev_token) + delta = self._compute_delta(old_string, new_string) + + # The previous token is the last character of the decoded sequence + # which includes the multi-token unicode character. + self._store_var(request_id, "prev_token", current_tokens) + self._store_var(request_id, "token_buffer", None) + + # Create output tensor + output_tensor = pb_utils.Tensor( + "OUTPUT", np.array([delta]).astype(self.output_dtype) + ) + inference_response = pb_utils.InferenceResponse( + output_tensors=[output_tensor] + ) + responses.append(inference_response) + + return responses + + def finalize(self): + print("Cleaning up...") + + def _store_var(self, request_id, var_name, var): + if request_id in self.state_dict: + self.state_dict[request_id][var_name] = var + self.state_dict.move_to_end(request_id) + else: + if len(self.state_dict) > self.cache_size: + self.state_dict.popitem(last=False) + self.state_dict[request_id] = {"prev_token": None, "token_buffer": None} + self.state_dict[request_id][var_name] = var + + def _get_var(self, request_id, var_name): + if request_id in self.state_dict: + return self.state_dict[request_id][var_name] + return None + + def _compute_delta(self, prev_str, new_str): + delta = "".join( + [ + char + for index, char in enumerate(new_str) + if index >= len(prev_str) or char != prev_str[index] + ] + ) + return delta + + def _postprocessing(self, tokens): + decoded_tokens = self.tokenizer.decode(tokens) + return decoded_tokens \ No newline at end of file diff --git a/dolfo/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt b/dolfo/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt new file mode 100644 index 00000000..854ef960 --- /dev/null +++ b/dolfo/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt @@ -0,0 +1,64 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "postprocessing" +backend: "python" +max_batch_size: 2048 +input [ + { + name: "TOKENS_BATCH" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +output [ + { + name: "OUTPUT" + data_type: TYPE_STRING + dims: [ -1, -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "NousResearch/Llama-2-7b-hf" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "auto" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/dolfo/packages/tensorrt_llm_model_repository/preprocessing/1/model.py b/dolfo/packages/tensorrt_llm_model_repository/preprocessing/1/model.py new file mode 100644 index 00000000..fa4dcc2c --- /dev/null +++ b/dolfo/packages/tensorrt_llm_model_repository/preprocessing/1/model.py @@ -0,0 +1,260 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import csv +import json +import os +from typing import List + +import numpy as np +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args["model_config"]) + # NOTE: Keep this in sync with the truss model.py variable + tokenizer_dir = os.environ["TRITON_TOKENIZER_REPOSITORY"] + tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] + self.add_special_tokens = model_config["parameters"].get( + "add_special_tokens", {"string_value": "false"} + )["string_value"].lower() in ["true", "1", "t", "y", "yes"] + + if tokenizer_type == "t5": + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") + elif tokenizer_type == "auto": + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_dir, padding_side="left" + ) + elif tokenizer_type == "llama": + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side="left" + ) + else: + raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}") + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.pad_id = self.tokenizer.encode( + self.tokenizer.pad_token, add_special_tokens=False + )[0] + + # Parse model output configs and convert Triton types to numpy types + input_names = [ + "INPUT_ID", + "REQUEST_INPUT_LEN", + "BAD_WORDS_IDS", + "STOP_WORDS_IDS", + ] + for input_name in input_names: + setattr( + self, + input_name.lower() + "_dtype", + pb_utils.triton_string_to_numpy( + pb_utils.get_output_config_by_name(model_config, input_name)[ + "data_type" + ] + ), + ) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for idx, request in enumerate(requests): + # Get input tensors + query = pb_utils.get_input_tensor_by_name(request, "QUERY").as_numpy() + request_output_len = pb_utils.get_input_tensor_by_name( + request, "REQUEST_OUTPUT_LEN" + ).as_numpy() + + bad_words_dict = pb_utils.get_input_tensor_by_name( + request, "BAD_WORDS_DICT" + ).as_numpy() + stop_words_dict = pb_utils.get_input_tensor_by_name( + request, "STOP_WORDS_DICT" + ).as_numpy() + + # Preprocessing input data. + input_id, request_input_len = self._create_request(query) + bad_words = self._to_word_list_format(bad_words_dict) + stop_words = self._to_word_list_format(stop_words_dict) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + input_id_tensor = pb_utils.Tensor( + "INPUT_ID", np.array(input_id).astype(self.input_id_dtype) + ) + request_input_len_tensor = pb_utils.Tensor( + "REQUEST_INPUT_LEN", + np.array(request_input_len).astype(self.request_input_len_dtype), + ) + request_output_len_tensor = pb_utils.Tensor( + "REQUEST_OUTPUT_LEN", request_output_len + ) + bad_words_ids_tensor = pb_utils.Tensor("BAD_WORDS_IDS", bad_words) + stop_words_ids_tensor = pb_utils.Tensor("STOP_WORDS_IDS", stop_words) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[ + input_id_tensor, + bad_words_ids_tensor, + stop_words_ids_tensor, + request_input_len_tensor, + request_output_len_tensor, + ] + ) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print("Cleaning up...") + + def _create_request(self, query): + """ + query : batch string (2D numpy array) + """ + start_ids = [ + np.array( + self.tokenizer.encode( + s[0].decode(), add_special_tokens=self.add_special_tokens + ) + ).astype(int) + for s in query + ] + start_lengths = np.array([[len(ids)] for ids in start_ids]).astype(int) + + max_len = 0 + for seq in start_ids: + max_len = max(max_len, seq.shape[0]) + start_ids = np.stack( + [ + np.pad( + seq, + (0, max_len - seq.shape[0]), + "constant", + constant_values=(0, self.pad_id), + ) + for seq in start_ids + ] + ) + + return start_ids, start_lengths + + def _to_word_list_format(self, word_dict: List[List[str]]): + """ + format of word_dict + len(word_dict) should be same to batch_size + word_dict[i] means the words for batch i + len(word_dict[i]) must be 1, which means it only contains 1 string + This string can contains several sentences and split by ",". + For example, if word_dict[2] = " I am happy, I am sad", then this function will return + the ids for two short sentences " I am happy" and " I am sad". + """ + assert self.tokenizer is not None, "need to set tokenizer" + + flat_ids = [] + offsets = [] + for word_dict_item in word_dict: + item_flat_ids = [] + item_offsets = [] + + if isinstance(word_dict_item[0], bytes): + word_dict_item = [word_dict_item[0].decode()] + + words = list(csv.reader(word_dict_item))[0] + for word in words: + ids = self.tokenizer.encode(word) + + if len(ids) == 0: + continue + + item_flat_ids += ids + item_offsets.append(len(ids)) + + flat_ids.append(np.array(item_flat_ids)) + offsets.append(np.cumsum(np.array(item_offsets))) + + pad_to = max(1, max(len(ids) for ids in flat_ids)) + + for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): + flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0) + offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) + + return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2)) diff --git a/dolfo/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt b/dolfo/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt new file mode 100644 index 00000000..1fb88012 --- /dev/null +++ b/dolfo/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt @@ -0,0 +1,99 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "preprocessing" +backend: "python" +max_batch_size: 2048 +input [ + { + name: "QUERY" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "BAD_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "STOP_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] +output [ + { + name: "INPUT_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "REQUEST_INPUT_LEN" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "BAD_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "STOP_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "NousResearch/Llama-2-7b-hf" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "auto" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/dolfo/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt b/dolfo/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt new file mode 100644 index 00000000..75cb6718 --- /dev/null +++ b/dolfo/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt @@ -0,0 +1,208 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm" +backend: "tensorrtllm" +max_batch_size: 2048 + +model_transaction_policy { + decoupled: True +} + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "request_output_len" + data_type: TYPE_UINT32 + dims: [ 1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "streaming" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + } +] +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] +parameters: { + key: "max_beam_width" + value: { + string_value: "1" + } +} +parameters: { + key: "FORCE_CPU_ONLY_INPUT_TENSORS" + value: { + string_value: "no" + } +} +parameters: { + key: "gpt_model_type" + value: { + string_value: "inflight_fused_batching" + } +} +parameters: { + key: "gpt_model_path" + value: { + string_value: "/packages/tensorrt_llm_model_repository/tensorrt_llm/1" + } +} +parameters: { + key: "max_tokens_in_paged_kv_cache" + value: { + string_value: "100000" + } +} +parameters: { + key: "batch_scheduler_policy" + value: { + string_value: "max_utilization" + } +} +parameters: { + key: "kv_cache_free_gpu_mem_fraction" + value: { + string_value: "0.9" + } +} +parameters: { + key: "max_num_sequences" + value: { + string_value: "2048" + } +} +parameters: { + key: "enable_trt_overlap" + value: { + string_value: "False" + } +} diff --git a/dolfo/packages/triton_client.py b/dolfo/packages/triton_client.py new file mode 100644 index 00000000..80f7a6a3 --- /dev/null +++ b/dolfo/packages/triton_client.py @@ -0,0 +1,136 @@ +import json +import os +import subprocess +import time +from pathlib import Path +from typing import AsyncGenerator, Optional + +import tritonclient.grpc.aio as grpcclient +import tritonclient.http as httpclient +from constants import ( + ENTRYPOINT_MODEL_NAME, + GRPC_SERVICE_PORT, + TENSORRT_LLM_MODEL_REPOSITORY_PATH, +) +from schema import ModelInput +from utils import download_engine, prepare_model_repository + + +class TritonServer: + def __init__(self, grpc_port: int = 8001, http_port: int = 8003): + self.grpc_port = grpc_port + self.http_port = http_port + self._server_process = None + + def create_model_repository( + self, + truss_data_dir: Path, + engine_repository_path: Optional[str] = None, + huggingface_auth_token: Optional[str] = None, + ) -> None: + if engine_repository_path: + download_engine( + engine_repository=engine_repository_path, + fp=truss_data_dir, + auth_token=huggingface_auth_token, + ) + prepare_model_repository(truss_data_dir) + return + + def start(self, world_size: int = 1, env: dict = {}) -> None: + mpirun_command = ["mpirun", "--allow-run-as-root"] + mpi_commands = [] + for i in range(world_size): + mpi_command = [ + "-n", + "1", + "tritonserver", + f"--model-repository={TENSORRT_LLM_MODEL_REPOSITORY_PATH}", + f"--grpc-port={str(self.grpc_port)}", + f"--http-port={str(self.http_port)}", + "--disable-auto-complete-config", + f"--backend-config=python,shm-region-prefix-name=prefix{i}_", + ":", + ] + + mpi_commands.extend(mpi_command) + command = mpirun_command + mpi_commands + + self._server_process = subprocess.Popen( # type: ignore + command, + env={**os.environ, **env}, + ) + while not self.is_alive and not self.is_ready: + time.sleep(2) + return + + def stop(self): + if self._server_process: + if self.is_ready: + self._server_process.kill() + self._server_process = None + return + + @property + def is_alive(self) -> bool: + try: + http_client = httpclient.InferenceServerClient( + url=f"localhost:{self.http_port}", verbose=False + ) + return http_client.is_server_live() + except ConnectionRefusedError: + return False + + @property + def is_ready(self) -> bool: + try: + http_client = httpclient.InferenceServerClient( + url=f"localhost:{self.http_port}", verbose=False + ) + return http_client.is_model_ready(model_name=ENTRYPOINT_MODEL_NAME) + except ConnectionRefusedError: + return False + + +class TritonClient: + def __init__(self, grpc_service_port: int = GRPC_SERVICE_PORT): + self.grpc_service_port = grpc_service_port + self._grpc_client = None + + def start_grpc_stream(self) -> grpcclient.InferenceServerClient: + if self._grpc_client: + return self._grpc_client + + self._grpc_client = grpcclient.InferenceServerClient( + url=f"localhost:{self.grpc_service_port}", verbose=False + ) + return self._grpc_client + + async def infer( + self, model_input: ModelInput, model_name="ensemble" + ) -> AsyncGenerator[str, None]: + grpc_client_instance = self.start_grpc_stream() + inputs = model_input.to_tensors() + + async def input_generator(): + yield { + "model_name": model_name, + "inputs": inputs, + "request_id": model_input.request_id, + } + + response_iterator = grpc_client_instance.stream_infer( + inputs_iterator=input_generator(), + ) + + try: + async for response in response_iterator: + result, error = response + if result: + result = result.as_numpy("text_output") + yield result[0].decode("utf-8") + else: + yield json.dumps({"status": "error", "message": error.message()}) + + except grpcclient.InferenceServerException as e: + print(f"InferenceServerException: {e}") \ No newline at end of file diff --git a/dolfo/packages/utils.py b/dolfo/packages/utils.py new file mode 100644 index 00000000..ee3554c4 --- /dev/null +++ b/dolfo/packages/utils.py @@ -0,0 +1,81 @@ +import subprocess +from pathlib import Path + +from constants import TENSORRT_LLM_MODEL_REPOSITORY_PATH, GRPC_SERVICE_PORT, HTTP_SERVICE_PORT +from huggingface_hub import snapshot_download + +import socket +def move_all_files(src: Path, dest: Path) -> None: + """ + Moves all files from `src` to `dest` recursively. + """ + for item in src.iterdir(): + dest_item = dest / item.name + if item.is_dir(): + dest_item.mkdir(parents=True, exist_ok=True) + move_all_files(item, dest_item) + else: + item.rename(dest_item) + + +def prepare_model_repository(data_dir: Path) -> None: + # Ensure the destination directory exists + dest_dir = TENSORRT_LLM_MODEL_REPOSITORY_PATH / "tensorrt_llm" / "1" + dest_dir.mkdir(parents=True, exist_ok=True) + + # Ensure empty version directory for `ensemble` model exists + ensemble_dir = TENSORRT_LLM_MODEL_REPOSITORY_PATH / "ensemble" / "1" + ensemble_dir.mkdir(parents=True, exist_ok=True) + + # Move all files and directories from data_dir to dest_dir + move_all_files(data_dir, dest_dir) + + +def download_engine(engine_repository: str, fp: Path, auth_token=None): + """ + Downloads the specified engine from Hugging Face Hub. + """ + snapshot_download( + engine_repository, + local_dir=fp, + local_dir_use_symlinks=False, + max_workers=4, + **({"use_auth_token": auth_token} if auth_token is not None else {}), + ) + + +def execute_command(command) -> None: + try: + process = subprocess.run(command, capture_output=True, text=True, check=True) + print("Standard Output:\n", process.stdout) + except FileNotFoundError: + raise FileNotFoundError( + f"The command '{command[0]}' is not found. Make sure it is installed and in your PATH." + ) + + + +def server_loaded(): + def port_is_available(port): + available = False + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + try: + sock.bind(("0.0.0.0", port)) + available = True + except: + pass + return available + + return not port_is_available(GRPC_SERVICE_PORT) or not port_is_available( + HTTP_SERVICE_PORT + ) + + +def server_loaded_file_approach(): + FILE_LOC = "/packages/worker.txt" + if Path(FILE_LOC).exists(): + return True + else: + Path(FILE_LOC).touch() + return False + \ No newline at end of file From 985689b3c241e921eb4e6d75194e793fde046ca9 Mon Sep 17 00:00:00 2001 From: Abu Qader <48742992+aspctu@users.noreply.github.com> Date: Wed, 27 Mar 2024 20:25:44 +0000 Subject: [PATCH 2/3] fix --- mistral/dolfo/config.yaml | 43 +++ mistral/dolfo/model/__init__.py | 0 mistral/dolfo/model/model.py | 135 +++++++++ mistral/dolfo/packages/constants.py | 9 + mistral/dolfo/packages/schema.py | 155 +++++++++++ .../ensemble/config.pbtxt | 246 +++++++++++++++++ .../postprocessing/1/model.py | 205 ++++++++++++++ .../postprocessing/config.pbtxt | 64 +++++ .../preprocessing/1/model.py | 260 ++++++++++++++++++ .../preprocessing/config.pbtxt | 99 +++++++ .../tensorrt_llm/config.pbtxt | 208 ++++++++++++++ mistral/dolfo/packages/triton_client.py | 136 +++++++++ mistral/dolfo/packages/utils.py | 55 ++++ 13 files changed, 1615 insertions(+) create mode 100644 mistral/dolfo/config.yaml create mode 100644 mistral/dolfo/model/__init__.py create mode 100644 mistral/dolfo/model/model.py create mode 100644 mistral/dolfo/packages/constants.py create mode 100644 mistral/dolfo/packages/schema.py create mode 100644 mistral/dolfo/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt create mode 100644 mistral/dolfo/packages/tensorrt_llm_model_repository/postprocessing/1/model.py create mode 100644 mistral/dolfo/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt create mode 100644 mistral/dolfo/packages/tensorrt_llm_model_repository/preprocessing/1/model.py create mode 100644 mistral/dolfo/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt create mode 100644 mistral/dolfo/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt create mode 100644 mistral/dolfo/packages/triton_client.py create mode 100644 mistral/dolfo/packages/utils.py diff --git a/mistral/dolfo/config.yaml b/mistral/dolfo/config.yaml new file mode 100644 index 00000000..eb59ab85 --- /dev/null +++ b/mistral/dolfo/config.yaml @@ -0,0 +1,43 @@ +apply_library_patches: true +base_image: + image: nvcr.io/nvidia/tritonserver:24.03-trtllm-python-py3 + python_executable_path: /usr/bin/python3 +bundled_packages_dir: packages +data_dir: data +description: Generate text from a prompt with this seven billion parameter language + model. +environment_variables: {} +examples_filename: examples.yaml +external_data: null +external_package_dirs: [] +input_type: Any +live_reload: false +model_class_filename: model.py +model_class_name: Model +model_framework: custom +trt_llm: + serve: + engine_repository: baseten/dolphin_i6000_o1024_bs96_tp8-tllm_0.9.0.dev2024032600 + pipeline_parallel_count: 1 + tensor_parallel_count: 8 + tokenizer_repository: cognitivecomputations/dolphin-2.6-mixtral-8x7b +model_metadata: + engine_repository: baseten/dolphin_i6000_o1024_bs96_tp8-tllm_0.9.0.dev2024032600 + tags: + - text-generation + - openai-compatible +model_module_dir: model +model_name: dolfo-new +model_type: Model +python_version: py311 +requirements: +- tritonclient[all] +- transformers +- jinja2 +resources: + accelerator: H100:8 + use_gpu: true +runtime: + num_workers: 1 + predict_concurrency: 1000 +secrets: {} diff --git a/mistral/dolfo/model/__init__.py b/mistral/dolfo/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mistral/dolfo/model/model.py b/mistral/dolfo/model/model.py new file mode 100644 index 00000000..2ddfaec9 --- /dev/null +++ b/mistral/dolfo/model/model.py @@ -0,0 +1,135 @@ +import os +from itertools import count + +from constants import ( + GRPC_SERVICE_PORT, + HF_AUTH_KEY_CONSTANT, + HTTP_SERVICE_PORT, + TOKENIZER_KEY_CONSTANT, +) +from schema import ModelInput +from transformers import AutoTokenizer +from triton_client import TritonClient, TritonServer +from utils import execute_command + +APPEND_ASSISTANT_TEMPLATE_TO_PROMPT = True +APPEND_ASSISTANT_TEMPLATE_TO_PROMPT_STR = "<|im_start|>assistant" +STOP_TOKEN = "<|im_end|>" + +class Model: + def __init__(self, data_dir, config, secrets): + self._data_dir = data_dir + self._config = config + self._secrets = secrets + self._request_id_counter = count(start=1) + self.triton_client = None + self.triton_server = None + self.tokenizer = None + self.uses_openai_api = None + + def load(self): + execute_command(["ldconfig"]) + # trtllm_config = TrussTRTLLMConfiguration(**self._config.get("trt_llm", {})) + trtllm_config = self._config.get("trt_llm", {}) + self.uses_openai_api = "openai-compatible" in self._config.get( + "model_metadata", {} + ).get("tags", []) + hf_access_token = None + if "hf_access_token" in self._secrets._base_secrets.keys(): + hf_access_token = self._secrets["hf_access_token"] + + self.triton_server = TritonServer( + grpc_port=GRPC_SERVICE_PORT, + http_port=HTTP_SERVICE_PORT, + ) + + engine_repository_path = trtllm_config["serve"]["engine_repository"] + tokenizer_repository = trtllm_config["serve"]["tokenizer_repository"] + tensor_parallel_count = trtllm_config["serve"]["tensor_parallel_count"] + pipeline_parallel_count = trtllm_config["serve"]["pipeline_parallel_count"] + world_size = tensor_parallel_count * pipeline_parallel_count + + # if not trtllm_config.requires_build: + # engine_repository_path = trtllm_config.serve.engine_repository + # tokenizer_repository = trtllm_config.serve.tokenizer_repository + # tensor_parallel_count = trtllm_config.serve.tensor_parallel_count + # pipeline_parallel_count = trtllm_config.serve.pipeline_parallel_count + # world_size = tensor_parallel_count * pipeline_parallel_count + # else: + # engine_repository_path = None + # tokenizer_repository = trtllm_config.build.huggingface_ckpt_repository + # tensor_parallel_count = trtllm_config.build.tensor_parallel_count + # pipeline_parallel_count = trtllm_config.build.pipeline_parallel_count + # world_size = tensor_parallel_count * pipeline_parallel_count + + self.triton_server.create_model_repository( + truss_data_dir=self._data_dir, + engine_repository_path=engine_repository_path, + huggingface_auth_token=hf_access_token, + ) + + env = {} + if hf_access_token: + env[HF_AUTH_KEY_CONSTANT] = hf_access_token + env[TOKENIZER_KEY_CONSTANT] = tokenizer_repository + + self.triton_server.start( + world_size=world_size, + env=env, + ) + + self.triton_client = TritonClient( + grpc_service_port=GRPC_SERVICE_PORT, + ) + + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_repository, token=hf_access_token + ) + self.eos_token_id = self.tokenizer.eos_token_id + + async def predict(self, model_input): + if model_input.get("max_tokens") is None: + model_input["max_tokens"] = 500 + + if model_input.get("max_new_tokens") is None: + model_input["max_new_tokens"] = 500 + + model_input["request_id"] = str(os.getpid()) + str( + next(self._request_id_counter) + ) + model_input["eos_token_id"] = self.eos_token_id + messages = model_input.get("messages", []) + if "messages" in model_input: + del model_input["messages"] + prompt = model_input.get("prompt", None) + if not prompt and messages == []: + raise ValueError("Prompt or messages must be provided") + + if self.uses_openai_api and not prompt: + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + ) + model_input["prompt"] = prompt + + if APPEND_ASSISTANT_TEMPLATE_TO_PROMPT: + model_input["prompt"] = f"{model_input['prompt']}{APPEND_ASSISTANT_TEMPLATE_TO_PROMPT_STR}" + + self.triton_client.start_grpc_stream() + model_input = ModelInput(**model_input) + result_iterator = self.triton_client.infer(model_input) + + async def generate(): + async for result in result_iterator: + if result != STOP_TOKEN: + yield result + else: + yield "" + + if model_input.stream: + return generate() + else: + if self.uses_openai_api: + return "".join(generate()) + else: + return {"text": "".join(generate())} \ No newline at end of file diff --git a/mistral/dolfo/packages/constants.py b/mistral/dolfo/packages/constants.py new file mode 100644 index 00000000..1f19e806 --- /dev/null +++ b/mistral/dolfo/packages/constants.py @@ -0,0 +1,9 @@ +from pathlib import Path + +# If changing model repo path, please updated inside tensorrt_llm config.pbtxt as well +TENSORRT_LLM_MODEL_REPOSITORY_PATH = Path("/packages/tensorrt_llm_model_repository/") +GRPC_SERVICE_PORT = 8001 +HTTP_SERVICE_PORT = 8003 +HF_AUTH_KEY_CONSTANT = "HUGGING_FACE_HUB_TOKEN" +TOKENIZER_KEY_CONSTANT = "TRITON_TOKENIZER_REPOSITORY" +ENTRYPOINT_MODEL_NAME = "ensemble" diff --git a/mistral/dolfo/packages/schema.py b/mistral/dolfo/packages/schema.py new file mode 100644 index 00000000..4847fcb0 --- /dev/null +++ b/mistral/dolfo/packages/schema.py @@ -0,0 +1,155 @@ +from typing import Optional + +import numpy as np +import tritonclient +import tritonclient.grpc.aio as grpcclient + + +class ModelInput: + def __init__( + self, + prompt: str, + request_id: int, + max_tokens: int = 50, + max_new_tokens: int = 50, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 50, + beam_width: int = 1, + bad_words_list: Optional[list] = None, + stop_words_list: Optional[list] = None, + repetition_penalty: float = 1.0, + ignore_eos: bool = False, + stream: bool = True, + eos_token_id: int = None, # type: ignore + ) -> None: + self.stream = stream + self.request_id = request_id + self._prompt = prompt + self._max_tokens = max_tokens + self._beam_width = beam_width + self._bad_words_list = [""] if bad_words_list is None else bad_words_list + self._stop_words_list = [""] if stop_words_list is None else stop_words_list + self._repetition_penalty = repetition_penalty + self._eos_token_id = eos_token_id + self._ignore_eos = ignore_eos + # These variables are passed by OAI proxy but are unused + # TODO(Abu): Add support for these + self._max_new_tokens = max_new_tokens + self._temperature = temperature + self._top_p = top_p + self._top_k = top_k + + def _prepare_grpc_tensor( + self, name: str, input_data: np.ndarray + ) -> grpcclient.InferInput: + tensor = grpcclient.InferInput( + name, + input_data.shape, + tritonclient.utils.np_to_triton_dtype(input_data.dtype), + ) + tensor.set_data_from_numpy(input_data) + return tensor + + def to_tensors(self): + if self._eos_token_id is None and self._ignore_eos: + raise ValueError("eos_token_id is required when ignore_eos is True") + + prompt_data = np.array([[self._prompt]], dtype=object) + output_len_data = np.ones_like(prompt_data, dtype=np.uint32) * self._max_tokens + bad_words_data = np.array([self._bad_words_list], dtype=object) + stop_words_data = np.array([self._stop_words_list], dtype=object) + stream_data = np.array([[self.stream]], dtype=bool) + beam_width_data = np.array([[self._beam_width]], dtype=np.uint32) + repetition_penalty_data = np.array( + [[self._repetition_penalty]], dtype=np.float32 + ) + + inputs = [ + self._prepare_grpc_tensor("text_input", prompt_data), + self._prepare_grpc_tensor("max_tokens", output_len_data), + self._prepare_grpc_tensor("bad_words", bad_words_data), + self._prepare_grpc_tensor("stop_words", stop_words_data), + self._prepare_grpc_tensor("stream", stream_data), + self._prepare_grpc_tensor("beam_width", beam_width_data), + self._prepare_grpc_tensor("repetition_penalty", repetition_penalty_data), + ] + + if not self._ignore_eos: + end_id_data = np.array([[self._eos_token_id]], dtype=np.uint32) + inputs.append(self._prepare_grpc_tensor("end_id", end_id_data)) + + return inputs + + +# The following are duplicated from the underlying base image. +# We list them as a comment for posterity: +# +# class TRTLLMModelArchitecture(Enum): +# LLAMA: str = "llama" +# MISTRAL: str = "mistral" +# DEEPSEEK: str = "deepseek" + + +# class TRTLLMQuantizationType(Enum): +# NO_QUANT: str = "no_quant" +# WEIGHTS_ONLY_INT8: str = "weights_int8" +# WEIGHTS_KV_INT8: str = "weights_kv_int8" +# WEIGHTS_ONLY_INT4: str = "weights_int4" +# WEIGHTS_KV_INT4: str = "weights_kv_int4" +# SMOOTH_QUANT: str = "smooth_quant" +# FP8: str = "fp8" +# FP8_KV: str = "fp8_kv" + +# class TrussTRTLLMPluginConfiguration(BaseModel): +# multi_block_mode: bool = False +# paged_kv_cache: bool = True +# use_fused_mlp: bool = False + +# class TrussTRTLLMBuildConfiguration(BaseModel): +# base_model_architecture: TRTLLMModelArchitecture +# max_input_len: int +# max_output_len: int +# max_batch_size: int +# max_beam_width: int +# max_prompt_embedding_table_size: int = 0 +# huggingface_ckpt_repository: Optional[str] +# gather_all_token_logits: bool = False +# strongly_typed: bool = False +# quantization_type: TRTLLMQuantizationType = TRTLLMQuantizationType.NO_QUANT +# tensor_parallel_count: int = 1 +# pipeline_parallel_count: int = 1 +# plugin_configuration: TrussTRTLLMPluginConfiguration = TrussTRTLLMPluginConfiguration() + +# class TrussTRTLLMServingConfiguration(BaseModel): +# engine_repository: str +# tokenizer_repository: str +# tensor_parallel_count: int = 1 +# pipeline_parallel_count: int = 1 + +# class TrussTRTLLMConfiguration(BaseModel): +# serve: Optional[TrussTRTLLMServingConfiguration] = None +# build: Optional[TrussTRTLLMBuildConfiguration] = None + +# @model_validator(mode="after") +# def check_minimum_required_configuration(self): +# if not self.serve and not self.build: +# raise ValueError( +# "Either serve or build configurations must be provided" +# ) +# if self.serve and self.build: +# raise ValueError( +# "Both serve and build configurations cannot be provided" +# ) +# if self.serve is not None: +# if (self.serve.engine_repository is None) ^ (self.serve.tokenizer_repository is None): +# raise ValueError( +# "Both engine_repository and tokenizer_repository must be provided" +# ) +# return self + +# @property +# def requires_build(self): +# if self.build is not None: +# return True +# return False \ No newline at end of file diff --git a/mistral/dolfo/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt b/mistral/dolfo/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt new file mode 100644 index 00000000..618098de --- /dev/null +++ b/mistral/dolfo/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt @@ -0,0 +1,246 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "ensemble" +platform: "ensemble" +max_batch_size: 2048 +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "max_tokens" + data_type: TYPE_UINT32 + dims: [ -1 ] + }, + { + name: "bad_words" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "stop_words" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "length_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + optional: true + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + } +] +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1, -1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocessing" + model_version: -1 + input_map { + key: "QUERY" + value: "text_input" + } + input_map { + key: "REQUEST_OUTPUT_LEN" + value: "max_tokens" + } + input_map { + key: "BAD_WORDS_DICT" + value: "bad_words" + } + input_map { + key: "STOP_WORDS_DICT" + value: "stop_words" + } + output_map { + key: "REQUEST_INPUT_LEN" + value: "_REQUEST_INPUT_LEN" + } + output_map { + key: "INPUT_ID" + value: "_INPUT_ID" + } + output_map { + key: "REQUEST_OUTPUT_LEN" + value: "_REQUEST_OUTPUT_LEN" + } + }, + { + model_name: "tensorrt_llm" + model_version: -1 + input_map { + key: "input_ids" + value: "_INPUT_ID" + } + input_map { + key: "input_lengths" + value: "_REQUEST_INPUT_LEN" + } + input_map { + key: "request_output_len" + value: "_REQUEST_OUTPUT_LEN" + } + input_map { + key: "end_id" + value: "end_id" + } + input_map { + key: "pad_id" + value: "pad_id" + } + input_map { + key: "runtime_top_k" + value: "top_k" + } + input_map { + key: "runtime_top_p" + value: "top_p" + } + input_map { + key: "temperature" + value: "temperature" + } + input_map { + key: "len_penalty" + value: "length_penalty" + } + input_map { + key: "repetition_penalty" + value: "repetition_penalty" + } + input_map { + key: "min_length" + value: "min_length" + } + input_map { + key: "presence_penalty" + value: "presence_penalty" + } + input_map { + key: "random_seed" + value: "random_seed" + } + input_map { + key: "beam_width" + value: "beam_width" + } + input_map { + key: "streaming" + value: "stream" + } + output_map { + key: "output_ids" + value: "_TOKENS_BATCH" + } + }, + { + model_name: "postprocessing" + model_version: -1 + input_map { + key: "TOKENS_BATCH" + value: "_TOKENS_BATCH" + } + output_map { + key: "OUTPUT" + value: "text_output" + } + } + ] +} diff --git a/mistral/dolfo/packages/tensorrt_llm_model_repository/postprocessing/1/model.py b/mistral/dolfo/packages/tensorrt_llm_model_repository/postprocessing/1/model.py new file mode 100644 index 00000000..ff7ab4ad --- /dev/null +++ b/mistral/dolfo/packages/tensorrt_llm_model_repository/postprocessing/1/model.py @@ -0,0 +1,205 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import os +from collections import OrderedDict + +import numpy as np +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + +# https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/strip.rs#L8 +INVALID_UNICODE_CHAR = "�" + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args["model_config"]) + # NOTE: Keep this in sync with the truss model.py variable + tokenizer_dir = os.environ["TRITON_TOKENIZER_REPOSITORY"] + tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] + + if tokenizer_type == "t5": + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") + elif tokenizer_type == "auto": + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_dir, padding_side="left" + ) + elif tokenizer_type == "llama": + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side="left" + ) + else: + raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}") + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Parse model output configs + output_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT") + # Convert Triton types to numpy types + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + self.state_dict = OrderedDict() + # TODO(pankaj) This should come from the batch size + self.cache_size = 2048 + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for idx, request in enumerate(requests): + # Get request ID + request_id = request.request_id() + + # Get input tensors + tokens_batch = ( + pb_utils.get_input_tensor_by_name(request, "TOKENS_BATCH") + .as_numpy() + .flatten() + ) + + if len(tokens_batch) == 0: + continue + + # Postprocess output data + prev_token = self._get_var(request_id, "prev_token") + token_buffer = self._get_var(request_id, "token_buffer") + token_buffer = token_buffer if token_buffer is not None else [] + current_tokens = np.concatenate( + (np.array(token_buffer, dtype=int), tokens_batch), dtype=int + ) + current_tokens_decoded = self.tokenizer.decode(current_tokens) + + if len(current_tokens_decoded) == 0: + responses.append(pb_utils.InferenceResponse()) + continue + + if current_tokens_decoded[-1] == INVALID_UNICODE_CHAR: + # If the last token is invalid, we need to keep it in the buffer + # for the next request to see if this is a multi-token unicode + # character. + self._store_var(request_id, "token_buffer", current_tokens) + responses.append(pb_utils.InferenceResponse()) + continue + + if prev_token is None: + delta = current_tokens_decoded + else: + # TODO(pankaj) Figure out how to make tokenizer.decode not + # ignore initial whitespace so we can avoid this hack. + # Get string with and without previous token and diff. This hack + # is needed because tokenizer.decode strips initial whitespace. + old_string = self.tokenizer.decode(prev_token) + with_prev_token = np.concatenate((prev_token, current_tokens)) + new_string = self.tokenizer.decode(with_prev_token) + delta = self._compute_delta(old_string, new_string) + + # The previous token is the last character of the decoded sequence + # which includes the multi-token unicode character. + self._store_var(request_id, "prev_token", current_tokens) + self._store_var(request_id, "token_buffer", None) + + # Create output tensor + output_tensor = pb_utils.Tensor( + "OUTPUT", np.array([delta]).astype(self.output_dtype) + ) + inference_response = pb_utils.InferenceResponse( + output_tensors=[output_tensor] + ) + responses.append(inference_response) + + return responses + + def finalize(self): + print("Cleaning up...") + + def _store_var(self, request_id, var_name, var): + if request_id in self.state_dict: + self.state_dict[request_id][var_name] = var + self.state_dict.move_to_end(request_id) + else: + if len(self.state_dict) > self.cache_size: + self.state_dict.popitem(last=False) + self.state_dict[request_id] = {"prev_token": None, "token_buffer": None} + self.state_dict[request_id][var_name] = var + + def _get_var(self, request_id, var_name): + if request_id in self.state_dict: + return self.state_dict[request_id][var_name] + return None + + def _compute_delta(self, prev_str, new_str): + delta = "".join( + [ + char + for index, char in enumerate(new_str) + if index >= len(prev_str) or char != prev_str[index] + ] + ) + return delta + + def _postprocessing(self, tokens): + decoded_tokens = self.tokenizer.decode(tokens) + return decoded_tokens \ No newline at end of file diff --git a/mistral/dolfo/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt b/mistral/dolfo/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt new file mode 100644 index 00000000..854ef960 --- /dev/null +++ b/mistral/dolfo/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt @@ -0,0 +1,64 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "postprocessing" +backend: "python" +max_batch_size: 2048 +input [ + { + name: "TOKENS_BATCH" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +output [ + { + name: "OUTPUT" + data_type: TYPE_STRING + dims: [ -1, -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "NousResearch/Llama-2-7b-hf" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "auto" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/mistral/dolfo/packages/tensorrt_llm_model_repository/preprocessing/1/model.py b/mistral/dolfo/packages/tensorrt_llm_model_repository/preprocessing/1/model.py new file mode 100644 index 00000000..fa4dcc2c --- /dev/null +++ b/mistral/dolfo/packages/tensorrt_llm_model_repository/preprocessing/1/model.py @@ -0,0 +1,260 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import csv +import json +import os +from typing import List + +import numpy as np +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args["model_config"]) + # NOTE: Keep this in sync with the truss model.py variable + tokenizer_dir = os.environ["TRITON_TOKENIZER_REPOSITORY"] + tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] + self.add_special_tokens = model_config["parameters"].get( + "add_special_tokens", {"string_value": "false"} + )["string_value"].lower() in ["true", "1", "t", "y", "yes"] + + if tokenizer_type == "t5": + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") + elif tokenizer_type == "auto": + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_dir, padding_side="left" + ) + elif tokenizer_type == "llama": + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side="left" + ) + else: + raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}") + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.pad_id = self.tokenizer.encode( + self.tokenizer.pad_token, add_special_tokens=False + )[0] + + # Parse model output configs and convert Triton types to numpy types + input_names = [ + "INPUT_ID", + "REQUEST_INPUT_LEN", + "BAD_WORDS_IDS", + "STOP_WORDS_IDS", + ] + for input_name in input_names: + setattr( + self, + input_name.lower() + "_dtype", + pb_utils.triton_string_to_numpy( + pb_utils.get_output_config_by_name(model_config, input_name)[ + "data_type" + ] + ), + ) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for idx, request in enumerate(requests): + # Get input tensors + query = pb_utils.get_input_tensor_by_name(request, "QUERY").as_numpy() + request_output_len = pb_utils.get_input_tensor_by_name( + request, "REQUEST_OUTPUT_LEN" + ).as_numpy() + + bad_words_dict = pb_utils.get_input_tensor_by_name( + request, "BAD_WORDS_DICT" + ).as_numpy() + stop_words_dict = pb_utils.get_input_tensor_by_name( + request, "STOP_WORDS_DICT" + ).as_numpy() + + # Preprocessing input data. + input_id, request_input_len = self._create_request(query) + bad_words = self._to_word_list_format(bad_words_dict) + stop_words = self._to_word_list_format(stop_words_dict) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + input_id_tensor = pb_utils.Tensor( + "INPUT_ID", np.array(input_id).astype(self.input_id_dtype) + ) + request_input_len_tensor = pb_utils.Tensor( + "REQUEST_INPUT_LEN", + np.array(request_input_len).astype(self.request_input_len_dtype), + ) + request_output_len_tensor = pb_utils.Tensor( + "REQUEST_OUTPUT_LEN", request_output_len + ) + bad_words_ids_tensor = pb_utils.Tensor("BAD_WORDS_IDS", bad_words) + stop_words_ids_tensor = pb_utils.Tensor("STOP_WORDS_IDS", stop_words) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[ + input_id_tensor, + bad_words_ids_tensor, + stop_words_ids_tensor, + request_input_len_tensor, + request_output_len_tensor, + ] + ) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print("Cleaning up...") + + def _create_request(self, query): + """ + query : batch string (2D numpy array) + """ + start_ids = [ + np.array( + self.tokenizer.encode( + s[0].decode(), add_special_tokens=self.add_special_tokens + ) + ).astype(int) + for s in query + ] + start_lengths = np.array([[len(ids)] for ids in start_ids]).astype(int) + + max_len = 0 + for seq in start_ids: + max_len = max(max_len, seq.shape[0]) + start_ids = np.stack( + [ + np.pad( + seq, + (0, max_len - seq.shape[0]), + "constant", + constant_values=(0, self.pad_id), + ) + for seq in start_ids + ] + ) + + return start_ids, start_lengths + + def _to_word_list_format(self, word_dict: List[List[str]]): + """ + format of word_dict + len(word_dict) should be same to batch_size + word_dict[i] means the words for batch i + len(word_dict[i]) must be 1, which means it only contains 1 string + This string can contains several sentences and split by ",". + For example, if word_dict[2] = " I am happy, I am sad", then this function will return + the ids for two short sentences " I am happy" and " I am sad". + """ + assert self.tokenizer is not None, "need to set tokenizer" + + flat_ids = [] + offsets = [] + for word_dict_item in word_dict: + item_flat_ids = [] + item_offsets = [] + + if isinstance(word_dict_item[0], bytes): + word_dict_item = [word_dict_item[0].decode()] + + words = list(csv.reader(word_dict_item))[0] + for word in words: + ids = self.tokenizer.encode(word) + + if len(ids) == 0: + continue + + item_flat_ids += ids + item_offsets.append(len(ids)) + + flat_ids.append(np.array(item_flat_ids)) + offsets.append(np.cumsum(np.array(item_offsets))) + + pad_to = max(1, max(len(ids) for ids in flat_ids)) + + for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): + flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0) + offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) + + return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2)) diff --git a/mistral/dolfo/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt b/mistral/dolfo/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt new file mode 100644 index 00000000..1fb88012 --- /dev/null +++ b/mistral/dolfo/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt @@ -0,0 +1,99 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "preprocessing" +backend: "python" +max_batch_size: 2048 +input [ + { + name: "QUERY" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "BAD_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "STOP_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] +output [ + { + name: "INPUT_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "REQUEST_INPUT_LEN" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "BAD_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "STOP_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "NousResearch/Llama-2-7b-hf" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "auto" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/mistral/dolfo/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt b/mistral/dolfo/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt new file mode 100644 index 00000000..2a77b769 --- /dev/null +++ b/mistral/dolfo/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt @@ -0,0 +1,208 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm" +backend: "tensorrtllm" +max_batch_size: 2048 + +model_transaction_policy { + decoupled: True +} + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "request_output_len" + data_type: TYPE_UINT32 + dims: [ 1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "streaming" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + } +] +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] +parameters: { + key: "max_beam_width" + value: { + string_value: "1" + } +} +parameters: { + key: "FORCE_CPU_ONLY_INPUT_TENSORS" + value: { + string_value: "no" + } +} +parameters: { + key: "gpt_model_type" + value: { + string_value: "inflight_fused_batching" + } +} +parameters: { + key: "gpt_model_path" + value: { + string_value: "/packages/tensorrt_llm_model_repository/tensorrt_llm/1" + } +} +parameters: { + key: "max_tokens_in_paged_kv_cache" + value: { + string_value: "100000" + } +} +parameters: { + key: "batch_scheduler_policy" + value: { + string_value: "max_utilization" + } +} +parameters: { + key: "kv_cache_free_gpu_mem_fraction" + value: { + string_value: "0.9" + } +} +parameters: { + key: "max_num_sequences" + value: { + string_value: "8092" + } +} +parameters: { + key: "enable_trt_overlap" + value: { + string_value: "True" + } +} diff --git a/mistral/dolfo/packages/triton_client.py b/mistral/dolfo/packages/triton_client.py new file mode 100644 index 00000000..80f7a6a3 --- /dev/null +++ b/mistral/dolfo/packages/triton_client.py @@ -0,0 +1,136 @@ +import json +import os +import subprocess +import time +from pathlib import Path +from typing import AsyncGenerator, Optional + +import tritonclient.grpc.aio as grpcclient +import tritonclient.http as httpclient +from constants import ( + ENTRYPOINT_MODEL_NAME, + GRPC_SERVICE_PORT, + TENSORRT_LLM_MODEL_REPOSITORY_PATH, +) +from schema import ModelInput +from utils import download_engine, prepare_model_repository + + +class TritonServer: + def __init__(self, grpc_port: int = 8001, http_port: int = 8003): + self.grpc_port = grpc_port + self.http_port = http_port + self._server_process = None + + def create_model_repository( + self, + truss_data_dir: Path, + engine_repository_path: Optional[str] = None, + huggingface_auth_token: Optional[str] = None, + ) -> None: + if engine_repository_path: + download_engine( + engine_repository=engine_repository_path, + fp=truss_data_dir, + auth_token=huggingface_auth_token, + ) + prepare_model_repository(truss_data_dir) + return + + def start(self, world_size: int = 1, env: dict = {}) -> None: + mpirun_command = ["mpirun", "--allow-run-as-root"] + mpi_commands = [] + for i in range(world_size): + mpi_command = [ + "-n", + "1", + "tritonserver", + f"--model-repository={TENSORRT_LLM_MODEL_REPOSITORY_PATH}", + f"--grpc-port={str(self.grpc_port)}", + f"--http-port={str(self.http_port)}", + "--disable-auto-complete-config", + f"--backend-config=python,shm-region-prefix-name=prefix{i}_", + ":", + ] + + mpi_commands.extend(mpi_command) + command = mpirun_command + mpi_commands + + self._server_process = subprocess.Popen( # type: ignore + command, + env={**os.environ, **env}, + ) + while not self.is_alive and not self.is_ready: + time.sleep(2) + return + + def stop(self): + if self._server_process: + if self.is_ready: + self._server_process.kill() + self._server_process = None + return + + @property + def is_alive(self) -> bool: + try: + http_client = httpclient.InferenceServerClient( + url=f"localhost:{self.http_port}", verbose=False + ) + return http_client.is_server_live() + except ConnectionRefusedError: + return False + + @property + def is_ready(self) -> bool: + try: + http_client = httpclient.InferenceServerClient( + url=f"localhost:{self.http_port}", verbose=False + ) + return http_client.is_model_ready(model_name=ENTRYPOINT_MODEL_NAME) + except ConnectionRefusedError: + return False + + +class TritonClient: + def __init__(self, grpc_service_port: int = GRPC_SERVICE_PORT): + self.grpc_service_port = grpc_service_port + self._grpc_client = None + + def start_grpc_stream(self) -> grpcclient.InferenceServerClient: + if self._grpc_client: + return self._grpc_client + + self._grpc_client = grpcclient.InferenceServerClient( + url=f"localhost:{self.grpc_service_port}", verbose=False + ) + return self._grpc_client + + async def infer( + self, model_input: ModelInput, model_name="ensemble" + ) -> AsyncGenerator[str, None]: + grpc_client_instance = self.start_grpc_stream() + inputs = model_input.to_tensors() + + async def input_generator(): + yield { + "model_name": model_name, + "inputs": inputs, + "request_id": model_input.request_id, + } + + response_iterator = grpc_client_instance.stream_infer( + inputs_iterator=input_generator(), + ) + + try: + async for response in response_iterator: + result, error = response + if result: + result = result.as_numpy("text_output") + yield result[0].decode("utf-8") + else: + yield json.dumps({"status": "error", "message": error.message()}) + + except grpcclient.InferenceServerException as e: + print(f"InferenceServerException: {e}") \ No newline at end of file diff --git a/mistral/dolfo/packages/utils.py b/mistral/dolfo/packages/utils.py new file mode 100644 index 00000000..eda10338 --- /dev/null +++ b/mistral/dolfo/packages/utils.py @@ -0,0 +1,55 @@ +import subprocess +from pathlib import Path + +from constants import TENSORRT_LLM_MODEL_REPOSITORY_PATH +from huggingface_hub import snapshot_download + + +def move_all_files(src: Path, dest: Path) -> None: + """ + Moves all files from `src` to `dest` recursively. + """ + for item in src.iterdir(): + dest_item = dest / item.name + if item.is_dir(): + dest_item.mkdir(parents=True, exist_ok=True) + move_all_files(item, dest_item) + else: + item.rename(dest_item) + + +def prepare_model_repository(data_dir: Path) -> None: + # Ensure the destination directory exists + dest_dir = TENSORRT_LLM_MODEL_REPOSITORY_PATH / "tensorrt_llm" / "1" + dest_dir.mkdir(parents=True, exist_ok=True) + + # Ensure empty version directory for `ensemble` model exists + ensemble_dir = TENSORRT_LLM_MODEL_REPOSITORY_PATH / "ensemble" / "1" + ensemble_dir.mkdir(parents=True, exist_ok=True) + + # Move all files and directories from data_dir to dest_dir + move_all_files(data_dir, dest_dir) + + +def download_engine(engine_repository: str, fp: Path, auth_token=None): + """ + Downloads the specified engine from Hugging Face Hub. + """ + snapshot_download( + engine_repository, + local_dir=fp, + local_dir_use_symlinks=False, + max_workers=4, + **({"use_auth_token": auth_token} if auth_token is not None else {}), + ) + + +def execute_command(command) -> None: + try: + process = subprocess.run(command, capture_output=True, text=True, check=True) + print("Standard Output:\n", process.stdout) + except FileNotFoundError: + raise FileNotFoundError( + f"The command '{command[0]}' is not found. Make sure it is installed and in your PATH." + ) + From c5bfe147d4aeb430184914d2e8763a9619bda2ef Mon Sep 17 00:00:00 2001 From: Abu Qader <48742992+aspctu@users.noreply.github.com> Date: Wed, 27 Mar 2024 20:29:22 +0000 Subject: [PATCH 3/3] rm --- dolfo/config.yaml | 43 --- dolfo/model/__init__.py | 0 dolfo/model/model.py | 139 ---------- dolfo/packages/build_engine_utils.py | 24 -- dolfo/packages/constants.py | 9 - dolfo/packages/schema.py | 155 ----------- .../ensemble/config.pbtxt | 246 ----------------- .../postprocessing/1/model.py | 205 -------------- .../postprocessing/config.pbtxt | 64 ----- .../preprocessing/1/model.py | 260 ------------------ .../preprocessing/config.pbtxt | 99 ------- .../tensorrt_llm/config.pbtxt | 208 -------------- dolfo/packages/triton_client.py | 136 --------- dolfo/packages/utils.py | 81 ------ 14 files changed, 1669 deletions(-) delete mode 100644 dolfo/config.yaml delete mode 100644 dolfo/model/__init__.py delete mode 100644 dolfo/model/model.py delete mode 100644 dolfo/packages/build_engine_utils.py delete mode 100644 dolfo/packages/constants.py delete mode 100644 dolfo/packages/schema.py delete mode 100644 dolfo/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt delete mode 100644 dolfo/packages/tensorrt_llm_model_repository/postprocessing/1/model.py delete mode 100644 dolfo/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt delete mode 100644 dolfo/packages/tensorrt_llm_model_repository/preprocessing/1/model.py delete mode 100644 dolfo/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt delete mode 100644 dolfo/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt delete mode 100644 dolfo/packages/triton_client.py delete mode 100644 dolfo/packages/utils.py diff --git a/dolfo/config.yaml b/dolfo/config.yaml deleted file mode 100644 index caa655dd..00000000 --- a/dolfo/config.yaml +++ /dev/null @@ -1,43 +0,0 @@ -apply_library_patches: true -base_image: - image: baseten/trtllm-build-server:r23.12_baseten_v0.9.0_20240305 - python_executable_path: /usr/bin/python3 -bundled_packages_dir: packages -data_dir: data -description: Generate text from a prompt with this seven billion parameter language - model. -environment_variables: {} -examples_filename: examples.yaml -external_data: null -external_package_dirs: [] -input_type: Any -live_reload: false -model_class_filename: model.py -model_class_name: Model -model_framework: custom -trt_llm: - serve: - engine_repository: baseten/dolphin_i6000_o1024_bs96_tp8-tllm_0.9.0.dev2024022000 - pipeline_parallel_count: 1 - tensor_parallel_count: 8 - tokenizer_repository: cognitivecomputations/dolphin-2.6-mixtral-8x7b -model_metadata: - engine_repository: baseten/dolphin_i6000_o1024_bs96_tp8-tllm_0.9.0.dev2024022000 - tags: - - text-generation - - openai-compatible -model_module_dir: model -model_name: Dolphin Mixtral TP2 - TP8 Num Workers 1 -model_type: Model -python_version: py311 -requirements: -- tritonclient[all] -- transformers -- jinja2 -resources: - accelerator: H100:8 - use_gpu: true -runtime: - num_workers: 1 - predict_concurrency: 1000 -secrets: {} diff --git a/dolfo/model/__init__.py b/dolfo/model/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/dolfo/model/model.py b/dolfo/model/model.py deleted file mode 100644 index 961e0adb..00000000 --- a/dolfo/model/model.py +++ /dev/null @@ -1,139 +0,0 @@ -import os -from itertools import count - -import build_engine_utils -from builder.types import TrussTRTLLMConfiguration -from constants import ( - GRPC_SERVICE_PORT, - HF_AUTH_KEY_CONSTANT, - HTTP_SERVICE_PORT, - TOKENIZER_KEY_CONSTANT, -) -from schema import ModelInput -from transformers import AutoTokenizer -from triton_client import TritonClient, TritonServer -from utils import execute_command, server_loaded_file_approach - -APPEND_ASSISTANT_TEMPLATE_TO_PROMPT = True -APPEND_ASSISTANT_TEMPLATE_TO_PROMPT_STR = "<|im_start|>assistant" -STOP_TOKEN = "<|im_end|>" - -class Model: - def __init__(self, data_dir, config, secrets): - self._data_dir = data_dir - self._config = config - self._secrets = secrets - self._request_id_counter = count(start=1) - self.triton_client = None - self.triton_server = None - self.tokenizer = None - self.uses_openai_api = None - - def load(self): - execute_command(["ldconfig"]) - trtllm_config = TrussTRTLLMConfiguration(**self._config.get("trt_llm", {})) - self.uses_openai_api = "openai-compatible" in self._config.get( - "model_metadata", {} - ).get("tags", []) - hf_access_token = None - if "hf_access_token" in self._secrets._base_secrets.keys(): - hf_access_token = self._secrets["hf_access_token"] - - # TODO(Abu): Move to pre-runtime - # if trtllm_config.requires_build: - # build_engine_utils.build_engine_from_config_args( - # truss_trtllm_configuration=trtllm_config, - # checkpoint_dir_path=None, - # dst=self._data_dir, - # ) - - self.triton_server = TritonServer( - grpc_port=GRPC_SERVICE_PORT, - http_port=HTTP_SERVICE_PORT, - ) - - if not trtllm_config.requires_build: - engine_repository_path = trtllm_config.serve.engine_repository - tokenizer_repository = trtllm_config.serve.tokenizer_repository - tensor_parallel_count = trtllm_config.serve.tensor_parallel_count - pipeline_parallel_count = trtllm_config.serve.pipeline_parallel_count - world_size = tensor_parallel_count * pipeline_parallel_count - else: - engine_repository_path = None - tokenizer_repository = trtllm_config.build.huggingface_ckpt_repository - tensor_parallel_count = trtllm_config.build.tensor_parallel_count - pipeline_parallel_count = trtllm_config.build.pipeline_parallel_count - world_size = tensor_parallel_count * pipeline_parallel_count - - if not server_loaded_file_approach(): - self.triton_server.create_model_repository( - truss_data_dir=self._data_dir, - engine_repository_path=engine_repository_path, - huggingface_auth_token=hf_access_token, - ) - - env = {} - if hf_access_token: - env[HF_AUTH_KEY_CONSTANT] = hf_access_token - env[TOKENIZER_KEY_CONSTANT] = tokenizer_repository - - self.triton_server.start( - world_size=world_size, - env=env, - ) - - self.triton_client = TritonClient( - grpc_service_port=GRPC_SERVICE_PORT, - ) - - self.tokenizer = AutoTokenizer.from_pretrained( - tokenizer_repository, token=hf_access_token - ) - self.eos_token_id = self.tokenizer.eos_token_id - - async def predict(self, model_input): - if model_input.get("max_tokens") is None: - model_input["max_tokens"] = 500 - - if model_input.get("max_new_tokens") is None: - model_input["max_new_tokens"] = 500 - - model_input["request_id"] = str(os.getpid()) + str( - next(self._request_id_counter) - ) - model_input["eos_token_id"] = self.eos_token_id - messages = model_input.get("messages", []) - if "messages" in model_input: - del model_input["messages"] - prompt = model_input.get("prompt", None) - if not prompt and messages == []: - raise ValueError("Prompt or messages must be provided") - - if self.uses_openai_api and not prompt: - prompt = self.tokenizer.apply_chat_template( - messages, - tokenize=False, - ) - model_input["prompt"] = prompt - - if APPEND_ASSISTANT_TEMPLATE_TO_PROMPT: - model_input["prompt"] = f"{model_input['prompt']}{APPEND_ASSISTANT_TEMPLATE_TO_PROMPT_STR}" - - self.triton_client.start_grpc_stream() - model_input = ModelInput(**model_input) - result_iterator = self.triton_client.infer(model_input) - - async def generate(): - async for result in result_iterator: - if result != STOP_TOKEN: - yield result - else: - yield "" - - if model_input.stream: - return generate() - else: - if self.uses_openai_api: - return "".join(generate()) - else: - return {"text": "".join(generate())} diff --git a/dolfo/packages/build_engine_utils.py b/dolfo/packages/build_engine_utils.py deleted file mode 100644 index 900abae6..00000000 --- a/dolfo/packages/build_engine_utils.py +++ /dev/null @@ -1,24 +0,0 @@ -from pathlib import Path -from typing import Optional - -from builder.types import TrussTRTLLMConfiguration - - -def build_engine_from_config_args( - truss_trtllm_configuration: TrussTRTLLMConfiguration, - dst: Path, - checkpoint_dir_path: Optional[Path] = None, -): - # NOTE: These are provided by the underlying base image - # TODO(Abu): Remove this when we have a better way of handling this - from builder.main import build_engine - - build_engine( - engine_configuration=truss_trtllm_configuration, - engine_serialization_path=dst, - # If checkpoint_dir_path is provided, we'll look there for the - # weight files. If not, we will attempt to use the `huggingface_ckpt_repository` - # key in the `truss_trtllm_configuration` to download the weights. - checkpoint_dir_path=checkpoint_dir_path, - ) - return dst \ No newline at end of file diff --git a/dolfo/packages/constants.py b/dolfo/packages/constants.py deleted file mode 100644 index 1f19e806..00000000 --- a/dolfo/packages/constants.py +++ /dev/null @@ -1,9 +0,0 @@ -from pathlib import Path - -# If changing model repo path, please updated inside tensorrt_llm config.pbtxt as well -TENSORRT_LLM_MODEL_REPOSITORY_PATH = Path("/packages/tensorrt_llm_model_repository/") -GRPC_SERVICE_PORT = 8001 -HTTP_SERVICE_PORT = 8003 -HF_AUTH_KEY_CONSTANT = "HUGGING_FACE_HUB_TOKEN" -TOKENIZER_KEY_CONSTANT = "TRITON_TOKENIZER_REPOSITORY" -ENTRYPOINT_MODEL_NAME = "ensemble" diff --git a/dolfo/packages/schema.py b/dolfo/packages/schema.py deleted file mode 100644 index 4847fcb0..00000000 --- a/dolfo/packages/schema.py +++ /dev/null @@ -1,155 +0,0 @@ -from typing import Optional - -import numpy as np -import tritonclient -import tritonclient.grpc.aio as grpcclient - - -class ModelInput: - def __init__( - self, - prompt: str, - request_id: int, - max_tokens: int = 50, - max_new_tokens: int = 50, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = 50, - beam_width: int = 1, - bad_words_list: Optional[list] = None, - stop_words_list: Optional[list] = None, - repetition_penalty: float = 1.0, - ignore_eos: bool = False, - stream: bool = True, - eos_token_id: int = None, # type: ignore - ) -> None: - self.stream = stream - self.request_id = request_id - self._prompt = prompt - self._max_tokens = max_tokens - self._beam_width = beam_width - self._bad_words_list = [""] if bad_words_list is None else bad_words_list - self._stop_words_list = [""] if stop_words_list is None else stop_words_list - self._repetition_penalty = repetition_penalty - self._eos_token_id = eos_token_id - self._ignore_eos = ignore_eos - # These variables are passed by OAI proxy but are unused - # TODO(Abu): Add support for these - self._max_new_tokens = max_new_tokens - self._temperature = temperature - self._top_p = top_p - self._top_k = top_k - - def _prepare_grpc_tensor( - self, name: str, input_data: np.ndarray - ) -> grpcclient.InferInput: - tensor = grpcclient.InferInput( - name, - input_data.shape, - tritonclient.utils.np_to_triton_dtype(input_data.dtype), - ) - tensor.set_data_from_numpy(input_data) - return tensor - - def to_tensors(self): - if self._eos_token_id is None and self._ignore_eos: - raise ValueError("eos_token_id is required when ignore_eos is True") - - prompt_data = np.array([[self._prompt]], dtype=object) - output_len_data = np.ones_like(prompt_data, dtype=np.uint32) * self._max_tokens - bad_words_data = np.array([self._bad_words_list], dtype=object) - stop_words_data = np.array([self._stop_words_list], dtype=object) - stream_data = np.array([[self.stream]], dtype=bool) - beam_width_data = np.array([[self._beam_width]], dtype=np.uint32) - repetition_penalty_data = np.array( - [[self._repetition_penalty]], dtype=np.float32 - ) - - inputs = [ - self._prepare_grpc_tensor("text_input", prompt_data), - self._prepare_grpc_tensor("max_tokens", output_len_data), - self._prepare_grpc_tensor("bad_words", bad_words_data), - self._prepare_grpc_tensor("stop_words", stop_words_data), - self._prepare_grpc_tensor("stream", stream_data), - self._prepare_grpc_tensor("beam_width", beam_width_data), - self._prepare_grpc_tensor("repetition_penalty", repetition_penalty_data), - ] - - if not self._ignore_eos: - end_id_data = np.array([[self._eos_token_id]], dtype=np.uint32) - inputs.append(self._prepare_grpc_tensor("end_id", end_id_data)) - - return inputs - - -# The following are duplicated from the underlying base image. -# We list them as a comment for posterity: -# -# class TRTLLMModelArchitecture(Enum): -# LLAMA: str = "llama" -# MISTRAL: str = "mistral" -# DEEPSEEK: str = "deepseek" - - -# class TRTLLMQuantizationType(Enum): -# NO_QUANT: str = "no_quant" -# WEIGHTS_ONLY_INT8: str = "weights_int8" -# WEIGHTS_KV_INT8: str = "weights_kv_int8" -# WEIGHTS_ONLY_INT4: str = "weights_int4" -# WEIGHTS_KV_INT4: str = "weights_kv_int4" -# SMOOTH_QUANT: str = "smooth_quant" -# FP8: str = "fp8" -# FP8_KV: str = "fp8_kv" - -# class TrussTRTLLMPluginConfiguration(BaseModel): -# multi_block_mode: bool = False -# paged_kv_cache: bool = True -# use_fused_mlp: bool = False - -# class TrussTRTLLMBuildConfiguration(BaseModel): -# base_model_architecture: TRTLLMModelArchitecture -# max_input_len: int -# max_output_len: int -# max_batch_size: int -# max_beam_width: int -# max_prompt_embedding_table_size: int = 0 -# huggingface_ckpt_repository: Optional[str] -# gather_all_token_logits: bool = False -# strongly_typed: bool = False -# quantization_type: TRTLLMQuantizationType = TRTLLMQuantizationType.NO_QUANT -# tensor_parallel_count: int = 1 -# pipeline_parallel_count: int = 1 -# plugin_configuration: TrussTRTLLMPluginConfiguration = TrussTRTLLMPluginConfiguration() - -# class TrussTRTLLMServingConfiguration(BaseModel): -# engine_repository: str -# tokenizer_repository: str -# tensor_parallel_count: int = 1 -# pipeline_parallel_count: int = 1 - -# class TrussTRTLLMConfiguration(BaseModel): -# serve: Optional[TrussTRTLLMServingConfiguration] = None -# build: Optional[TrussTRTLLMBuildConfiguration] = None - -# @model_validator(mode="after") -# def check_minimum_required_configuration(self): -# if not self.serve and not self.build: -# raise ValueError( -# "Either serve or build configurations must be provided" -# ) -# if self.serve and self.build: -# raise ValueError( -# "Both serve and build configurations cannot be provided" -# ) -# if self.serve is not None: -# if (self.serve.engine_repository is None) ^ (self.serve.tokenizer_repository is None): -# raise ValueError( -# "Both engine_repository and tokenizer_repository must be provided" -# ) -# return self - -# @property -# def requires_build(self): -# if self.build is not None: -# return True -# return False \ No newline at end of file diff --git a/dolfo/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt b/dolfo/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt deleted file mode 100644 index 618098de..00000000 --- a/dolfo/packages/tensorrt_llm_model_repository/ensemble/config.pbtxt +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -name: "ensemble" -platform: "ensemble" -max_batch_size: 2048 -input [ - { - name: "text_input" - data_type: TYPE_STRING - dims: [ -1 ] - }, - { - name: "max_tokens" - data_type: TYPE_UINT32 - dims: [ -1 ] - }, - { - name: "bad_words" - data_type: TYPE_STRING - dims: [ -1 ] - }, - { - name: "stop_words" - data_type: TYPE_STRING - dims: [ -1 ] - }, - { - name: "end_id" - data_type: TYPE_UINT32 - dims: [ 1 ] - optional: true - }, - { - name: "pad_id" - data_type: TYPE_UINT32 - dims: [ 1 ] - optional: true - }, - { - name: "top_k" - data_type: TYPE_UINT32 - dims: [ 1 ] - optional: true - }, - { - name: "top_p" - data_type: TYPE_FP32 - dims: [ 1 ] - optional: true - }, - { - name: "temperature" - data_type: TYPE_FP32 - dims: [ 1 ] - optional: true - }, - { - name: "length_penalty" - data_type: TYPE_FP32 - dims: [ 1 ] - optional: true - }, - { - name: "repetition_penalty" - data_type: TYPE_FP32 - dims: [ 1 ] - optional: true - }, - { - name: "min_length" - data_type: TYPE_UINT32 - dims: [ 1 ] - optional: true - }, - { - name: "presence_penalty" - data_type: TYPE_FP32 - dims: [ 1 ] - optional: true - }, - { - name: "random_seed" - data_type: TYPE_UINT64 - dims: [ 1 ] - optional: true - }, - { - name: "beam_width" - data_type: TYPE_UINT32 - dims: [ 1 ] - optional: true - }, - { - name: "stream" - data_type: TYPE_BOOL - dims: [ 1 ] - optional: true - } -] -output [ - { - name: "text_output" - data_type: TYPE_STRING - dims: [ -1, -1 ] - } -] -ensemble_scheduling { - step [ - { - model_name: "preprocessing" - model_version: -1 - input_map { - key: "QUERY" - value: "text_input" - } - input_map { - key: "REQUEST_OUTPUT_LEN" - value: "max_tokens" - } - input_map { - key: "BAD_WORDS_DICT" - value: "bad_words" - } - input_map { - key: "STOP_WORDS_DICT" - value: "stop_words" - } - output_map { - key: "REQUEST_INPUT_LEN" - value: "_REQUEST_INPUT_LEN" - } - output_map { - key: "INPUT_ID" - value: "_INPUT_ID" - } - output_map { - key: "REQUEST_OUTPUT_LEN" - value: "_REQUEST_OUTPUT_LEN" - } - }, - { - model_name: "tensorrt_llm" - model_version: -1 - input_map { - key: "input_ids" - value: "_INPUT_ID" - } - input_map { - key: "input_lengths" - value: "_REQUEST_INPUT_LEN" - } - input_map { - key: "request_output_len" - value: "_REQUEST_OUTPUT_LEN" - } - input_map { - key: "end_id" - value: "end_id" - } - input_map { - key: "pad_id" - value: "pad_id" - } - input_map { - key: "runtime_top_k" - value: "top_k" - } - input_map { - key: "runtime_top_p" - value: "top_p" - } - input_map { - key: "temperature" - value: "temperature" - } - input_map { - key: "len_penalty" - value: "length_penalty" - } - input_map { - key: "repetition_penalty" - value: "repetition_penalty" - } - input_map { - key: "min_length" - value: "min_length" - } - input_map { - key: "presence_penalty" - value: "presence_penalty" - } - input_map { - key: "random_seed" - value: "random_seed" - } - input_map { - key: "beam_width" - value: "beam_width" - } - input_map { - key: "streaming" - value: "stream" - } - output_map { - key: "output_ids" - value: "_TOKENS_BATCH" - } - }, - { - model_name: "postprocessing" - model_version: -1 - input_map { - key: "TOKENS_BATCH" - value: "_TOKENS_BATCH" - } - output_map { - key: "OUTPUT" - value: "text_output" - } - } - ] -} diff --git a/dolfo/packages/tensorrt_llm_model_repository/postprocessing/1/model.py b/dolfo/packages/tensorrt_llm_model_repository/postprocessing/1/model.py deleted file mode 100644 index ff7ab4ad..00000000 --- a/dolfo/packages/tensorrt_llm_model_repository/postprocessing/1/model.py +++ /dev/null @@ -1,205 +0,0 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import json -import os -from collections import OrderedDict - -import numpy as np -import triton_python_backend_utils as pb_utils -from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer - -# https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/strip.rs#L8 -INVALID_UNICODE_CHAR = "�" - - -class TritonPythonModel: - """Your Python model must use the same class name. Every Python model - that is created must have "TritonPythonModel" as the class name. - """ - - def initialize(self, args): - """`initialize` is called only once when the model is being loaded. - Implementing `initialize` function is optional. This function allows - the model to initialize any state associated with this model. - Parameters - ---------- - args : dict - Both keys and values are strings. The dictionary keys and values are: - * model_config: A JSON string containing the model configuration - * model_instance_kind: A string containing model instance kind - * model_instance_device_id: A string containing model instance device ID - * model_repository: Model repository path - * model_version: Model version - * model_name: Model name - """ - # Parse model configs - model_config = json.loads(args["model_config"]) - # NOTE: Keep this in sync with the truss model.py variable - tokenizer_dir = os.environ["TRITON_TOKENIZER_REPOSITORY"] - tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] - - if tokenizer_type == "t5": - self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") - elif tokenizer_type == "auto": - self.tokenizer = AutoTokenizer.from_pretrained( - tokenizer_dir, padding_side="left" - ) - elif tokenizer_type == "llama": - self.tokenizer = LlamaTokenizer.from_pretrained( - tokenizer_dir, legacy=False, padding_side="left" - ) - else: - raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}") - self.tokenizer.pad_token = self.tokenizer.eos_token - - # Parse model output configs - output_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT") - # Convert Triton types to numpy types - self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) - - self.state_dict = OrderedDict() - # TODO(pankaj) This should come from the batch size - self.cache_size = 2048 - - def execute(self, requests): - """`execute` must be implemented in every Python model. `execute` - function receives a list of pb_utils.InferenceRequest as the only - argument. This function is called when an inference is requested - for this model. Depending on the batching configuration (e.g. Dynamic - Batching) used, `requests` may contain multiple requests. Every - Python model, must create one pb_utils.InferenceResponse for every - pb_utils.InferenceRequest in `requests`. If there is an error, you can - set the error argument when creating a pb_utils.InferenceResponse. - Parameters - ---------- - requests : list - A list of pb_utils.InferenceRequest - Returns - ------- - list - A list of pb_utils.InferenceResponse. The length of this list must - be the same as `requests` - """ - - responses = [] - - # Every Python backend must iterate over everyone of the requests - # and create a pb_utils.InferenceResponse for each of them. - for idx, request in enumerate(requests): - # Get request ID - request_id = request.request_id() - - # Get input tensors - tokens_batch = ( - pb_utils.get_input_tensor_by_name(request, "TOKENS_BATCH") - .as_numpy() - .flatten() - ) - - if len(tokens_batch) == 0: - continue - - # Postprocess output data - prev_token = self._get_var(request_id, "prev_token") - token_buffer = self._get_var(request_id, "token_buffer") - token_buffer = token_buffer if token_buffer is not None else [] - current_tokens = np.concatenate( - (np.array(token_buffer, dtype=int), tokens_batch), dtype=int - ) - current_tokens_decoded = self.tokenizer.decode(current_tokens) - - if len(current_tokens_decoded) == 0: - responses.append(pb_utils.InferenceResponse()) - continue - - if current_tokens_decoded[-1] == INVALID_UNICODE_CHAR: - # If the last token is invalid, we need to keep it in the buffer - # for the next request to see if this is a multi-token unicode - # character. - self._store_var(request_id, "token_buffer", current_tokens) - responses.append(pb_utils.InferenceResponse()) - continue - - if prev_token is None: - delta = current_tokens_decoded - else: - # TODO(pankaj) Figure out how to make tokenizer.decode not - # ignore initial whitespace so we can avoid this hack. - # Get string with and without previous token and diff. This hack - # is needed because tokenizer.decode strips initial whitespace. - old_string = self.tokenizer.decode(prev_token) - with_prev_token = np.concatenate((prev_token, current_tokens)) - new_string = self.tokenizer.decode(with_prev_token) - delta = self._compute_delta(old_string, new_string) - - # The previous token is the last character of the decoded sequence - # which includes the multi-token unicode character. - self._store_var(request_id, "prev_token", current_tokens) - self._store_var(request_id, "token_buffer", None) - - # Create output tensor - output_tensor = pb_utils.Tensor( - "OUTPUT", np.array([delta]).astype(self.output_dtype) - ) - inference_response = pb_utils.InferenceResponse( - output_tensors=[output_tensor] - ) - responses.append(inference_response) - - return responses - - def finalize(self): - print("Cleaning up...") - - def _store_var(self, request_id, var_name, var): - if request_id in self.state_dict: - self.state_dict[request_id][var_name] = var - self.state_dict.move_to_end(request_id) - else: - if len(self.state_dict) > self.cache_size: - self.state_dict.popitem(last=False) - self.state_dict[request_id] = {"prev_token": None, "token_buffer": None} - self.state_dict[request_id][var_name] = var - - def _get_var(self, request_id, var_name): - if request_id in self.state_dict: - return self.state_dict[request_id][var_name] - return None - - def _compute_delta(self, prev_str, new_str): - delta = "".join( - [ - char - for index, char in enumerate(new_str) - if index >= len(prev_str) or char != prev_str[index] - ] - ) - return delta - - def _postprocessing(self, tokens): - decoded_tokens = self.tokenizer.decode(tokens) - return decoded_tokens \ No newline at end of file diff --git a/dolfo/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt b/dolfo/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt deleted file mode 100644 index 854ef960..00000000 --- a/dolfo/packages/tensorrt_llm_model_repository/postprocessing/config.pbtxt +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -name: "postprocessing" -backend: "python" -max_batch_size: 2048 -input [ - { - name: "TOKENS_BATCH" - data_type: TYPE_INT32 - dims: [ -1, -1 ] - } -] -output [ - { - name: "OUTPUT" - data_type: TYPE_STRING - dims: [ -1, -1 ] - } -] - -parameters { - key: "tokenizer_dir" - value: { - string_value: "NousResearch/Llama-2-7b-hf" - } -} - -parameters { - key: "tokenizer_type" - value: { - string_value: "auto" - } -} - -instance_group [ - { - count: 1 - kind: KIND_CPU - } -] diff --git a/dolfo/packages/tensorrt_llm_model_repository/preprocessing/1/model.py b/dolfo/packages/tensorrt_llm_model_repository/preprocessing/1/model.py deleted file mode 100644 index fa4dcc2c..00000000 --- a/dolfo/packages/tensorrt_llm_model_repository/preprocessing/1/model.py +++ /dev/null @@ -1,260 +0,0 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import csv -import json -import os -from typing import List - -import numpy as np -import triton_python_backend_utils as pb_utils -from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer - - -class TritonPythonModel: - """Your Python model must use the same class name. Every Python model - that is created must have "TritonPythonModel" as the class name. - """ - - def initialize(self, args): - """`initialize` is called only once when the model is being loaded. - Implementing `initialize` function is optional. This function allows - the model to initialize any state associated with this model. - Parameters - ---------- - args : dict - Both keys and values are strings. The dictionary keys and values are: - * model_config: A JSON string containing the model configuration - * model_instance_kind: A string containing model instance kind - * model_instance_device_id: A string containing model instance device ID - * model_repository: Model repository path - * model_version: Model version - * model_name: Model name - """ - # Parse model configs - model_config = json.loads(args["model_config"]) - # NOTE: Keep this in sync with the truss model.py variable - tokenizer_dir = os.environ["TRITON_TOKENIZER_REPOSITORY"] - tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] - self.add_special_tokens = model_config["parameters"].get( - "add_special_tokens", {"string_value": "false"} - )["string_value"].lower() in ["true", "1", "t", "y", "yes"] - - if tokenizer_type == "t5": - self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") - elif tokenizer_type == "auto": - self.tokenizer = AutoTokenizer.from_pretrained( - tokenizer_dir, padding_side="left" - ) - elif tokenizer_type == "llama": - self.tokenizer = LlamaTokenizer.from_pretrained( - tokenizer_dir, legacy=False, padding_side="left" - ) - else: - raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}") - self.tokenizer.pad_token = self.tokenizer.eos_token - - self.pad_id = self.tokenizer.encode( - self.tokenizer.pad_token, add_special_tokens=False - )[0] - - # Parse model output configs and convert Triton types to numpy types - input_names = [ - "INPUT_ID", - "REQUEST_INPUT_LEN", - "BAD_WORDS_IDS", - "STOP_WORDS_IDS", - ] - for input_name in input_names: - setattr( - self, - input_name.lower() + "_dtype", - pb_utils.triton_string_to_numpy( - pb_utils.get_output_config_by_name(model_config, input_name)[ - "data_type" - ] - ), - ) - - def execute(self, requests): - """`execute` must be implemented in every Python model. `execute` - function receives a list of pb_utils.InferenceRequest as the only - argument. This function is called when an inference is requested - for this model. Depending on the batching configuration (e.g. Dynamic - Batching) used, `requests` may contain multiple requests. Every - Python model, must create one pb_utils.InferenceResponse for every - pb_utils.InferenceRequest in `requests`. If there is an error, you can - set the error argument when creating a pb_utils.InferenceResponse. - Parameters - ---------- - requests : list - A list of pb_utils.InferenceRequest - Returns - ------- - list - A list of pb_utils.InferenceResponse. The length of this list must - be the same as `requests` - """ - - responses = [] - - # Every Python backend must iterate over everyone of the requests - # and create a pb_utils.InferenceResponse for each of them. - for idx, request in enumerate(requests): - # Get input tensors - query = pb_utils.get_input_tensor_by_name(request, "QUERY").as_numpy() - request_output_len = pb_utils.get_input_tensor_by_name( - request, "REQUEST_OUTPUT_LEN" - ).as_numpy() - - bad_words_dict = pb_utils.get_input_tensor_by_name( - request, "BAD_WORDS_DICT" - ).as_numpy() - stop_words_dict = pb_utils.get_input_tensor_by_name( - request, "STOP_WORDS_DICT" - ).as_numpy() - - # Preprocessing input data. - input_id, request_input_len = self._create_request(query) - bad_words = self._to_word_list_format(bad_words_dict) - stop_words = self._to_word_list_format(stop_words_dict) - - # Create output tensors. You need pb_utils.Tensor - # objects to create pb_utils.InferenceResponse. - input_id_tensor = pb_utils.Tensor( - "INPUT_ID", np.array(input_id).astype(self.input_id_dtype) - ) - request_input_len_tensor = pb_utils.Tensor( - "REQUEST_INPUT_LEN", - np.array(request_input_len).astype(self.request_input_len_dtype), - ) - request_output_len_tensor = pb_utils.Tensor( - "REQUEST_OUTPUT_LEN", request_output_len - ) - bad_words_ids_tensor = pb_utils.Tensor("BAD_WORDS_IDS", bad_words) - stop_words_ids_tensor = pb_utils.Tensor("STOP_WORDS_IDS", stop_words) - - # Create InferenceResponse. You can set an error here in case - # there was a problem with handling this inference request. - # Below is an example of how you can set errors in inference - # response: - # - # pb_utils.InferenceResponse( - # output_tensors=..., TritonError("An error occurred")) - inference_response = pb_utils.InferenceResponse( - output_tensors=[ - input_id_tensor, - bad_words_ids_tensor, - stop_words_ids_tensor, - request_input_len_tensor, - request_output_len_tensor, - ] - ) - responses.append(inference_response) - - # You should return a list of pb_utils.InferenceResponse. Length - # of this list must match the length of `requests` list. - return responses - - def finalize(self): - """`finalize` is called only once when the model is being unloaded. - Implementing `finalize` function is optional. This function allows - the model to perform any necessary clean ups before exit. - """ - print("Cleaning up...") - - def _create_request(self, query): - """ - query : batch string (2D numpy array) - """ - start_ids = [ - np.array( - self.tokenizer.encode( - s[0].decode(), add_special_tokens=self.add_special_tokens - ) - ).astype(int) - for s in query - ] - start_lengths = np.array([[len(ids)] for ids in start_ids]).astype(int) - - max_len = 0 - for seq in start_ids: - max_len = max(max_len, seq.shape[0]) - start_ids = np.stack( - [ - np.pad( - seq, - (0, max_len - seq.shape[0]), - "constant", - constant_values=(0, self.pad_id), - ) - for seq in start_ids - ] - ) - - return start_ids, start_lengths - - def _to_word_list_format(self, word_dict: List[List[str]]): - """ - format of word_dict - len(word_dict) should be same to batch_size - word_dict[i] means the words for batch i - len(word_dict[i]) must be 1, which means it only contains 1 string - This string can contains several sentences and split by ",". - For example, if word_dict[2] = " I am happy, I am sad", then this function will return - the ids for two short sentences " I am happy" and " I am sad". - """ - assert self.tokenizer is not None, "need to set tokenizer" - - flat_ids = [] - offsets = [] - for word_dict_item in word_dict: - item_flat_ids = [] - item_offsets = [] - - if isinstance(word_dict_item[0], bytes): - word_dict_item = [word_dict_item[0].decode()] - - words = list(csv.reader(word_dict_item))[0] - for word in words: - ids = self.tokenizer.encode(word) - - if len(ids) == 0: - continue - - item_flat_ids += ids - item_offsets.append(len(ids)) - - flat_ids.append(np.array(item_flat_ids)) - offsets.append(np.cumsum(np.array(item_offsets))) - - pad_to = max(1, max(len(ids) for ids in flat_ids)) - - for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): - flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0) - offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) - - return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2)) diff --git a/dolfo/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt b/dolfo/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt deleted file mode 100644 index 1fb88012..00000000 --- a/dolfo/packages/tensorrt_llm_model_repository/preprocessing/config.pbtxt +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -name: "preprocessing" -backend: "python" -max_batch_size: 2048 -input [ - { - name: "QUERY" - data_type: TYPE_STRING - dims: [ -1 ] - }, - { - name: "BAD_WORDS_DICT" - data_type: TYPE_STRING - dims: [ -1 ] - }, - { - name: "STOP_WORDS_DICT" - data_type: TYPE_STRING - dims: [ -1 ] - }, - { - name: "REQUEST_OUTPUT_LEN" - data_type: TYPE_UINT32 - dims: [ -1 ] - } -] -output [ - { - name: "INPUT_ID" - data_type: TYPE_INT32 - dims: [ -1 ] - }, - { - name: "REQUEST_INPUT_LEN" - data_type: TYPE_INT32 - dims: [ 1 ] - }, - { - name: "BAD_WORDS_IDS" - data_type: TYPE_INT32 - dims: [ 2, -1 ] - }, - { - name: "STOP_WORDS_IDS" - data_type: TYPE_INT32 - dims: [ 2, -1 ] - }, - { - name: "REQUEST_OUTPUT_LEN" - data_type: TYPE_UINT32 - dims: [ -1 ] - } -] - -parameters { - key: "tokenizer_dir" - value: { - string_value: "NousResearch/Llama-2-7b-hf" - } -} - -parameters { - key: "tokenizer_type" - value: { - string_value: "auto" - } -} - -instance_group [ - { - count: 1 - kind: KIND_CPU - } -] diff --git a/dolfo/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt b/dolfo/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt deleted file mode 100644 index 75cb6718..00000000 --- a/dolfo/packages/tensorrt_llm_model_repository/tensorrt_llm/config.pbtxt +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -name: "tensorrt_llm" -backend: "tensorrtllm" -max_batch_size: 2048 - -model_transaction_policy { - decoupled: True -} - -input [ - { - name: "input_ids" - data_type: TYPE_INT32 - dims: [ -1 ] - }, - { - name: "input_lengths" - data_type: TYPE_INT32 - dims: [ 1 ] - reshape: { shape: [ ] } - }, - { - name: "request_output_len" - data_type: TYPE_UINT32 - dims: [ 1 ] - }, - { - name: "end_id" - data_type: TYPE_UINT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "pad_id" - data_type: TYPE_UINT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "beam_width" - data_type: TYPE_UINT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "temperature" - data_type: TYPE_FP32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "runtime_top_k" - data_type: TYPE_UINT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "runtime_top_p" - data_type: TYPE_FP32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "len_penalty" - data_type: TYPE_FP32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "repetition_penalty" - data_type: TYPE_FP32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "min_length" - data_type: TYPE_UINT32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "presence_penalty" - data_type: TYPE_FP32 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "random_seed" - data_type: TYPE_UINT64 - dims: [ 1 ] - reshape: { shape: [ ] } - optional: true - }, - { - name: "stop" - data_type: TYPE_BOOL - dims: [ 1 ] - optional: true - }, - { - name: "streaming" - data_type: TYPE_BOOL - dims: [ 1 ] - optional: true - } -] -output [ - { - name: "output_ids" - data_type: TYPE_INT32 - dims: [ -1, -1 ] - } -] -instance_group [ - { - count: 1 - kind : KIND_CPU - } -] -parameters: { - key: "max_beam_width" - value: { - string_value: "1" - } -} -parameters: { - key: "FORCE_CPU_ONLY_INPUT_TENSORS" - value: { - string_value: "no" - } -} -parameters: { - key: "gpt_model_type" - value: { - string_value: "inflight_fused_batching" - } -} -parameters: { - key: "gpt_model_path" - value: { - string_value: "/packages/tensorrt_llm_model_repository/tensorrt_llm/1" - } -} -parameters: { - key: "max_tokens_in_paged_kv_cache" - value: { - string_value: "100000" - } -} -parameters: { - key: "batch_scheduler_policy" - value: { - string_value: "max_utilization" - } -} -parameters: { - key: "kv_cache_free_gpu_mem_fraction" - value: { - string_value: "0.9" - } -} -parameters: { - key: "max_num_sequences" - value: { - string_value: "2048" - } -} -parameters: { - key: "enable_trt_overlap" - value: { - string_value: "False" - } -} diff --git a/dolfo/packages/triton_client.py b/dolfo/packages/triton_client.py deleted file mode 100644 index 80f7a6a3..00000000 --- a/dolfo/packages/triton_client.py +++ /dev/null @@ -1,136 +0,0 @@ -import json -import os -import subprocess -import time -from pathlib import Path -from typing import AsyncGenerator, Optional - -import tritonclient.grpc.aio as grpcclient -import tritonclient.http as httpclient -from constants import ( - ENTRYPOINT_MODEL_NAME, - GRPC_SERVICE_PORT, - TENSORRT_LLM_MODEL_REPOSITORY_PATH, -) -from schema import ModelInput -from utils import download_engine, prepare_model_repository - - -class TritonServer: - def __init__(self, grpc_port: int = 8001, http_port: int = 8003): - self.grpc_port = grpc_port - self.http_port = http_port - self._server_process = None - - def create_model_repository( - self, - truss_data_dir: Path, - engine_repository_path: Optional[str] = None, - huggingface_auth_token: Optional[str] = None, - ) -> None: - if engine_repository_path: - download_engine( - engine_repository=engine_repository_path, - fp=truss_data_dir, - auth_token=huggingface_auth_token, - ) - prepare_model_repository(truss_data_dir) - return - - def start(self, world_size: int = 1, env: dict = {}) -> None: - mpirun_command = ["mpirun", "--allow-run-as-root"] - mpi_commands = [] - for i in range(world_size): - mpi_command = [ - "-n", - "1", - "tritonserver", - f"--model-repository={TENSORRT_LLM_MODEL_REPOSITORY_PATH}", - f"--grpc-port={str(self.grpc_port)}", - f"--http-port={str(self.http_port)}", - "--disable-auto-complete-config", - f"--backend-config=python,shm-region-prefix-name=prefix{i}_", - ":", - ] - - mpi_commands.extend(mpi_command) - command = mpirun_command + mpi_commands - - self._server_process = subprocess.Popen( # type: ignore - command, - env={**os.environ, **env}, - ) - while not self.is_alive and not self.is_ready: - time.sleep(2) - return - - def stop(self): - if self._server_process: - if self.is_ready: - self._server_process.kill() - self._server_process = None - return - - @property - def is_alive(self) -> bool: - try: - http_client = httpclient.InferenceServerClient( - url=f"localhost:{self.http_port}", verbose=False - ) - return http_client.is_server_live() - except ConnectionRefusedError: - return False - - @property - def is_ready(self) -> bool: - try: - http_client = httpclient.InferenceServerClient( - url=f"localhost:{self.http_port}", verbose=False - ) - return http_client.is_model_ready(model_name=ENTRYPOINT_MODEL_NAME) - except ConnectionRefusedError: - return False - - -class TritonClient: - def __init__(self, grpc_service_port: int = GRPC_SERVICE_PORT): - self.grpc_service_port = grpc_service_port - self._grpc_client = None - - def start_grpc_stream(self) -> grpcclient.InferenceServerClient: - if self._grpc_client: - return self._grpc_client - - self._grpc_client = grpcclient.InferenceServerClient( - url=f"localhost:{self.grpc_service_port}", verbose=False - ) - return self._grpc_client - - async def infer( - self, model_input: ModelInput, model_name="ensemble" - ) -> AsyncGenerator[str, None]: - grpc_client_instance = self.start_grpc_stream() - inputs = model_input.to_tensors() - - async def input_generator(): - yield { - "model_name": model_name, - "inputs": inputs, - "request_id": model_input.request_id, - } - - response_iterator = grpc_client_instance.stream_infer( - inputs_iterator=input_generator(), - ) - - try: - async for response in response_iterator: - result, error = response - if result: - result = result.as_numpy("text_output") - yield result[0].decode("utf-8") - else: - yield json.dumps({"status": "error", "message": error.message()}) - - except grpcclient.InferenceServerException as e: - print(f"InferenceServerException: {e}") \ No newline at end of file diff --git a/dolfo/packages/utils.py b/dolfo/packages/utils.py deleted file mode 100644 index ee3554c4..00000000 --- a/dolfo/packages/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -import subprocess -from pathlib import Path - -from constants import TENSORRT_LLM_MODEL_REPOSITORY_PATH, GRPC_SERVICE_PORT, HTTP_SERVICE_PORT -from huggingface_hub import snapshot_download - -import socket -def move_all_files(src: Path, dest: Path) -> None: - """ - Moves all files from `src` to `dest` recursively. - """ - for item in src.iterdir(): - dest_item = dest / item.name - if item.is_dir(): - dest_item.mkdir(parents=True, exist_ok=True) - move_all_files(item, dest_item) - else: - item.rename(dest_item) - - -def prepare_model_repository(data_dir: Path) -> None: - # Ensure the destination directory exists - dest_dir = TENSORRT_LLM_MODEL_REPOSITORY_PATH / "tensorrt_llm" / "1" - dest_dir.mkdir(parents=True, exist_ok=True) - - # Ensure empty version directory for `ensemble` model exists - ensemble_dir = TENSORRT_LLM_MODEL_REPOSITORY_PATH / "ensemble" / "1" - ensemble_dir.mkdir(parents=True, exist_ok=True) - - # Move all files and directories from data_dir to dest_dir - move_all_files(data_dir, dest_dir) - - -def download_engine(engine_repository: str, fp: Path, auth_token=None): - """ - Downloads the specified engine from Hugging Face Hub. - """ - snapshot_download( - engine_repository, - local_dir=fp, - local_dir_use_symlinks=False, - max_workers=4, - **({"use_auth_token": auth_token} if auth_token is not None else {}), - ) - - -def execute_command(command) -> None: - try: - process = subprocess.run(command, capture_output=True, text=True, check=True) - print("Standard Output:\n", process.stdout) - except FileNotFoundError: - raise FileNotFoundError( - f"The command '{command[0]}' is not found. Make sure it is installed and in your PATH." - ) - - - -def server_loaded(): - def port_is_available(port): - available = False - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - try: - sock.bind(("0.0.0.0", port)) - available = True - except: - pass - return available - - return not port_is_available(GRPC_SERVICE_PORT) or not port_is_available( - HTTP_SERVICE_PORT - ) - - -def server_loaded_file_approach(): - FILE_LOC = "/packages/worker.txt" - if Path(FILE_LOC).exists(): - return True - else: - Path(FILE_LOC).touch() - return False - \ No newline at end of file