diff --git a/audonnx/__init__.py b/audonnx/__init__.py index d6a7955..8d6ce5f 100644 --- a/audonnx/__init__.py +++ b/audonnx/__init__.py @@ -3,6 +3,7 @@ from audonnx.core.model import Model from audonnx.core.node import InputNode from audonnx.core.node import OutputNode +from audonnx.core.ort import device_to_providers __all__ = [] diff --git a/audonnx/core/model.py b/audonnx/core/model.py index 80d2916..617bf3b 100644 --- a/audonnx/core/model.py +++ b/audonnx/core/model.py @@ -1,5 +1,4 @@ import os -import re import typing import numpy as np @@ -12,6 +11,7 @@ from audonnx.core.node import InputNode from audonnx.core.node import OutputNode +from audonnx.core.ort import device_to_providers from audonnx.core.typing import Device from audonnx.core.typing import Labels from audonnx.core.typing import Transform @@ -117,7 +117,7 @@ def __init__( self.path = audeer.path(path) if isinstance(path, str) else None r"""Model path""" - providers = _device_to_providers(device) + providers = device_to_providers(device) self.sess = onnxruntime.InferenceSession( self.path if isinstance(path, str) else path.SerializeToString(), providers=providers, @@ -363,39 +363,6 @@ def to_yaml( super().to_yaml(fp, include_version=include_version) -def _device_to_providers( - device: typing.Union[ - str, - typing.Tuple[str, typing.Dict], - typing.Sequence[typing.Union[str, typing.Tuple[str, typing.Dict]]], - ], -) -> typing.Sequence[typing.Union[str, typing.Tuple[str, typing.Dict]]]: - r"""Converts device into a list of providers.""" - if isinstance(device, str): - if device == 'cpu': - providers = ['CPUExecutionProvider'] - elif device.startswith('cuda'): - match = re.search(r'^cuda:(\d+)$', device) - if match: - device_id = match.group(1) - providers = [ - ( - 'CUDAExecutionProvider', { - 'device_id': device_id, - } - ), - ] - else: - providers = ['CUDAExecutionProvider'] - else: - providers = [device] - elif isinstance(device, tuple): - providers = [device] - else: - providers = device - return providers - - def _concat( y: typing.Dict[str, np.ndarray], shapes: typing.Sequence[typing.List[int]], diff --git a/audonnx/core/ort.py b/audonnx/core/ort.py new file mode 100644 index 0000000..80cf88a --- /dev/null +++ b/audonnx/core/ort.py @@ -0,0 +1,52 @@ +import re +import typing + + +def device_to_providers( + device: typing.Union[ + str, + typing.Tuple[str, typing.Dict], + typing.Sequence[typing.Union[str, typing.Tuple[str, typing.Dict]]], + ], +) -> typing.Sequence[typing.Union[str, typing.Tuple[str, typing.Dict]]]: + r"""Converts device into a list of providers. + + Args: + device: ``'cpu'``, + ``'cuda'``, + ``'cuda:'``, + or a (list of) `ONNX Runtime Execution Providers`_ + + Returns: + sequence of `ONNX Runtime Execution Providers`_ + + Examples: + >>> device_to_providers('cpu') + ['CPUExecutionProvider'] + + .. _ONNX Runtime Execution Providers: https://onnxruntime.ai/docs/execution-providers/ + + """ + if isinstance(device, str): + if device == 'cpu': + providers = ['CPUExecutionProvider'] + elif device.startswith('cuda'): + match = re.search(r'^cuda:(\d+)$', device) + if match: + device_id = match.group(1) + providers = [ + ( + 'CUDAExecutionProvider', { + 'device_id': device_id, + } + ), + ] + else: + providers = ['CUDAExecutionProvider'] + else: + providers = [device] + elif isinstance(device, tuple): + providers = [device] + else: + providers = device + return providers diff --git a/docs/api-src/audonnx.rst b/docs/api-src/audonnx.rst index 70d7ec7..1f72bcc 100644 --- a/docs/api-src/audonnx.rst +++ b/docs/api-src/audonnx.rst @@ -11,4 +11,5 @@ audonnx Model InputNode OutputNode + device_to_providers load