From e1455e0cb17e0c09ba6a5c049b12b98bc849c65f Mon Sep 17 00:00:00 2001 From: Matthew Kotila Date: Fri, 9 Aug 2024 16:42:39 -0700 Subject: [PATCH] TensorRT-LLM in-process benchmarking support (#35) * Add tensorrtllm_engine option to service-kind and update testing (#700) (#762) * Add tensorrtllm_engine option to service-kind and update testing * Add output format check for tensorrtllm_engine Co-authored-by: Elias Bermudez <6505145+debermudez@users.noreply.github.com> * Support input payload generation for tensorrtllm engine (#767) * Add functionality for async requests and output retrieval with Triton C API (#25) * Support 1-d array data in profile exporter (#28) * support array of data in profile exporter * add some tests * run formatting * fix pre-commit * remove duplicate argparser arguments * Fix Triton C API mode missing infer requested output datatype bug --------- Co-authored-by: Matthew Kotila * Support profile data parsing for tensorrtllm engine service kind (#33) * support parsing tensorrtllm engine profile response * add test * refactor the test * update types and names * fix pre-commit * run PA with triton c api * more clean up on the tests * fix codeql * address feedback * Add functionality to continue benchmarking in Triton C API mode if server logging support is disabled (#34) --------- Co-authored-by: Hyunjae Woo <107147848+nv-hwoo@users.noreply.github.com> Co-authored-by: Elias Bermudez <6505145+debermudez@users.noreply.github.com> --- .../genai_perf/llm_inputs/llm_inputs.py | 97 +++++++ genai-perf/genai_perf/metrics/llm_metrics.py | 2 +- genai-perf/genai_perf/parser.py | 17 +- .../llm_profile_data_parser.py | 16 ++ .../profile_data_parser.py | 2 + genai-perf/genai_perf/wrapper.py | 10 +- genai-perf/tests/test_cli.py | 7 +- genai-perf/tests/test_llm_inputs.py | 102 +++++++ .../tests/test_llm_profile_data_parser.py | 269 +++++++++++++++++- .../triton_c_api/CMakeLists.txt | 4 +- .../triton_c_api/alloc_payload.h | 67 +++++ .../triton_c_api/c_api_infer_results.h | 63 +++- .../triton_c_api/response_output.h | 48 ++++ .../triton_c_api/triton_c_api_backend.cc | 94 +++++- .../triton_c_api/triton_c_api_backend.h | 28 +- .../triton_c_api/triton_loader.cc | 264 ++++++++++++----- .../triton_c_api/triton_loader.h | 40 ++- src/command_line_parser.cc | 13 +- src/mock_profile_data_exporter.h | 13 + src/perf_analyzer.cc | 7 + src/perf_analyzer_exception.h | 6 +- src/perf_utils.cc | 36 +++ src/perf_utils.h | 6 +- src/profile_data_exporter.cc | 110 +++++-- src/profile_data_exporter.h | 6 + src/request_record.h | 2 +- src/test_command_line_parser.cc | 4 +- src/test_profile_data_collector.cc | 24 +- src/test_profile_data_exporter.cc | 141 ++++++++- 29 files changed, 1335 insertions(+), 163 deletions(-) create mode 100644 src/client_backend/triton_c_api/alloc_payload.h create mode 100644 src/client_backend/triton_c_api/response_output.h diff --git a/genai-perf/genai_perf/llm_inputs/llm_inputs.py b/genai-perf/genai_perf/llm_inputs/llm_inputs.py index 057c3356..2fb0d9df 100644 --- a/genai-perf/genai_perf/llm_inputs/llm_inputs.py +++ b/genai-perf/genai_perf/llm_inputs/llm_inputs.py @@ -53,6 +53,7 @@ class OutputFormat(Enum): RANKINGS = auto() TENSORRTLLM = auto() VLLM = auto() + TENSORRTLLM_ENGINE = auto() def to_lowercase(self): return self.name.lower() @@ -216,6 +217,7 @@ def create_llm_inputs( json_in_pa_format = cls._convert_generic_json_to_output_format( output_format, + tokenizer, generic_dataset_json, add_model_name, add_stream, @@ -688,6 +690,7 @@ def _encode_images_in_input_dataset(cls, input_file_dataset: Dict) -> Dict: def _convert_generic_json_to_output_format( cls, output_format: OutputFormat, + tokenizer: Tokenizer, generic_dataset: Dict, add_model_name: bool, add_stream: bool, @@ -763,6 +766,16 @@ def _convert_generic_json_to_output_format( model_name, model_selection_strategy, ) + elif output_format == OutputFormat.TENSORRTLLM_ENGINE: + output_json = cls._convert_generic_json_to_trtllm_engine_format( + generic_dataset, + tokenizer, + add_stream, + extra_inputs, + output_tokens_mean, + output_tokens_stddev, + output_tokens_deterministic, + ) else: raise GenAIPerfException( f"Output format {output_format} is not currently supported" @@ -1010,6 +1023,28 @@ def _convert_generic_json_to_trtllm_format( return pa_json + @classmethod + def _convert_generic_json_to_trtllm_engine_format( + cls, + dataset_json: Dict, + tokenizer: Tokenizer, + add_stream: bool, + extra_inputs: Dict, + output_tokens_mean: int, + output_tokens_stddev: int, + output_tokens_deterministic: bool, + ) -> Dict: + pa_json = cls._populate_trtllm_engine_output_json( + dataset_json, + tokenizer, + add_stream, + extra_inputs, + output_tokens_mean, + output_tokens_stddev, + output_tokens_deterministic, + ) + return pa_json + @classmethod def _write_json_to_file(cls, json_in_pa_format: Dict, output_dir: Path) -> None: filename = output_dir / DEFAULT_INPUT_DATA_JSON @@ -1261,6 +1296,43 @@ def _populate_trtllm_output_json( return pa_json + @classmethod + def _populate_trtllm_engine_output_json( + cls, + dataset_json: Dict, + tokenizer: Tokenizer, + add_stream: bool, + extra_inputs: Dict, + output_tokens_mean: int, + output_tokens_stddev: int, + output_tokens_deterministic: bool, + ) -> Dict: + pa_json = cls._create_empty_trtllm_pa_json() + + for index, entry in enumerate(dataset_json["rows"]): + token_ids = tokenizer.encode(entry["text_input"]) + pa_json["data"].append( + { + "input_ids": { + "content": token_ids, + "shape": [len(token_ids)], + }, + "input_lengths": [len(token_ids)], + "request_output_len": [cls.DEFAULT_TENSORRTLLM_MAX_TOKENS], + } + ) + + pa_json = cls._add_optional_tags_to_trtllm_engine_json( + pa_json, + index, + add_stream, + extra_inputs, + output_tokens_mean, + output_tokens_stddev, + output_tokens_deterministic, + ) + return pa_json + @classmethod def _create_empty_openai_pa_json(cls) -> Dict: empty_pa_json = deepcopy(cls.EMPTY_JSON_IN_OPENAI_PA_FORMAT) @@ -1477,6 +1549,31 @@ def _add_optional_tags_to_trtllm_json( return pa_json + @classmethod + def _add_optional_tags_to_trtllm_engine_json( + cls, + pa_json: Dict, + index: int, + add_stream: bool, + extra_inputs: Dict, + output_tokens_mean: int, + output_tokens_stddev: int, + output_tokens_deterministic: bool, + ) -> Dict: + row = pa_json["data"][index] + if add_stream: + row["streaming"] = [True] + if output_tokens_mean != cls.DEFAULT_OUTPUT_TOKENS_MEAN: + num_tokens = int(random.gauss(output_tokens_mean, output_tokens_stddev)) + row["request_output_len"] = [num_tokens] + if output_tokens_deterministic: + row["min_length"] = [num_tokens] + + for key, value in extra_inputs.items(): + row[key] = [value] + + return pa_json + @classmethod def _add_required_tags_to_trtllm_json( cls, diff --git a/genai-perf/genai_perf/metrics/llm_metrics.py b/genai-perf/genai_perf/metrics/llm_metrics.py index 13dff8a6..7dd00ba7 100755 --- a/genai-perf/genai_perf/metrics/llm_metrics.py +++ b/genai-perf/genai_perf/metrics/llm_metrics.py @@ -54,7 +54,7 @@ def __init__( time_to_first_tokens: List[int] = [], inter_token_latencies: List[int] = [], output_token_throughputs: List[float] = [], - output_token_throughputs_per_request: List[int] = [], + output_token_throughputs_per_request: List[float] = [], output_sequence_lengths: List[int] = [], input_sequence_lengths: List[int] = [], chunked_inter_token_latencies: List[List[int]] = [[]], diff --git a/genai-perf/genai_perf/parser.py b/genai-perf/genai_perf/parser.py index 60994f29..c135313d 100644 --- a/genai-perf/genai_perf/parser.py +++ b/genai-perf/genai_perf/parser.py @@ -176,6 +176,9 @@ def _check_conditional_args( args = _convert_str_to_enum_entry(args, "backend", OutputFormat) args.output_format = args.backend + if args.service_kind == "tensorrtllm_engine": + args.output_format = OutputFormat.TENSORRTLLM_ENGINE + # Output token distribution checks if args.output_tokens_mean == LlmInputs.DEFAULT_OUTPUT_TOKENS_MEAN: if args.output_tokens_stddev != LlmInputs.DEFAULT_OUTPUT_TOKENS_STDDEV: @@ -187,10 +190,11 @@ def _check_conditional_args( "The --output-tokens-mean option is required when using --output-tokens-mean-deterministic." ) - if args.service_kind != "triton": + if args.service_kind not in ["triton", "tensorrtllm_engine"]: if args.output_tokens_mean_deterministic: parser.error( - "The --output-tokens-mean-deterministic option is only supported with the Triton service-kind." + "The --output-tokens-mean-deterministic option is only supported " + "with the Triton and TensorRT-LLM Engine service-kind." ) _check_conditional_args_embeddings_rankings(parser, args) @@ -267,6 +271,8 @@ def _set_artifact_paths(args: argparse.Namespace) -> argparse.Namespace: name += [f"{args.service_kind}-{args.endpoint_type}"] elif args.service_kind == "triton": name += [f"{args.service_kind}-{args.backend.to_lowercase()}"] + elif args.service_kind == "tensorrtllm_engine": + name += [f"{args.service_kind}"] else: raise ValueError(f"Unknown service kind '{args.service_kind}'.") @@ -578,7 +584,7 @@ def _add_endpoint_args(parser): endpoint_group.add_argument( "--service-kind", type=str, - choices=["triton", "openai"], + choices=["triton", "openai", "tensorrtllm_engine"], default="triton", required=False, help="The kind of service perf_analyzer will " @@ -625,9 +631,8 @@ def _add_output_args(parser): default=Path("profile_export.json"), help="The path where the perf_analyzer profile export will be " "generated. By default, the profile export will be to profile_export.json. " - "The genai-perf files will be exported to _genai_perf.json and " - "_genai_perf.csv. " - "For example, if the profile export file is profile_export.json, the genai-perf CSV file will be " + "The genai-perf file will be exported to _genai_perf.csv. " + "For example, if the profile export file is profile_export.json, the genai-perf file will be " "exported to profile_export_genai_perf.csv.", ) diff --git a/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py b/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py index 9136ab1f..b877b578 100755 --- a/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py +++ b/genai-perf/genai_perf/profile_data_parser/llm_profile_data_parser.py @@ -224,6 +224,8 @@ def _get_input_token_count(self, req_inputs: dict) -> int: """Deserialize the request input and return tokenized inputs.""" if self._service_kind == "triton": input_text = req_inputs["text_input"] + elif self._service_kind == "triton_c_api": + return len(req_inputs["input_ids"]) # no tokenizer required elif self._service_kind == "openai": input_text = self._get_openai_input_text(req_inputs) else: @@ -252,6 +254,9 @@ def _get_output_token_counts( """Return response-level token counts and total token count.""" if self._service_kind == "triton": output_texts = self._get_triton_output_tokens(res_outputs) + elif self._service_kind == "triton_c_api": + # No tokenizer is need to get the token counts. + return self._get_tensorrtllm_engine_token_counts(res_outputs) elif self._service_kind == "openai": output_texts = self._get_openai_output_tokens(res_outputs) else: @@ -263,6 +268,17 @@ def _get_output_token_counts( output_token_counts = list(map(len, output_tokens)) return output_token_counts, full_text_token_count + def _get_tensorrtllm_engine_token_counts( + self, res_outputs: List[Dict] + ) -> Tuple[List[int], int]: + token_ids = [] + for r in res_outputs: + if isinstance(r["output_ids"], list): + token_ids += r["output_ids"] + else: + token_ids.append(r["output_ids"]) + return token_ids, len(token_ids) + def _get_triton_output_tokens(self, res_outputs: List[Dict]) -> List[str]: """Return a list of Triton response texts.""" return [r["text_output"] for r in res_outputs] diff --git a/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py b/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py index 74eb48a2..245afb2c 100755 --- a/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py +++ b/genai-perf/genai_perf/profile_data_parser/profile_data_parser.py @@ -98,6 +98,8 @@ def _get_profile_metadata(self, data: dict) -> None: elif self._service_kind == "triton": self._response_format = ResponseFormat.TRITON + elif self._service_kind == "triton_c_api": + pass # ignore else: raise ValueError(f"Unknown service kind: {self._service_kind}") diff --git a/genai-perf/genai_perf/wrapper.py b/genai-perf/genai_perf/wrapper.py index 76ef3e32..c7b27a6b 100644 --- a/genai-perf/genai_perf/wrapper.py +++ b/genai-perf/genai_perf/wrapper.py @@ -110,6 +110,9 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s f"--input-data", f"{args.artifact_dir / DEFAULT_INPUT_DATA_JSON}", ] + cmd += Profiler.add_protocol_args(args) + cmd += Profiler.add_inference_load_args(args) + for arg, value in vars(args).items(): if arg in skip_args: pass @@ -122,6 +125,10 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s cmd += [f"-{arg}"] else: cmd += [f"--{arg}"] + # GAP needs to call PA using triton_c_api service kind when running + # against tensorrtllm engine. + elif arg == "service_kind" and value == "tensorrtllm_engine": + cmd += ["--service-kind", "triton_c_api", "--streaming"] else: if len(arg) == 1: cmd += [f"-{arg}", f"{value}"] @@ -129,9 +136,6 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s arg = utils.convert_option_name(arg) cmd += [f"--{arg}", f"{value}"] - cmd += Profiler.add_protocol_args(args) - cmd += Profiler.add_inference_load_args(args) - if extra_args is not None: for arg in extra_args: cmd += [f"{arg}"] diff --git a/genai-perf/tests/test_cli.py b/genai-perf/tests/test_cli.py index dcf637ed..db120f7f 100644 --- a/genai-perf/tests/test_cli.py +++ b/genai-perf/tests/test_cli.py @@ -203,6 +203,10 @@ def test_help_version_arguments_output_and_exit( (["--request-rate", "9.0"], {"request_rate": 9.0}), (["-s", "99.5"], {"stability_percentage": 99.5}), (["--service-kind", "triton"], {"service_kind": "triton"}), + ( + ["--service-kind", "tensorrtllm_engine"], + {"service_kind": "tensorrtllm_engine"}, + ), ( ["--service-kind", "openai", "--endpoint-type", "chat"], {"service_kind": "openai", "endpoint": "v1/chat/completions"}, @@ -530,7 +534,7 @@ def test_unrecognized_arg(self, monkeypatch, capsys): "100", "--output-tokens-mean-deterministic", ], - "The --output-tokens-mean-deterministic option is only supported with the Triton service-kind", + "The --output-tokens-mean-deterministic option is only supported with the Triton and TensorRT-LLM Engine service-kind", ), ( [ @@ -642,6 +646,7 @@ def test_conditional_errors(self, args, expected_output, monkeypatch, capsys): OutputFormat.TENSORRTLLM, ), (["--service-kind", "triton", "--backend", "vllm"], OutputFormat.VLLM), + (["--service-kind", "tensorrtllm_engine"], OutputFormat.TENSORRTLLM_ENGINE), ], ) def test_inferred_output_format(self, monkeypatch, args, expected_format): diff --git a/genai-perf/tests/test_llm_inputs.py b/genai-perf/tests/test_llm_inputs.py index 028e7284..e2acd7ae 100644 --- a/genai-perf/tests/test_llm_inputs.py +++ b/genai-perf/tests/test_llm_inputs.py @@ -554,6 +554,107 @@ def test_llm_inputs_with_defaults(self, default_configured_url): # else: # assert False, f"Unsupported output format: {output_format}" + @pytest.mark.parametrize( + "generic_json, add_stream, output_tokens_mean, output_tokens_deterministic, expected_json", + [ + ( + # generic_json + { + "rows": [ + {"text_input": "test input one"}, + {"text_input": "test input two"}, + ] + }, + False, + -1, + False, + # expected_json + { + "data": [ + { + "input_ids": { + "content": [1243, 1881, 697], + "shape": [3], + }, + "input_lengths": [3], + "request_output_len": [ + LlmInputs.DEFAULT_TENSORRTLLM_MAX_TOKENS + ], + }, + { + "input_ids": { + "content": [1243, 1881, 1023], + "shape": [3], + }, + "input_lengths": [3], + "request_output_len": [ + LlmInputs.DEFAULT_TENSORRTLLM_MAX_TOKENS + ], + }, + ], + }, + ), + ( + # generic_json + { + "rows": [ + {"text_input": "test input one"}, + {"text_input": "test input two"}, + ] + }, + True, + 999, + True, + # expected_json + { + "data": [ + { + "input_ids": { + "content": [1243, 1881, 697], + "shape": [3], + }, + "input_lengths": [3], + "request_output_len": [999], + "min_length": [999], + "streaming": [True], + }, + { + "input_ids": { + "content": [1243, 1881, 1023], + "shape": [3], + }, + "input_lengths": [3], + "request_output_len": [999], + "min_length": [999], + "streaming": [True], + }, + ], + }, + ), + ], + ) + def test_generic_json_to_trtllm_engine_format( + self, + generic_json, + add_stream, + output_tokens_mean, + output_tokens_deterministic, + expected_json, + ) -> None: + trtllm_json = LlmInputs._convert_generic_json_to_output_format( + output_format=OutputFormat.TENSORRTLLM_ENGINE, + tokenizer=get_tokenizer(DEFAULT_TOKENIZER), + generic_dataset=generic_json, + add_model_name=False, + add_stream=add_stream, + extra_inputs={}, + output_tokens_mean=output_tokens_mean, + output_tokens_stddev=0, + output_tokens_deterministic=output_tokens_deterministic, + ) + + assert trtllm_json == expected_json + def test_add_image_inputs_openai_vision(self) -> None: generic_json = { "rows": [ @@ -606,6 +707,7 @@ def test_add_image_inputs_openai_vision(self) -> None: OutputFormat.OPENAI_VISION, OutputFormat.VLLM, OutputFormat.TENSORRTLLM, + OutputFormat.TENSORRTLLM_ENGINE, ], ) def test_get_input_dataset_from_synthetic( diff --git a/genai-perf/tests/test_llm_profile_data_parser.py b/genai-perf/tests/test_llm_profile_data_parser.py index ceb9c8ea..88b1b98d 100644 --- a/genai-perf/tests/test_llm_profile_data_parser.py +++ b/genai-perf/tests/test_llm_profile_data_parser.py @@ -27,11 +27,12 @@ import json from io import StringIO from pathlib import Path -from typing import Any, List, Union +from typing import Any, List, Union, cast import numpy as np import pytest from genai_perf.metrics import LLMMetrics +from genai_perf.metrics.statistics import Statistics from genai_perf.profile_data_parser import LLMProfileDataParser from genai_perf.tokenizer import DEFAULT_TOKENIZER, get_tokenizer @@ -41,6 +42,28 @@ def ns_to_sec(ns: int) -> Union[int, float]: return ns / 1e9 +def check_statistics(s1: Statistics, s2: Statistics) -> None: + s1_dict = s1.stats_dict + s2_dict = s2.stats_dict + for metric in s1_dict.keys(): + for stat_name, value in s1_dict[metric].items(): + if stat_name != "unit": + assert s2_dict[metric][stat_name] == pytest.approx(value) + + +def check_llm_metrics(m1: LLMMetrics, m2: LLMMetrics) -> None: + assert m1.request_latencies == m2.request_latencies + assert m1.request_throughputs == pytest.approx(m2.request_throughputs) + assert m1.time_to_first_tokens == m2.time_to_first_tokens + assert m1.inter_token_latencies == m2.inter_token_latencies + assert m1.output_token_throughputs_per_request == pytest.approx( + m2.output_token_throughputs_per_request + ) + assert m1.output_token_throughputs == pytest.approx(m2.output_token_throughputs) + assert m1.output_sequence_lengths == m2.output_sequence_lengths + assert m1.input_sequence_lengths == m2.input_sequence_lengths + + class TestLLMProfileDataParser: @pytest.fixture def mock_read_write(self, monkeypatch: pytest.MonkeyPatch) -> List[str]: @@ -74,6 +97,9 @@ def write(self: Any, content: str) -> int: elif filename == "openai_vlm_profile_export.json": tmp_file = StringIO(json.dumps(self.openai_vlm_profile_data)) return tmp_file + elif filename == "tensorrtllm_engine_profile_export.json": + tmp_file = StringIO(json.dumps(self.tensorrtllm_engine_profile_data)) + return tmp_file elif filename == "empty_profile_export.json": tmp_file = StringIO(json.dumps(self.empty_profile_data)) return tmp_file @@ -413,6 +439,103 @@ def test_openai_vlm_profile_data(self, mock_read_write: pytest.MonkeyPatch) -> N with pytest.raises(KeyError): pd.get_statistics(infer_mode="concurrency", load_level="40") + @pytest.mark.parametrize( + "infer_mode, load_level, expected_metrics", + [ + ( + "concurrency", + "10", + { + "request_latencies": [7, 9], + "request_throughputs": [1 / ns_to_sec(5)], + "time_to_first_tokens": [2, 2], + "inter_token_latencies": [2, 4], + "output_token_throughputs_per_request": [ + 3 / ns_to_sec(7), + 1 / ns_to_sec(3), + ], + "output_token_throughputs": [3 / ns_to_sec(5)], + "output_sequence_lengths": [3, 3], + "input_sequence_lengths": [3, 4], + }, + ), + ( + "request_rate", + "2.0", + { + "request_latencies": [13, 8], + "request_throughputs": [2 / ns_to_sec(15)], + "time_to_first_tokens": [2, 3], + "inter_token_latencies": [4, 2], + "output_token_throughputs_per_request": [ + 4 / ns_to_sec(13), + 3 / ns_to_sec(8), + ], + "output_token_throughputs": [7 / ns_to_sec(15)], + "output_sequence_lengths": [4, 3], + "input_sequence_lengths": [3, 4], + }, + ), + ], + ) + def test_tensorrtllm_engine_llm_profile_data( + self, + mock_read_write: pytest.MonkeyPatch, + infer_mode, + load_level, + expected_metrics, + ) -> None: + """Collect LLM metrics from profile export data and check values. + + Metrics + * request_latencies + - experiment 1: [8 - 1, 11 - 2] = [7, 9] + - experiment 2: [18 - 5, 11 -3] = [13, 8] + * request_throughputs + - experiment 1: [2/(11 - 1)] = [1/5] + - experiment 2: [2/(18 - 3)] = [2/15] + * time to first tokens + - experiment 1: [3 - 1, 4 - 2] = [2, 2] + - experiment 2: [7 - 5, 6 - 3] = [2, 3] + * inter token latencies + - experiment 1: [((8 - 1) - 2)/(3 - 1), ((11 - 2) - 2)/(3 - 1)] + : [2.5, 3.5] + : [2, 4] # rounded + - experiment 2: [((18 - 5) - 2)/(4 - 1), ((11 - 3) - 3)/(3 - 1)] + : [11/3, 2.5] + : [4, 2] # rounded + * output token throughputs per request + - experiment 1: [3/(8 - 1), 3/(11 - 2)] = [3/7, 1/3] + - experiment 2: [4/(18 - 5), 3/(11 - 3)] = [4/13, 3/8] + * output token throughputs + - experiment 1: [(3 + 3)/(11 - 1)] = [3/5] + - experiment 2: [(4 + 3)/(18 - 3)] = [7/15] + * output sequence lengths + - experiment 1: [3, 3] + - experiment 2: [4, 3] + * input sequence lengths + - experiment 1: [3, 4] + - experiment 2: [3, 4] + """ + tokenizer = get_tokenizer(DEFAULT_TOKENIZER) + pd = LLMProfileDataParser( + filename=Path("tensorrtllm_engine_profile_export.json"), + tokenizer=tokenizer, + ) + + statistics = pd.get_statistics(infer_mode=infer_mode, load_level=load_level) + metrics = cast(LLMMetrics, statistics.metrics) + + expected_metrics = LLMMetrics(**expected_metrics) + expected_statistics = Statistics(expected_metrics) + + check_llm_metrics(metrics, expected_metrics) + check_statistics(statistics, expected_statistics) + + # check non-existing profile data + with pytest.raises(KeyError): + pd.get_statistics(infer_mode="concurrency", load_level="30") + def test_merged_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None: """Test merging the multiple sse response.""" res_timestamps = [0, 1, 2, 3] @@ -787,3 +910,147 @@ def test_non_sse_response(self, mock_read_write: pytest.MonkeyPatch) -> None: }, ], } + + tensorrtllm_engine_profile_data = { + "service_kind": "triton_c_api", + "endpoint": "", + "experiments": [ + { + "experiment": { + "mode": "concurrency", + "value": 10, + }, + "requests": [ + { + "timestamp": 1, + "request_inputs": { + "streaming": True, + "request_output_len": 3, + "min_length": 3, + "input_lengths": 3, + "input_ids": [ + 111, + 222, + 333, + ], + }, + "response_timestamps": [3, 5, 8], + "response_outputs": [ + { + "output_log_probs": [0, 0], + "output_ids": 123, + }, + { + "output_log_probs": [0, 0], + "output_ids": 456, + }, + { + "output_ids": 789, + }, + ], + }, + { + "timestamp": 2, + "request_inputs": { + "streaming": True, + "request_output_len": 3, + "min_length": 3, + "input_lengths": 4, + "input_ids": [ + 111, + 222, + 333, + 444, + ], + }, + "response_timestamps": [4, 7, 11], + "response_outputs": [ + { + "output_log_probs": [0, 0], + "output_ids": 123, + }, + { + "output_log_probs": [0, 0], + "output_ids": 456, + }, + { + "output_log_probs": [0, 0], + "output_ids": 789, + }, + ], + }, + ], + }, + { + "experiment": { + "mode": "request_rate", + "value": 2.0, + }, + "requests": [ + { + "timestamp": 5, + "request_inputs": { + "streaming": True, + "request_output_len": 4, + "min_length": 4, + "input_lengths": 3, + "input_ids": [ + 111, + 222, + 333, + ], + }, + "response_timestamps": [7, 8, 13, 18], + "response_outputs": [ + { + "output_log_probs": [0, 0], + "output_ids": 123, + }, + { + "output_log_probs": [0, 0], + "output_ids": 456, + }, + { + "output_log_probs": [0, 0], + "output_ids": 789, + }, + { + "output_log_probs": [0, 0], + "output_ids": 1011, + }, + ], + }, + { + "timestamp": 3, + "request_inputs": { + "streaming": True, + "request_output_len": 3, + "min_length": 3, + "input_lengths": 4, + "input_ids": [ + 111, + 222, + 333, + 444, + ], + }, + "response_timestamps": [6, 8, 11], + "response_outputs": [ + { + "output_log_probs": [0, 0], + "output_ids": 123, + }, + { + "output_log_probs": [0, 0], + "output_ids": 456, + }, + { + "output_log_probs": [0, 0], + "output_ids": 789, + }, + ], + }, + ], + }, + ], + } diff --git a/src/client_backend/triton_c_api/CMakeLists.txt b/src/client_backend/triton_c_api/CMakeLists.txt index e3312aec..0b954b52 100644 --- a/src/client_backend/triton_c_api/CMakeLists.txt +++ b/src/client_backend/triton_c_api/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions @@ -52,6 +52,8 @@ set( triton_loader.h c_api_infer_results.h scoped_defer.h + response_output.h + alloc_payload.h ) add_library( diff --git a/src/client_backend/triton_c_api/alloc_payload.h b/src/client_backend/triton_c_api/alloc_payload.h new file mode 100644 index 00000000..19a0ff81 --- /dev/null +++ b/src/client_backend/triton_c_api/alloc_payload.h @@ -0,0 +1,67 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include + +#include +#include +#include + +namespace triton { namespace perfanalyzer { namespace clientbackend { +namespace tritoncapi { + +struct AllocPayload { + struct OutputInfo { + enum Kind { BINARY, SHM }; + + Kind kind_; + void* base_; + uint64_t byte_size_; + TRITONSERVER_MemoryType memory_type_; + int64_t device_id_; + + // For shared memory + OutputInfo( + void* base, uint64_t byte_size, TRITONSERVER_MemoryType memory_type, + int64_t device_id) + : kind_(SHM), base_(base), byte_size_(byte_size), + memory_type_(memory_type), device_id_(device_id) + { + } + }; + + ~AllocPayload() + { + for (auto it : output_map_) { + delete it.second; + } + } + + std::unordered_map output_map_; +}; + +}}}} // namespace triton::perfanalyzer::clientbackend::tritoncapi diff --git a/src/client_backend/triton_c_api/c_api_infer_results.h b/src/client_backend/triton_c_api/c_api_infer_results.h index 440a94c0..abc10450 100644 --- a/src/client_backend/triton_c_api/c_api_infer_results.h +++ b/src/client_backend/triton_c_api/c_api_infer_results.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -25,7 +25,14 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #pragma once +#include +#include +#include +#include +#include + #include "common.h" +#include "response_output.h" namespace tc = triton::client; @@ -38,9 +45,12 @@ namespace tritoncapi { class InferResult { public: static void Create( - InferResult** infer_result, const tc::Error& err, const std::string& id) + InferResult** infer_result, const tc::Error& err, const std::string& id, + std::unordered_map&& outputs, + bool is_final_response, bool is_null_response) { - *infer_result = reinterpret_cast(new InferResult(err, id)); + *infer_result = reinterpret_cast(new InferResult( + err, id, std::move(outputs), is_final_response, is_null_response)); } tc::Error Id(std::string* id) const @@ -50,13 +60,56 @@ class InferResult { } tc::Error RequestStatus() const { return status_; } + tc::Error RawData( + const std::string& output_name, const uint8_t** buf, + size_t* byte_size) const + { + auto it = outputs_.find(output_name); + if (it != outputs_.end()) { + *buf = reinterpret_cast(it->second.base); + *byte_size = it->second.byte_size; + } else { + return tc::Error( + "The response does not contain results for output name '" + + output_name + "'"); + } + + return tc::Error::Success; + } + + tc::Error IsFinalResponse(bool* is_final_response) const + { + if (is_final_response == nullptr) { + return tc::Error("is_final_response cannot be nullptr"); + } + *is_final_response = is_final_response_; + return tc::Error::Success; + } + + tc::Error IsNullResponse(bool* is_null_response) const + { + if (is_null_response == nullptr) { + return tc::Error("is_null_response cannot be nullptr"); + } + *is_null_response = is_null_response_; + return tc::Error::Success; + } + private: - InferResult(const tc::Error& err, const std::string& id) - : status_(err), request_id_(id) + InferResult( + const tc::Error& err, const std::string& id, + std::unordered_map&& outputs, + bool is_final_response, bool is_null_response) + : status_(err), request_id_(id), outputs_(std::move(outputs)), + is_final_response_(is_final_response), + is_null_response_(is_null_response) { } std::string request_id_; tc::Error status_; + std::unordered_map outputs_{}; + bool is_final_response_{true}; + bool is_null_response_{false}; }; }}}} // namespace triton::perfanalyzer::clientbackend::tritoncapi diff --git a/src/client_backend/triton_c_api/response_output.h b/src/client_backend/triton_c_api/response_output.h new file mode 100644 index 00000000..26b123d9 --- /dev/null +++ b/src/client_backend/triton_c_api/response_output.h @@ -0,0 +1,48 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include + +#include +#include + +namespace triton { namespace perfanalyzer { namespace clientbackend { +namespace tritoncapi { + +struct ResponseOutput { + const char* name{}; + TRITONSERVER_DataType datatype{}; + const int64_t* shape{}; + uint64_t dim_count{}; + const void* base{}; + size_t byte_size{}; + TRITONSERVER_MemoryType memory_type{}; + int64_t memory_type_id{}; + void* userp{}; +}; + +}}}} // namespace triton::perfanalyzer::clientbackend::tritoncapi diff --git a/src/client_backend/triton_c_api/triton_c_api_backend.cc b/src/client_backend/triton_c_api/triton_c_api_backend.cc index e97f1ea8..96dd67cb 100644 --- a/src/client_backend/triton_c_api/triton_c_api_backend.cc +++ b/src/client_backend/triton_c_api/triton_c_api_backend.cc @@ -1,4 +1,4 @@ -// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -121,6 +121,66 @@ TritonCApiClientBackend::Infer( return Error::Success; } +Error +TritonCApiClientBackend::AsyncInfer( + OnCompleteFn callback, const InferOptions& options, + const std::vector& inputs, + const std::vector& outputs) +{ + auto wrapped_callback = [callback](capi::InferResult* client_result) { + cb::InferResult* result = new TritonCApiInferResult(client_result); + callback(result); + }; + + std::vector triton_inputs; + ParseInferInputToTriton(inputs, &triton_inputs); + + std::vector triton_outputs; + ParseInferRequestedOutputToTriton(outputs, &triton_outputs); + + tc::InferOptions triton_options(options.model_name_); + ParseInferOptionsToTriton(options, &triton_options); + + RETURN_IF_ERROR(triton_loader_->AsyncInfer( + wrapped_callback, triton_options, triton_inputs, triton_outputs, + enable_stats_)); + + return Error::Success; +} + +Error +TritonCApiClientBackend::StartStream(OnCompleteFn callback, bool enable_stats) +{ + stream_callback_ = callback; + enable_stats_ = enable_stats; + return Error::Success; +} + +Error +TritonCApiClientBackend::AsyncStreamInfer( + const InferOptions& options, const std::vector& inputs, + const std::vector& outputs) +{ + auto wrapped_callback = [this](capi::InferResult* client_result) { + cb::InferResult* result = new TritonCApiInferResult(client_result); + stream_callback_(result); + }; + + std::vector triton_inputs; + ParseInferInputToTriton(inputs, &triton_inputs); + + std::vector triton_outputs; + ParseInferRequestedOutputToTriton(outputs, &triton_outputs); + + tc::InferOptions triton_options(options.model_name_); + ParseInferOptionsToTriton(options, &triton_options); + + RETURN_IF_ERROR(triton_loader_->AsyncInfer( + wrapped_callback, triton_options, triton_inputs, triton_outputs, + enable_stats_)); + + return Error::Success; +} Error TritonCApiClientBackend::ClientInferStat(InferStat* infer_stat) @@ -324,6 +384,13 @@ TritonCApiInferInput::SetSharedMemory( return Error::Success; } +Error +TritonCApiInferInput::RawData(const uint8_t** buf, size_t* byte_size) +{ + RETURN_IF_TRITON_ERROR(input_->RawData(buf, byte_size)); + return Error::Success; +} + TritonCApiInferInput::TritonCApiInferInput( const std::string& name, const std::string& datatype) : InferInput(BackendKind::TRITON_C_API, name, datatype) @@ -339,7 +406,7 @@ TritonCApiInferRequestedOutput::Create( const size_t class_count, const std::string& datatype) { TritonCApiInferRequestedOutput* local_infer_output = - new TritonCApiInferRequestedOutput(name); + new TritonCApiInferRequestedOutput(name, datatype); tc::InferRequestedOutput* triton_infer_output; RETURN_IF_TRITON_ERROR(tc::InferRequestedOutput::Create( @@ -360,8 +427,8 @@ TritonCApiInferRequestedOutput::SetSharedMemory( } TritonCApiInferRequestedOutput::TritonCApiInferRequestedOutput( - const std::string& name) - : InferRequestedOutput(BackendKind::TRITON_C_API, name) + const std::string& name, const std::string& datatype) + : InferRequestedOutput(BackendKind::TRITON_C_API, name, datatype) { } @@ -391,9 +458,22 @@ TritonCApiInferResult::RawData( const std::string& output_name, const uint8_t** buf, size_t* byte_size) const { - return Error( - "Output retrieval is not currently supported for Triton C API client " - "backend"); + RETURN_IF_TRITON_ERROR(result_->RawData(output_name, buf, byte_size)); + return Error::Success; +} + +Error +TritonCApiInferResult::IsFinalResponse(bool* is_final_response) const +{ + RETURN_IF_TRITON_ERROR(result_->IsFinalResponse(is_final_response)); + return Error::Success; +} + +Error +TritonCApiInferResult::IsNullResponse(bool* is_null_response) const +{ + RETURN_IF_TRITON_ERROR(result_->IsNullResponse(is_null_response)); + return Error::Success; } //============================================================================== diff --git a/src/client_backend/triton_c_api/triton_c_api_backend.h b/src/client_backend/triton_c_api/triton_c_api_backend.h index 0f9f5def..33d519f8 100644 --- a/src/client_backend/triton_c_api/triton_c_api_backend.h +++ b/src/client_backend/triton_c_api/triton_c_api_backend.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -98,6 +98,20 @@ class TritonCApiClientBackend : public ClientBackend { const std::vector& inputs, const std::vector& outputs) override; + /// See ClientBackend::AsyncInfer() + Error AsyncInfer( + OnCompleteFn callback, const InferOptions& options, + const std::vector& inputs, + const std::vector& outputs) override; + + /// See ClientBackend::StartStream() + Error StartStream(OnCompleteFn callback, bool enable_stats) override; + + /// See ClientBackend::AsyncStreamInfer() + Error AsyncStreamInfer( + const InferOptions& options, const std::vector& inputs, + const std::vector& outputs) override; + /// See ClientBackend::ClientInferStat() Error ClientInferStat(InferStat* infer_stat) override; @@ -140,6 +154,8 @@ class TritonCApiClientBackend : public ClientBackend { void ParseInferStat( const tc::InferStat& triton_infer_stat, InferStat* infer_stat); TritonLoader* triton_loader_; + OnCompleteFn stream_callback_{}; + bool enable_stats_{true}; }; //============================================================== @@ -171,6 +187,9 @@ class TritonCApiInferInput : public InferInput { Error SetSharedMemory( const std::string& name, size_t byte_size, size_t offset = 0) override; + /// See InferInput::RawData() + Error RawData(const uint8_t** buf, size_t* byte_size) override; + private: explicit TritonCApiInferInput( const std::string& name, const std::string& datatype); @@ -196,7 +215,8 @@ class TritonCApiInferRequestedOutput : public InferRequestedOutput { const std::string& name, size_t byte_size, size_t offset = 0) override; private: - explicit TritonCApiInferRequestedOutput(const std::string& name); + explicit TritonCApiInferRequestedOutput( + const std::string& name, const std::string& datatype); std::unique_ptr output_; }; @@ -216,6 +236,10 @@ class TritonCApiInferResult : public cb::InferResult { Error RawData( const std::string& output_name, const uint8_t** buf, size_t* byte_size) const override; + /// See InferResult::IsFinalResponse() + Error IsFinalResponse(bool* is_final_response) const override; + /// See InferResult::IsNullResponse() + Error IsNullResponse(bool* is_null_response) const override; private: std::unique_ptr result_; diff --git a/src/client_backend/triton_c_api/triton_loader.cc b/src/client_backend/triton_c_api/triton_loader.cc index 35f7657f..27e1eda6 100644 --- a/src/client_backend/triton_c_api/triton_loader.cc +++ b/src/client_backend/triton_c_api/triton_loader.cc @@ -1,4 +1,4 @@ -// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -46,36 +46,6 @@ namespace triton { namespace perfanalyzer { namespace clientbackend { namespace tritoncapi { namespace { -struct AllocPayload { - struct OutputInfo { - enum Kind { BINARY, SHM }; - - Kind kind_; - void* base_; - uint64_t byte_size_; - TRITONSERVER_MemoryType memory_type_; - int64_t device_id_; - - // For shared memory - OutputInfo( - void* base, uint64_t byte_size, TRITONSERVER_MemoryType memory_type, - int64_t device_id) - : kind_(SHM), base_(base), byte_size_(byte_size), - memory_type_(memory_type), device_id_(device_id) - { - } - }; - - ~AllocPayload() - { - for (auto it : output_map_) { - delete it.second; - } - } - - std::unordered_map output_map_; -}; - bool helper_verbose = false; /// Helper function for allocating memory TRITONSERVER_Error* @@ -272,7 +242,8 @@ TritonLoader::StartTriton() // Check API version. uint32_t api_version_major, api_version_minor; REPORT_TRITONSERVER_ERROR( - api_version_fn_(&api_version_major, &api_version_minor)); + api_version_fn_(&api_version_major, &api_version_minor), + "unable to get api version"); if ((TRITONSERVER_API_VERSION_MAJOR != api_version_major) || (TRITONSERVER_API_VERSION_MINOR > api_version_minor)) { std::stringstream sstream; @@ -294,10 +265,10 @@ TritonLoader::StartTriton() RETURN_IF_TRITONSERVER_ERROR( set_cuda_memory_pool_byte_size_(server_options, 0, 1073741824), "setting cuda memory pool byte size failed."); - RETURN_IF_TRITONSERVER_ERROR( + REPORT_TRITONSERVER_ERROR( set_log_verbose_fn_(server_options, verbose_level_), "setting verbose logging level"); - RETURN_IF_TRITONSERVER_ERROR( + REPORT_TRITONSERVER_ERROR( set_log_info_fn_(server_options, verbose_), "setting if log verbose level is true"); RETURN_IF_TRITONSERVER_ERROR( @@ -365,6 +336,30 @@ TritonLoader::StartTriton() "deleting status metadata"); } + // Create the allocator that will be used to allocate buffers for + // the result tensors. + RETURN_IF_TRITONSERVER_ERROR( + GetSingleton()->response_allocator_new_fn_( + &allocator_, + reinterpret_cast< + TRITONSERVER_Error* (*)(TRITONSERVER_ResponseAllocator* allocator, + const char* tensor_name, size_t byte_size, + TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id, void* userp, + void** buffer, void** buffer_userp, + TRITONSERVER_MemoryType* + actual_memory_type, + int64_t* actual_memory_type_id)>( + ResponseAlloc), + reinterpret_cast< + TRITONSERVER_Error* (*)(TRITONSERVER_ResponseAllocator* allocator, + void* buffer, void* buffer_userp, + size_t byte_size, + TRITONSERVER_MemoryType memory_type, + int64_t memory_type_id)>(ResponseRelease), + nullptr /* start_fn */), + "creating response allocator"); + return Error::Success; } @@ -918,16 +913,15 @@ TritonLoader::Infer( return Error("Server is not ready and/or requested model is not loaded"); } - TRITONSERVER_ResponseAllocator* allocator = nullptr; TRITONSERVER_InferenceRequest* irequest = nullptr; TRITONSERVER_InferenceResponse* completed_response = nullptr; tc::RequestTimers timer; timer.Reset(); timer.CaptureTimestamp(tc::RequestTimers::Kind::REQUEST_START); - RETURN_IF_ERROR(InitializeRequest(options, outputs, &allocator, &irequest)); - ScopedDefer error_handler([&error, &completed_response, &allocator, this] { - error = CleanUp(completed_response, allocator); + RETURN_IF_ERROR(InitializeRequest(options, outputs, &irequest)); + ScopedDefer error_handler([&error, &completed_response, this] { + error = CleanUp(completed_response); }); RETURN_IF_ERROR(AddInputs(inputs, irequest)); RETURN_IF_ERROR(AddOutputs(outputs, irequest)); @@ -966,7 +960,7 @@ TritonLoader::Infer( std::future completed = p->get_future(); RETURN_IF_TRITONSERVER_ERROR( inference_request_set_response_callback_fn_( - irequest, allocator, &alloc_payload /* response_allocator_userp */, + irequest, allocator_, &alloc_payload /* response_allocator_userp */, InferResponseComplete, reinterpret_cast(p)), "setting response callback"); RETURN_IF_TRITONSERVER_ERROR( @@ -990,26 +984,174 @@ TritonLoader::Infer( std::cerr << "Failed to update context stat: " << err << std::endl; } - InferResult::Create(result, err, id); + std::unordered_map response_outputs{}; + GetOutputs(completed_response, response_outputs); + + // Synchronous mode requests only ever have one response, which is "final" + bool is_final_response{true}; + bool is_null_response{completed_response == nullptr}; + + InferResult::Create( + result, err, id, std::move(response_outputs), is_final_response, + is_null_response); - // CleanUp the response allocators + // CleanUp the response error_handler.Complete(); return error; } Error -TritonLoader::CleanUp( - TRITONSERVER_InferenceResponse* completed_response, - TRITONSERVER_ResponseAllocator* allocator) +TritonLoader::GetOutputs( + TRITONSERVER_InferenceResponse* response, + std::unordered_map& outputs) +{ + uint32_t count{}; + RETURN_IF_TRITONSERVER_ERROR( + inference_response_output_count_fn_(response, &count), + "inference_response_output_count_fn_ error"); + + for (uint32_t index{0}; index < count; index++) { + const char* name{}; + TRITONSERVER_DataType datatype{}; + const int64_t* shape{}; + uint64_t dim_count{}; + const void* base{}; + size_t byte_size{}; + TRITONSERVER_MemoryType memory_type{}; + int64_t memory_type_id{}; + void* userp{}; + + RETURN_IF_TRITONSERVER_ERROR( + inference_response_output_fn_( + response, index, &name, &datatype, &shape, &dim_count, &base, + &byte_size, &memory_type, &memory_type_id, &userp), + "inference_response_output_fn_ error"); + + outputs.emplace( + name, ResponseOutput{ + name, datatype, shape, dim_count, base, byte_size, + memory_type, memory_type_id, userp}); + } + + return Error::Success; +} + +void +InferResponseCompleteAsyncNonMember( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp) +{ + TritonLoader::GetSingleton()->InferResponseCompleteAsync( + response, flags, + reinterpret_cast(userp)); +} + +void +TritonLoader::InferResponseCompleteAsync( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, + AsyncRequestInfo* async_request_info) +{ + REPORT_TRITONSERVER_ERROR( + inference_response_error_fn_(response), + "unable to get inference response error"); + + if (async_request_info->enable_stats) { + tc::RequestTimers timer{*async_request_info->timer}; + + timer.CaptureTimestamp(tc::RequestTimers::Kind::RECV_START); + timer.CaptureTimestamp(tc::RequestTimers::Kind::RECV_END); + timer.CaptureTimestamp(tc::RequestTimers::Kind::REQUEST_END); + + { + std::lock_guard lock(update_infer_stat_mutex_); + tc::Error err{UpdateInferStat(timer)}; + if (!err.IsOk()) { + std::cerr << "Failed to update context stat: " << err << std::endl; + } + } + } + + std::unordered_map outputs{}; + Error err{GetOutputs(response, outputs)}; + if (!err.IsOk()) { + std::cerr << "Failed to get outputs: " << err << std::endl; + } + + bool is_final_response{flags == TRITONSERVER_RESPONSE_COMPLETE_FINAL}; + bool is_null_response{response == nullptr}; + + InferResult* infer_result{}; + InferResult::Create( + &infer_result, tc::Error::Success, async_request_info->request_id, + std::move(outputs), is_final_response, is_null_response); + + async_request_info->callback(infer_result); + + if (is_final_response) { + delete async_request_info; + } + + CleanUp(response); +} + +Error +TritonLoader::AsyncInfer( + OnCompleteFn callback, const tc::InferOptions& options, + const std::vector& inputs, + const std::vector& outputs, + bool enable_stats) +{ + Error error = Error::Success; + if (!ServerIsReady() || !ModelIsLoaded()) { + return Error("Server is not ready and/or requested model is not loaded"); + } + + TRITONSERVER_InferenceRequest* irequest = nullptr; + TRITONSERVER_InferenceResponse* completed_response = nullptr; + std::shared_ptr timer{ + std::make_shared()}; + timer->Reset(); + timer->CaptureTimestamp(tc::RequestTimers::Kind::REQUEST_START); + + RETURN_IF_ERROR(InitializeRequest(options, outputs, &irequest)); + RETURN_IF_ERROR(AddInputs(inputs, irequest)); + RETURN_IF_ERROR(AddOutputs(outputs, irequest)); + + const char* cid = nullptr; + RETURN_IF_TRITONSERVER_ERROR( + request_id_fn_(irequest, &cid), "Failed to get request id"); + + // Perform inference... + timer->CaptureTimestamp(tc::RequestTimers::Kind::SEND_START); + AsyncRequestInfo* async_request_info{new AsyncRequestInfo}; + async_request_info->alloc_payload = std::make_unique(); + async_request_info->request_id = cid; + async_request_info->timer = timer; + async_request_info->callback = callback; + async_request_info->enable_stats = enable_stats; + RETURN_IF_TRITONSERVER_ERROR( + inference_request_set_response_callback_fn_( + irequest, allocator_, + async_request_info->alloc_payload + .get() /* response_allocator_userp */, + InferResponseCompleteAsyncNonMember, async_request_info), + "setting response callback"); + RETURN_IF_TRITONSERVER_ERROR( + infer_async_fn_((server_).get(), irequest, nullptr /* trace */), + "running inference"); + timer->CaptureTimestamp(tc::RequestTimers::Kind::SEND_END); + + return error; +} + +Error +TritonLoader::CleanUp(TRITONSERVER_InferenceResponse* completed_response) { TRITONSERVER_Error* response_err = nullptr; if (completed_response != nullptr) { response_err = inference_response_delete_fn_(completed_response); + RETURN_IF_TRITONSERVER_ERROR(response_err, "deleting inference response"); } - TRITONSERVER_Error* allocator_err = response_allocator_delete_fn_(allocator); - RETURN_IF_TRITONSERVER_ERROR(response_err, "deleting inference response"); - RETURN_IF_TRITONSERVER_ERROR(allocator_err, "deleting response allocator"); return Error::Success; } @@ -1017,33 +1159,8 @@ Error TritonLoader::InitializeRequest( const tc::InferOptions& options, const std::vector& outputs, - TRITONSERVER_ResponseAllocator** allocator, TRITONSERVER_InferenceRequest** irequest) { - // Create the allocator that will be used to allocate buffers for - // the result tensors. - RETURN_IF_TRITONSERVER_ERROR( - GetSingleton()->response_allocator_new_fn_( - allocator, - reinterpret_cast< - TRITONSERVER_Error* (*)(TRITONSERVER_ResponseAllocator* allocator, - const char* tensor_name, size_t byte_size, - TRITONSERVER_MemoryType memory_type, - int64_t memory_type_id, void* userp, - void** buffer, void** buffer_userp, - TRITONSERVER_MemoryType* - actual_memory_type, - int64_t* actual_memory_type_id)>( - ResponseAlloc), - reinterpret_cast< - TRITONSERVER_Error* (*)(TRITONSERVER_ResponseAllocator* allocator, - void* buffer, void* buffer_userp, - size_t byte_size, - TRITONSERVER_MemoryType memory_type, - int64_t memory_type_id)>(ResponseRelease), - nullptr /* start_fn */), - "creating response allocator"); - // set up inference request RETURN_IF_TRITONSERVER_ERROR( inference_request_new_fn_( @@ -1233,6 +1350,9 @@ TritonLoader::GetSingleton() TritonLoader::~TritonLoader() { + TRITONSERVER_Error* allocator_err = response_allocator_delete_fn_(allocator_); + FAIL_IF_TRITONSERVER_ERROR(allocator_err, "deleting response allocator"); + FAIL_IF_ERR(Delete(), "dereferencing server instance..."); FAIL_IF_ERR(CloseLibraryHandle(dlhandle_), "error on closing triton loader"); ClearHandles(); diff --git a/src/client_backend/triton_c_api/triton_loader.h b/src/client_backend/triton_c_api/triton_loader.h index 1a18176c..4ea840f1 100644 --- a/src/client_backend/triton_c_api/triton_loader.h +++ b/src/client_backend/triton_c_api/triton_loader.h @@ -1,4 +1,4 @@ -// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -34,7 +34,9 @@ #include #include "../client_backend.h" +#include "alloc_payload.h" #include "common.h" +#include "response_output.h" #include "shared_library.h" #include "shared_memory_manager.h" #include "triton/core/tritonserver.h" @@ -66,11 +68,12 @@ } \ } while (false) -#define REPORT_TRITONSERVER_ERROR(E) \ +#define REPORT_TRITONSERVER_ERROR(E, MSG) \ do { \ TRITONSERVER_Error* err__ = (E); \ if (err__ != nullptr) { \ - std::cout << GetSingleton()->error_message_fn_(err__) << std::endl; \ + std::cerr << "error: " << (MSG) << ": " \ + << GetSingleton()->error_message_fn_(err__) << std::endl; \ GetSingleton()->error_delete_fn_(err__); \ } \ } while (false) @@ -84,6 +87,16 @@ class InferResult; class TritonLoader : public tc::InferenceServerClient { public: + using OnCompleteFn = std::function; + + struct AsyncRequestInfo { + std::unique_ptr alloc_payload{}; + std::string request_id{}; + std::shared_ptr timer{}; + OnCompleteFn callback{}; + bool enable_stats{true}; + }; + ~TritonLoader(); static Error Create( @@ -110,9 +123,17 @@ class TritonLoader : public tc::InferenceServerClient { const std::vector& outputs, InferResult** result); - Error CleanUp( - TRITONSERVER_InferenceResponse* completed_response, - TRITONSERVER_ResponseAllocator* allocator); + void InferResponseCompleteAsync( + TRITONSERVER_InferenceResponse* response, const uint32_t flags, + AsyncRequestInfo* async_request_info); + + Error AsyncInfer( + OnCompleteFn callback, const tc::InferOptions& options, + const std::vector& inputs, + const std::vector& outputs, + bool enable_stats); + + Error CleanUp(TRITONSERVER_InferenceResponse* completed_response); Error ModelInferenceStatistics( const std::string& model_name, const std::string& model_version, @@ -419,7 +440,6 @@ class TritonLoader : public tc::InferenceServerClient { Error InitializeRequest( const tc::InferOptions& options, const std::vector& outputs, - TRITONSERVER_ResponseAllocator** allocator, TRITONSERVER_InferenceRequest** irequest); Error AddInputs( @@ -430,6 +450,10 @@ class TritonLoader : public tc::InferenceServerClient { const std::vector& outputs, TRITONSERVER_InferenceRequest* irequest); + Error GetOutputs( + TRITONSERVER_InferenceResponse* response, + std::unordered_map& outputs); + void* dlhandle_; TritonServerApiVersionFn_t api_version_fn_; TritonServerOptionsNewFn_t options_new_fn_; @@ -514,6 +538,8 @@ class TritonLoader : public tc::InferenceServerClient { bool model_is_loaded_{false}; bool server_is_ready_{false}; std::unique_ptr shm_manager_{nullptr}; + TRITONSERVER_ResponseAllocator* allocator_{}; + std::mutex update_infer_stat_mutex_{}; }; }}}} // namespace triton::perfanalyzer::clientbackend::tritoncapi diff --git a/src/command_line_parser.cc b/src/command_line_parser.cc index 8003be71..1197454e 100644 --- a/src/command_line_parser.cc +++ b/src/command_line_parser.cc @@ -1752,8 +1752,9 @@ CLParser::VerifyOptions() "Failed to parse -i (protocol). The value should be either HTTP or " "gRPC."); } - if (params_->streaming && (params_->protocol != cb::ProtocolType::GRPC)) { - Usage("Streaming is only allowed with gRPC protocol."); + if (params_->streaming && (params_->protocol != cb::ProtocolType::GRPC && + params_->kind != cb::BackendKind::TRITON_C_API)) { + Usage("Streaming is only allowed with gRPC protocol and Triton C API."); } if (params_->using_grpc_compression && (params_->protocol != cb::ProtocolType::GRPC)) { @@ -1956,10 +1957,12 @@ CLParser::VerifyOptions() "service-kind=triton_c_api."); } - if (params_->async) { + // Decoupled models run via Triton C API do not support shared memory + if (params_->async && params_->streaming && + params_->shared_memory_type != SharedMemoryType::NO_SHARED_MEMORY) { Usage( - "Async mode is not supported by triton_c_api service " - "kind."); + "Cannot use --shared-memory=system or --shared-memory=cuda with " + "--service-kind=triton_c_api and --async and --streaming."); } params_->protocol = cb::ProtocolType::UNKNOWN; diff --git a/src/mock_profile_data_exporter.h b/src/mock_profile_data_exporter.h index 90e96d73..910310e6 100644 --- a/src/mock_profile_data_exporter.h +++ b/src/mock_profile_data_exporter.h @@ -59,6 +59,16 @@ class NaggyMockProfileDataExporter : public ProfileDataExporter { entry, experiment, raw_experiment); }); + ON_CALL( + *this, AddDataToJSON(testing::_, testing::_, testing::_, testing::_)) + .WillByDefault( + [this]( + rapidjson::Value& json, const uint8_t* buf, + const size_t byte_size, const std::string& data_type) -> void { + this->ProfileDataExporter::AddDataToJSON( + json, buf, byte_size, data_type); + }); + ON_CALL(*this, AddServiceKind(testing::_)) .WillByDefault([this](cb::BackendKind& service_kind) -> void { this->ProfileDataExporter::AddServiceKind(service_kind); @@ -82,6 +92,9 @@ class NaggyMockProfileDataExporter : public ProfileDataExporter { MOCK_METHOD( void, AddExperiment, (rapidjson::Value&, rapidjson::Value&, const Experiment&), (override)); + MOCK_METHOD( + void, AddDataToJSON, + (rapidjson::Value&, const uint8_t*, const size_t, const std::string&)); MOCK_METHOD(void, OutputToFile, (std::string&), (override)); MOCK_METHOD(void, AddServiceKind, (cb::BackendKind&)); MOCK_METHOD(void, AddEndpoint, (std::string&)); diff --git a/src/perf_analyzer.cc b/src/perf_analyzer.cc index c10101e1..f48e6300 100644 --- a/src/perf_analyzer.cc +++ b/src/perf_analyzer.cc @@ -165,6 +165,13 @@ PerfAnalyzer::CreateAnalyzerObjects() params_->async = true; } + if (parser_->IsDecoupled() && + (params_->async == false || params_->streaming == false)) { + throw pa::PerfAnalyzerException( + "Decoupled models must be run with `--async` and `--streaming` and " + "either `-i grpc` or `--service-kind=triton_c_api`"); + } + std::unique_ptr manager; if (params_->targeting_concurrency()) { if ((parser_->SchedulerType() == pa::ModelParser::SEQUENCE) || diff --git a/src/perf_analyzer_exception.h b/src/perf_analyzer_exception.h index a0b8ae70..b9f0788a 100644 --- a/src/perf_analyzer_exception.h +++ b/src/perf_analyzer_exception.h @@ -1,4 +1,4 @@ -// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -35,6 +35,8 @@ namespace triton { namespace perfanalyzer { // class PerfAnalyzerException : public std::exception { public: + PerfAnalyzerException(const std::string& message) : message_(message) {} + PerfAnalyzerException(uint32_t error) : error_(error) {} PerfAnalyzerException(const std::string& message, uint32_t error) @@ -48,7 +50,7 @@ class PerfAnalyzerException : public std::exception { private: const std::string message_{""}; - uint32_t error_; + uint32_t error_{GENERIC_ERROR}; }; }} // namespace triton::perfanalyzer diff --git a/src/perf_utils.cc b/src/perf_utils.cc index 6088c1b6..773c5e22 100644 --- a/src/perf_utils.cc +++ b/src/perf_utils.cc @@ -413,4 +413,40 @@ ParseTensorFormat(const std::string& content_type_str) } } +std::optional +GetDataTypeSize(const std::string& data_type) +{ + if (data_type == "BOOL") { + return sizeof(bool); + } else if (data_type == "UINT8") { + return sizeof(uint8_t); + } else if (data_type == "UINT16") { + return sizeof(uint16_t); + } else if (data_type == "UINT32") { + return sizeof(uint32_t); + } else if (data_type == "UINT64") { + return sizeof(uint64_t); + } else if (data_type == "INT8") { + return sizeof(int8_t); + } else if (data_type == "INT16") { + return sizeof(int16_t); + } else if (data_type == "INT32") { + return sizeof(int32_t); + } else if (data_type == "INT64") { + return sizeof(int64_t); + } else if (data_type == "FP32") { + return sizeof(float); + } else if (data_type == "FP64") { + return sizeof(double); + } else if (data_type == "BYTES") { + return sizeof(char); + } else if (data_type == "JSON") { + return sizeof(char); + } else { + std::cerr << "WARNING: unsupported data type: '" + data_type + "'" + << std::endl; + return {}; + } +} + }} // namespace triton::perfanalyzer diff --git a/src/perf_utils.h b/src/perf_utils.h index 6975d694..818a2223 100644 --- a/src/perf_utils.h +++ b/src/perf_utils.h @@ -1,4 +1,4 @@ -// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -36,6 +36,7 @@ #include #include #include +#include #include #include "client_backend/client_backend.h" @@ -137,4 +138,7 @@ std::function ScheduleDistribution( // Parse the HTTP tensor format cb::TensorFormat ParseTensorFormat(const std::string& tensor_format_str); +// Returns the size of a given data type in bytes. +std::optional GetDataTypeSize(const std::string& data_type); + }} // namespace triton::perfanalyzer diff --git a/src/profile_data_exporter.cc b/src/profile_data_exporter.cc index ea79d685..3af462a7 100644 --- a/src/profile_data_exporter.cc +++ b/src/profile_data_exporter.cc @@ -160,6 +160,81 @@ ProfileDataExporter::AddResponseTimestamps( } } +void +ProfileDataExporter::SetValueToJSON( + rapidjson::Value& json, const size_t index, const uint8_t* buf, + const size_t byte_size, const std::string& data_type) +{ + if (data_type == "BOOL") { + json.SetBool(reinterpret_cast(buf)[index]); + } else if (data_type == "UINT8") { + json.SetUint(reinterpret_cast(buf)[index]); + } else if (data_type == "UINT16") { + json.SetUint(reinterpret_cast(buf)[index]); + } else if (data_type == "UINT32") { + json.SetUint(reinterpret_cast(buf)[index]); + } else if (data_type == "UINT64") { + json.SetUint64(reinterpret_cast(buf)[index]); + } else if (data_type == "INT8") { + json.SetInt(reinterpret_cast(buf)[index]); + } else if (data_type == "INT16") { + json.SetInt(reinterpret_cast(buf)[index]); + } else if (data_type == "INT32") { + json.SetInt(reinterpret_cast(buf)[index]); + } else if (data_type == "INT64") { + json.SetInt64(reinterpret_cast(buf)[index]); + } else if (data_type == "FP32") { + json.SetFloat(reinterpret_cast(buf)[index]); + } else if (data_type == "FP64") { + json.SetDouble(reinterpret_cast(buf)[index]); + } else if (data_type == "BYTES" || data_type == "JSON") { + json.SetString( + reinterpret_cast(buf), byte_size, + document_.GetAllocator()); + } else { + std::cerr << "WARNING: data type '" + data_type + + "' is not supported with JSON." + << std::endl; + } +} + +void +ProfileDataExporter::AddDataToJSON( + rapidjson::Value& json, const uint8_t* buf, const size_t byte_size, + const std::string& data_type) +{ + // TPA-268: support N-dimensional tensor + size_t data_size; + // TODO TPA-283: Add support for N-dimensional string tensors + if (data_type == "BYTES" || data_type == "JSON") { + // return string as is instead of array of chars + data_size = 1; + } else { + const std::optional data_type_size{GetDataTypeSize(data_type)}; + if (!data_type_size) { + return; + } + data_size = byte_size / data_type_size.value(); + if (data_size > 1) { + json.SetArray(); + } + } + + if (buf != nullptr) { + if (json.IsArray()) { + for (int i = 0; i < data_size; i++) { + rapidjson::Value data_val; + SetValueToJSON(data_val, i, buf, byte_size, data_type); + json.PushBack(data_val, document_.GetAllocator()); + } + } else { + SetValueToJSON(json, 0 /* index */, buf, byte_size, data_type); + } + } else { + json.SetString("", 0, document_.GetAllocator()); + } +} + void ProfileDataExporter::AddRequestInputs( rapidjson::Value& request_inputs_json, @@ -172,27 +247,8 @@ ProfileDataExporter::AddRequestInputs( const auto& byte_size{input.second.size_}; const auto& data_type{input.second.data_type_}; rapidjson::Value name_json(name.c_str(), document_.GetAllocator()); - rapidjson::Value input_json{}; - // TMA-1777: support other data types - if (buf != nullptr) { - if (data_type == "BYTES" || data_type == "JSON") { - input_json.SetString( - reinterpret_cast(buf), byte_size, - document_.GetAllocator()); - } else if (data_type == "INT32") { - auto* val = reinterpret_cast(buf); - input_json.SetInt(*val); - } else if (data_type == "BOOL") { - bool is_true = (*buf > 0); - input_json.SetBool(is_true); - } else { - std::cerr << "WARNING: data type '" + data_type + - "' is not supported with JSON." - << std::endl; - } - } else { - input_json.SetString("", 0, document_.GetAllocator()); - } + rapidjson::Value input_json; + AddDataToJSON(input_json, buf, byte_size, data_type); request_inputs_json.AddMember( name_json, input_json, document_.GetAllocator()); } @@ -210,16 +266,10 @@ ProfileDataExporter::AddResponseOutputs( const auto& name{output.first}; const auto& buf{output.second.data_.get()}; const auto& byte_size{output.second.size_}; + const auto& data_type{output.second.data_type_}; rapidjson::Value name_json(name.c_str(), document_.GetAllocator()); - rapidjson::Value output_json{}; - // TMA-1777: support other data types - if (buf != nullptr) { - output_json.SetString( - reinterpret_cast(buf), byte_size, - document_.GetAllocator()); - } else { - output_json.SetString("", 0, document_.GetAllocator()); - } + rapidjson::Value output_json; + AddDataToJSON(output_json, buf, byte_size, data_type); response_output_json.AddMember( name_json, output_json, document_.GetAllocator()); } diff --git a/src/profile_data_exporter.h b/src/profile_data_exporter.h index 820148d7..6c2045e6 100644 --- a/src/profile_data_exporter.h +++ b/src/profile_data_exporter.h @@ -75,6 +75,12 @@ class ProfileDataExporter { void AddRequests( rapidjson::Value& entry, rapidjson::Value& requests, const Experiment& raw_experiment); + void SetValueToJSON( + rapidjson::Value& json, const size_t index, const uint8_t* buf, + const size_t byte_size, const std::string& data_type); + void AddDataToJSON( + rapidjson::Value& json, const uint8_t* buf, const size_t byte_size, + const std::string& data_type); void AddRequestInputs( rapidjson::Value& inputs_json, const std::vector& inputs); diff --git a/src/request_record.h b/src/request_record.h index 91b5ca19..bd5e4a55 100644 --- a/src/request_record.h +++ b/src/request_record.h @@ -35,7 +35,7 @@ namespace triton { namespace perfanalyzer { /// A record containing the data of a single request input or response output struct RecordData { - RecordData(const uint8_t* buf, size_t size, std::string data_type = "") + RecordData(const uint8_t* buf, size_t size, std::string data_type) { uint8_t* array = new uint8_t[size]; std::memcpy(array, buf, size); diff --git a/src/test_command_line_parser.cc b/src/test_command_line_parser.cc index 2d17bbc2..7738b130 100644 --- a/src/test_command_line_parser.cc +++ b/src/test_command_line_parser.cc @@ -579,7 +579,7 @@ TEST_CASE("Testing Command Line Parser") // CHECK_THROWS_WITH_AS( act = parser.Parse(argc, argv), - "Streaming is only allowed with gRPC protocol.", + "Streaming is only allowed with gRPC protocol and Triton C API.", PerfAnalyzerException); check_params = false; @@ -592,7 +592,7 @@ TEST_CASE("Testing Command Line Parser") CHECK_THROWS_WITH_AS( act = parser.Parse(argc, argv), - "Streaming is only allowed with gRPC protocol.", + "Streaming is only allowed with gRPC protocol and Triton C API.", PerfAnalyzerException); check_params = false; diff --git a/src/test_profile_data_collector.cc b/src/test_profile_data_collector.cc index 926a9015..167f7034 100644 --- a/src/test_profile_data_collector.cc +++ b/src/test_profile_data_collector.cc @@ -70,14 +70,14 @@ TEST_CASE("profile_data_collector: AddData") uint8_t fake_data_in[] = {0x01, 0x02, 0x03, 0x04}; uint8_t fake_data_out[] = {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}; RequestRecord::RequestInput request1_request_input{ - {"key1", RecordData(fake_data_in, 1)}, - {"key2", RecordData(fake_data_in, 2)}}; + {"key1", RecordData(fake_data_in, 1, "fake_datatype")}, + {"key2", RecordData(fake_data_in, 2, "fake_datatype")}}; RequestRecord::ResponseOutput request1_response1_output{ - {"key1", RecordData(fake_data_out, 1)}, - {"key2", RecordData(fake_data_out, 2)}}; + {"key1", RecordData(fake_data_out, 1, "fake_datatype")}, + {"key2", RecordData(fake_data_out, 2, "fake_datatype")}}; RequestRecord::ResponseOutput request1_response2_output{ - {"key3", RecordData(fake_data_out, 3)}, - {"key4", RecordData(fake_data_out, 4)}}; + {"key3", RecordData(fake_data_out, 3, "fake_datatype")}, + {"key4", RecordData(fake_data_out, 4, "fake_datatype")}}; RequestRecord request_record1{ request1_timestamp, @@ -95,14 +95,14 @@ TEST_CASE("profile_data_collector: AddData") auto request2_response1_timestamp{clock_epoch + nanoseconds(5)}; auto request2_response2_timestamp{clock_epoch + nanoseconds(6)}; RequestRecord::RequestInput request2_request_input{ - {"key3", RecordData(fake_data_in, 3)}, - {"key4", RecordData(fake_data_in, 4)}}; + {"key3", RecordData(fake_data_in, 3, "fake_datatype")}, + {"key4", RecordData(fake_data_in, 4, "fake_datatype")}}; RequestRecord::ResponseOutput request2_response1_output{ - {"key5", RecordData(fake_data_out, 5)}, - {"key6", RecordData(fake_data_out, 6)}}; + {"key5", RecordData(fake_data_out, 5, "fake_datatype")}, + {"key6", RecordData(fake_data_out, 6, "fake_datatype")}}; RequestRecord::ResponseOutput request2_response2_output{ - {"key7", RecordData(fake_data_out, 7)}, - {"key8", RecordData(fake_data_out, 8)}}; + {"key7", RecordData(fake_data_out, 7, "fake_datatype")}, + {"key8", RecordData(fake_data_out, 8, "fake_datatype")}}; RequestRecord request_record2{ request2_timestamp, diff --git a/src/test_profile_data_exporter.cc b/src/test_profile_data_exporter.cc index ffd958c5..71fb1493 100644 --- a/src/test_profile_data_exporter.cc +++ b/src/test_profile_data_exporter.cc @@ -70,17 +70,17 @@ TEST_CASE("profile_data_exporter: ConvertToJson") RequestRecord::ResponseOutput response_output1{ {"out_key1", {reinterpret_cast(out_bufs[0].data()), - out_bufs[0].size()}}, + out_bufs[0].size(), "BYTES"}}, {"out_key2", {reinterpret_cast(out_bufs[1].data()), - out_bufs[1].size()}}}; + out_bufs[1].size(), "BYTES"}}}; RequestRecord::ResponseOutput response_output2{ {"out_key3", {reinterpret_cast(out_bufs[2].data()), - out_bufs[2].size()}}, + out_bufs[2].size(), "BYTES"}}, {"out_key4", {reinterpret_cast(out_bufs[3].data()), - out_bufs[3].size()}}}; + out_bufs[3].size(), "BYTES"}}}; RequestRecord request_record{ request_timestamp, @@ -192,6 +192,139 @@ TEST_CASE("profile_data_exporter: ConvertToJson") CHECK(actual_version == expected_version); } +TEST_CASE("profile_data_exporter: AddDataToJSON") +{ + MockProfileDataExporter exporter{}; + rapidjson::Value json; + const uint8_t* buf; + + SUBCASE("Test bytes") + { + const std::string data{"abc123"}; + buf = reinterpret_cast(data.data()); + exporter.AddDataToJSON(json, buf, data.size(), "BYTES"); + CHECK(json == "abc123"); + } + + SUBCASE("Test json") + { + const std::string data{"{\"abc\":\"def\"}"}; + buf = reinterpret_cast(data.data()); + exporter.AddDataToJSON(json, buf, data.size(), "JSON"); + CHECK(json == "{\"abc\":\"def\"}"); + } + + SUBCASE("Test bool") + { + const bool data[3] = {true, false, true}; + buf = reinterpret_cast(data); + exporter.AddDataToJSON(json, buf, sizeof(data), "BOOL"); + CHECK(json[0] == true); + CHECK(json[1] == false); + CHECK(json[2] == true); + } + + SUBCASE("Test uint8") + { + const uint8_t data[3] = {1, 2, 3}; + buf = reinterpret_cast(data); + exporter.AddDataToJSON(json, buf, sizeof(data), "UINT8"); + CHECK(json[0] == 1); + CHECK(json[1] == 2); + CHECK(json[2] == 3); + } + + SUBCASE("Test uint16") + { + const uint16_t data[3] = {4, 5, 6}; + buf = reinterpret_cast(data); + exporter.AddDataToJSON(json, buf, sizeof(data), "UINT16"); + CHECK(json[0] == 4); + CHECK(json[1] == 5); + CHECK(json[2] == 6); + } + + SUBCASE("Test uint32") + { + const uint32_t data[3] = {7, 8, 9}; + buf = reinterpret_cast(data); + exporter.AddDataToJSON(json, buf, sizeof(data), "UINT32"); + CHECK(json[0] == 7); + CHECK(json[1] == 8); + CHECK(json[2] == 9); + } + + SUBCASE("Test uint64") + { + const uint64_t data[3] = {10, 11, 12}; + buf = reinterpret_cast(data); + exporter.AddDataToJSON(json, buf, sizeof(data), "UINT64"); + CHECK(json[0] == 10); + CHECK(json[1] == 11); + CHECK(json[2] == 12); + } + + SUBCASE("Test int8") + { + const int8_t data[3] = {1, -2, 3}; + buf = reinterpret_cast(data); + exporter.AddDataToJSON(json, buf, sizeof(data), "INT8"); + CHECK(json[0] == 1); + CHECK(json[1] == -2); + CHECK(json[2] == 3); + } + + SUBCASE("Test int16") + { + const int16_t data[3] = {4, -5, 6}; + buf = reinterpret_cast(data); + exporter.AddDataToJSON(json, buf, sizeof(data), "INT16"); + CHECK(json[0] == 4); + CHECK(json[1] == -5); + CHECK(json[2] == 6); + } + + SUBCASE("Test int32") + { + const int32_t data[3] = {7, -8, 9}; + buf = reinterpret_cast(data); + exporter.AddDataToJSON(json, buf, sizeof(data), "INT32"); + CHECK(json[0] == 7); + CHECK(json[1] == -8); + CHECK(json[2] == 9); + } + + SUBCASE("Test int64") + { + const int64_t data[3] = {10, -11, 12}; + buf = reinterpret_cast(data); + exporter.AddDataToJSON(json, buf, sizeof(data), "INT64"); + CHECK(json[0] == 10); + CHECK(json[1] == -11); + CHECK(json[2] == 12); + } + + SUBCASE("Test fp32") + { + const float data[3] = {1.0, -2.0, 3.0}; + buf = reinterpret_cast(data); + exporter.AddDataToJSON(json, buf, sizeof(data), "FP32"); + CHECK(json[0] == 1.0); + CHECK(json[1] == -2.0); + CHECK(json[2] == 3.0); + } + + SUBCASE("Test fp64") + { + const double data[3] = {4.0, -5.0, 6.0}; + buf = reinterpret_cast(data); + exporter.AddDataToJSON(json, buf, sizeof(data), "FP64"); + CHECK(json[0] == 4.0); + CHECK(json[1] == -5.0); + CHECK(json[2] == 6.0); + } +} + TEST_CASE("profile_data_exporter: AddExperiment") { MockProfileDataExporter exporter{};