Skip to content

Commit

Permalink
Add audonnx.device_to_providers() (#77)
Browse files Browse the repository at this point in the history
* Add audonnx.device_to_providers()

* Fix docstring

* Add link to docstring
  • Loading branch information
hagenw authored Dec 14, 2023
1 parent 33faab7 commit 8ea3b90
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 35 deletions.
1 change: 1 addition & 0 deletions audonnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from audonnx.core.model import Model
from audonnx.core.node import InputNode
from audonnx.core.node import OutputNode
from audonnx.core.ort import device_to_providers


__all__ = []
Expand Down
37 changes: 2 additions & 35 deletions audonnx/core/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import re
import typing

import numpy as np
Expand All @@ -12,6 +11,7 @@

from audonnx.core.node import InputNode
from audonnx.core.node import OutputNode
from audonnx.core.ort import device_to_providers
from audonnx.core.typing import Device
from audonnx.core.typing import Labels
from audonnx.core.typing import Transform
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
self.path = audeer.path(path) if isinstance(path, str) else None
r"""Model path"""

providers = _device_to_providers(device)
providers = device_to_providers(device)
self.sess = onnxruntime.InferenceSession(
self.path if isinstance(path, str) else path.SerializeToString(),
providers=providers,
Expand Down Expand Up @@ -363,39 +363,6 @@ def to_yaml(
super().to_yaml(fp, include_version=include_version)


def _device_to_providers(
device: typing.Union[
str,
typing.Tuple[str, typing.Dict],
typing.Sequence[typing.Union[str, typing.Tuple[str, typing.Dict]]],
],
) -> typing.Sequence[typing.Union[str, typing.Tuple[str, typing.Dict]]]:
r"""Converts device into a list of providers."""
if isinstance(device, str):
if device == 'cpu':
providers = ['CPUExecutionProvider']
elif device.startswith('cuda'):
match = re.search(r'^cuda:(\d+)$', device)
if match:
device_id = match.group(1)
providers = [
(
'CUDAExecutionProvider', {
'device_id': device_id,
}
),
]
else:
providers = ['CUDAExecutionProvider']
else:
providers = [device]
elif isinstance(device, tuple):
providers = [device]
else:
providers = device
return providers


def _concat(
y: typing.Dict[str, np.ndarray],
shapes: typing.Sequence[typing.List[int]],
Expand Down
52 changes: 52 additions & 0 deletions audonnx/core/ort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import re
import typing


def device_to_providers(
device: typing.Union[
str,
typing.Tuple[str, typing.Dict],
typing.Sequence[typing.Union[str, typing.Tuple[str, typing.Dict]]],
],
) -> typing.Sequence[typing.Union[str, typing.Tuple[str, typing.Dict]]]:
r"""Converts device into a list of providers.
Args:
device: ``'cpu'``,
``'cuda'``,
``'cuda:<id>'``,
or a (list of) `ONNX Runtime Execution Providers`_
Returns:
sequence of `ONNX Runtime Execution Providers`_
Examples:
>>> device_to_providers('cpu')
['CPUExecutionProvider']
.. _ONNX Runtime Execution Providers: https://onnxruntime.ai/docs/execution-providers/
"""
if isinstance(device, str):
if device == 'cpu':
providers = ['CPUExecutionProvider']
elif device.startswith('cuda'):
match = re.search(r'^cuda:(\d+)$', device)
if match:
device_id = match.group(1)
providers = [
(
'CUDAExecutionProvider', {
'device_id': device_id,
}
),
]
else:
providers = ['CUDAExecutionProvider']
else:
providers = [device]
elif isinstance(device, tuple):
providers = [device]
else:
providers = device
return providers
1 change: 1 addition & 0 deletions docs/api-src/audonnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ audonnx
Model
InputNode
OutputNode
device_to_providers
load

0 comments on commit 8ea3b90

Please sign in to comment.