Skip to content

Commit

Permalink
Collect telemetry metrics from Triton metrics endpoint (#26)
Browse files Browse the repository at this point in the history
* Collect telemetry metrics from Triton metrics endpoint

* Remove one of the print statements

* Fix comments

* Fix pre-commit errors

* Fix test errors

* Add unit tests and fix code

* Fix pre-commit error

* Fix codeql warnings

* Fix comments
  • Loading branch information
lkomali authored Aug 14, 2024
1 parent a812e25 commit e67f9ca
Show file tree
Hide file tree
Showing 13 changed files with 704 additions and 13 deletions.
2 changes: 1 addition & 1 deletion genai-perf/genai_perf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

DEFAULT_HTTP_URL = "localhost:8000"
DEFAULT_GRPC_URL = "localhost:8001"

DEFAULT_TRITON_METRICS_URL = "localhost:8002/metrics"

OPEN_ORCA = "openorca"
CNN_DAILY_MAIL = "cnn_dailymail"
Expand Down
1 change: 1 addition & 0 deletions genai-perf/genai_perf/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def report_output(data_parser: ProfileDataParser, args: Namespace) -> None:
else:
raise GenAIPerfException("No valid infer mode specified")

# TPA-274 - Integrate telemetry metrics with other metrics for export
stats = data_parser.get_statistics(infer_mode, load_level)
reporter = OutputReporter(stats, args)
reporter.report_output()
Expand Down
1 change: 1 addition & 0 deletions genai-perf/genai_perf/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from genai_perf.metrics.llm_metrics import LLMMetrics
from genai_perf.metrics.metrics import MetricMetadata, Metrics
from genai_perf.metrics.statistics import Statistics
from genai_perf.metrics.telemetry_metrics import TelemetryMetrics
82 changes: 82 additions & 0 deletions genai-perf/genai_perf/metrics/telemetry_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#!/usr/bin/env python3

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import List

from genai_perf.metrics.metrics import MetricMetadata


class TelemetryMetrics:
"""
A class that contains common telemetry metrics.
Metrics are stored as lists where each inner list corresponds to multiple measurements per GPU.
Each measurement is recorded every second.
"""

TELEMETRY_METRICS = [
MetricMetadata("gpu_power_usage", "watts"),
MetricMetadata("gpu_power_limit", "watts"),
MetricMetadata("energy_consumption", "joules"),
MetricMetadata("gpu_utilization", "percentage"),
MetricMetadata("total_gpu_memory", "bytes"),
MetricMetadata("gpu_memory_used", "bytes"),
]

def __init__(
self,
gpu_power_usage: List[List[float]] = [], # Multiple measurements per GPU
gpu_power_limit: List[List[float]] = [],
energy_consumption: List[List[float]] = [],
gpu_utilization: List[List[float]] = [],
total_gpu_memory: List[List[float]] = [],
gpu_memory_used: List[List[float]] = [],
) -> None:
self.gpu_power_usage = gpu_power_usage
self.gpu_power_limit = gpu_power_limit
self.energy_consumption = energy_consumption
self.gpu_utilization = gpu_utilization
self.total_gpu_memory = total_gpu_memory
self.gpu_memory_used = gpu_memory_used

def update_metrics(self, measurement_data: dict) -> None:
"""Update the metrics with new measurement data"""
for metric in self.TELEMETRY_METRICS:
metric_key = metric.name
if metric_key in measurement_data:
getattr(self, metric_key).append(measurement_data[metric_key])

def __repr__(self):
attr_strs = []
for k, v in self.__dict__.items():
if not k.startswith("_"):
attr_strs.append(f"{k}={v}")
return f"TelemetryMetrics({','.join(attr_strs)})"

@property
def telemetry_metrics(self) -> List[MetricMetadata]:
return self.TELEMETRY_METRICS
17 changes: 16 additions & 1 deletion genai-perf/genai_perf/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
CNN_DAILY_MAIL,
DEFAULT_ARTIFACT_DIR,
DEFAULT_COMPARE_DIR,
DEFAULT_TRITON_METRICS_URL,
OPEN_ORCA,
)
from genai_perf.llm_inputs.llm_inputs import (
Expand Down Expand Up @@ -765,9 +766,23 @@ def compare_handler(args: argparse.Namespace):


def profile_handler(args, extra_args):
from genai_perf.telemetry_data.triton_telemetry_data_collector import (
TritonTelemetryDataCollector,
)
from genai_perf.wrapper import Profiler

Profiler.run(args=args, extra_args=extra_args)
telemetry_data_collector = None
if args.service_kind == "triton":
# TPA-275: pass server url as a CLI option in non-default case
telemetry_data_collector = TritonTelemetryDataCollector(
server_metrics_url=DEFAULT_TRITON_METRICS_URL
)

Profiler.run(
args=args,
extra_args=extra_args,
telemetry_data_collector=telemetry_data_collector,
)


### Parser Initialization ###
Expand Down
27 changes: 27 additions & 0 deletions genai-perf/genai_perf/telemetry_data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from genai_perf.telemetry_data.telemetry_data_collector import TelemetryDataCollector
83 changes: 83 additions & 0 deletions genai-perf/genai_perf/telemetry_data/telemetry_data_collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/usr/bin/env python3

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


import time
from abc import ABC, abstractmethod
from threading import Event, Thread
from typing import Optional

import requests
from genai_perf.metrics.telemetry_metrics import TelemetryMetrics


class TelemetryDataCollector(ABC):
def __init__(
self, server_metrics_url: str, collection_interval: float = 1.0 # in seconds
) -> None:
self._server_metrics_url = server_metrics_url
self._collection_interval = collection_interval
self._metrics = TelemetryMetrics()
self._stop_event = Event()
self._thread: Optional[Thread] = None

def start(self) -> None:
"""Start the telemetry data collection thread."""
if self._thread is None or not self._thread.is_alive():
self._stop_event.clear()
self._thread = Thread(target=self._collect_metrics)
self._thread.start()

def stop(self) -> None:
"""Stop the telemetry data collection thread."""
if self._thread is not None and self._thread.is_alive():
self._stop_event.set()
self._thread.join()

def _fetch_metrics(self) -> str:
"""Fetch metrics from the metrics endpoint"""
response = requests.get(self._server_metrics_url)
response.raise_for_status()
return response.text

@abstractmethod
def _process_and_update_metrics(self, metrics_data: str) -> None:
"""This method should be implemented by subclasses."""
pass

def _collect_metrics(self) -> None:
"""Continuously collect telemetry metrics at for every second"""
while not self._stop_event.is_set():
metrics_data = self._fetch_metrics()
self._process_and_update_metrics(metrics_data)
time.sleep(self._collection_interval)

@property
def metrics(self) -> TelemetryMetrics:
"""Return the collected metrics."""
return self._metrics
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python3

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import Dict, List

import genai_perf.logging as logging
from genai_perf.telemetry_data.telemetry_data_collector import TelemetryDataCollector

logger = logging.getLogger(__name__)


class TritonTelemetryDataCollector(TelemetryDataCollector):
"""Class to collect telemetry metrics from Triton server"""

"""Mapping from Triton metric names to GenAI-Perf telemetry metric names"""
METRIC_NAME_MAPPING = {
"nv_gpu_power_usage": "gpu_power_usage",
"nv_gpu_power_limit": "gpu_power_limit",
"nv_energy_consumption": "energy_consumption",
"nv_gpu_utilization": "gpu_utilization",
"nv_gpu_memory_total_bytes": "total_gpu_memory",
"nv_gpu_memory_used_bytes": "gpu_memory_used",
}

def _process_and_update_metrics(self, metrics_data: str) -> None:
"""Process the response from Triton metrics endpoint and update metrics.
This method extracts metric names and values from the raw data. Metric names
are extracted from the start of each line up to the '{' character, as all metrics
follow the format 'metric_name{labels} value'. Only metrics defined in
METRIC_NAME_MAPPING are processed.
Args:
data (str): Raw metrics data from the Triton endpoint.
Example:
Given the metric data:
```
nv_gpu_power_usage{gpu_uuid="GPU-abschdinjacgdo65gdj7"} 27.01
nv_gpu_utilization{gpu_uuid="GPU-abcdef123456"} 75.5
nv_energy_consumption{gpu_uuid="GPU-xyz789"} 1234.56
```
The method will extract and process:
- `nv_gpu_power_usage` as `gpu_power_usage`
- `nv_gpu_utilization` as `gpu_utilization`
- `nv_energy_consumption` as `energy_consumption`
"""

if not metrics_data.strip():
logger.info("Response from Triton metrics endpoint is empty")
return

current_measurement_interval = {
metric.name: [] for metric in self.metrics.TELEMETRY_METRICS
} # type: Dict[str, List[float]]

for line in metrics_data.splitlines():
line = line.strip()
if not line:
continue

parts = line.split()
if len(parts) < 2:
continue

triton_metric_key = parts[0].split("{")[0]
metric_value = parts[1]

metric_key = self.METRIC_NAME_MAPPING.get(triton_metric_key, None)

if metric_key and metric_key in current_measurement_interval:
current_measurement_interval[metric_key].append(float(metric_value))

self.metrics.update_metrics(current_measurement_interval)
27 changes: 20 additions & 7 deletions genai-perf/genai_perf/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
import genai_perf.utils as utils
from genai_perf.constants import DEFAULT_GRPC_URL, DEFAULT_INPUT_DATA_JSON
from genai_perf.llm_inputs.llm_inputs import OutputFormat
from genai_perf.telemetry_data.triton_telemetry_data_collector import (
TelemetryDataCollector,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,10 +145,20 @@ def build_cmd(args: Namespace, extra_args: Optional[List[str]] = None) -> List[s
return cmd

@staticmethod
def run(args: Namespace, extra_args: Optional[List[str]]) -> None:
cmd = Profiler.build_cmd(args, extra_args)
logger.info(f"Running Perf Analyzer : '{' '.join(cmd)}'")
if args and args.verbose:
subprocess.run(cmd, check=True, stdout=None)
else:
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL)
def run(
args: Namespace,
extra_args: Optional[List[str]],
telemetry_data_collector: Optional[TelemetryDataCollector] = None,
) -> None:
try:
if telemetry_data_collector is not None:
telemetry_data_collector.start()
cmd = Profiler.build_cmd(args, extra_args)
logger.info(f"Running Perf Analyzer : '{' '.join(cmd)}'")
if args and args.verbose:
subprocess.run(cmd, check=True, stdout=None)
else:
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL)
finally:
if telemetry_data_collector is not None:
telemetry_data_collector.stop()
Loading

0 comments on commit e67f9ca

Please sign in to comment.