Skip to content

Commit

Permalink
Add num_workers to audonnx.Model and audonnx.load()
Browse files Browse the repository at this point in the history
  • Loading branch information
audeerington committed Dec 15, 2023
1 parent 8ea3b90 commit a3a2c49
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 0 deletions.
6 changes: 6 additions & 0 deletions audonnx/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def load(
typing.Tuple[str, typing.Dict],
typing.Sequence[typing.Union[str, typing.Tuple[str, typing.Dict]]],
] = 'cpu',
num_workers: typing.Optional[int] = 1,
auto_install: bool = False,
) -> Model:
r"""Load model from folder.
Expand All @@ -44,6 +45,9 @@ def load(
device: set device
(``'cpu'``, ``'cuda'``, or ``'cuda:<id>'``)
or a (list of) provider_
num_workers: number of threads for running
onnxruntime inference on cpu.
If ``None`` onnxruntime chooses the number of threads
auto_install: install missing packages needed to create the object
.. _provider: https://onnxruntime.ai/docs/execution-providers/
Expand Down Expand Up @@ -85,6 +89,7 @@ def load(
auto_install=auto_install,
override_args={
'device': device,
'num_workers': num_workers,
},
)

Expand All @@ -111,6 +116,7 @@ def load(
labels=labels,
transform=transform,
device=device,
num_workers=num_workers,
)

return model
11 changes: 11 additions & 0 deletions audonnx/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Model(audobject.Object):
device: set device
(``'cpu'``, ``'cuda'``, or ``'cuda:<id>'``)
or a (list of) provider(s)_
num_workers: number of threads for running
onnxruntime inference on cpu.
If ``None`` onnxruntime chooses the number of threads
Examples:
>>> import audiofile
Expand Down Expand Up @@ -97,6 +100,7 @@ class Model(audobject.Object):
},
hide=[
'device',
'num_workers',
],
)
def __init__(
Expand All @@ -106,6 +110,7 @@ def __init__(
labels: Labels = None,
transform: Transform = None,
device: Device = 'cpu',
num_workers: typing.Optional[int] = 1,
):
# keep original arguments to store them
# when object is serialized
Expand All @@ -117,9 +122,15 @@ def __init__(
self.path = audeer.path(path) if isinstance(path, str) else None
r"""Model path"""

session_options = onnxruntime.SessionOptions()
if num_workers is not None:
session_options.inter_op_num_threads = num_workers
session_options.intra_op_num_threads = num_workers

providers = device_to_providers(device)
self.sess = onnxruntime.InferenceSession(
self.path if isinstance(path, str) else path.SerializeToString(),
sess_options=session_options,
providers=providers,
)
r"""Interference session"""
Expand Down
5 changes: 5 additions & 0 deletions audonnx/core/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def create_model(
dtype: int = onnx.TensorProto.FLOAT,
opset_version: int = 14,
device: Device = 'cpu',
num_workers: typing.Optional[int] = 1,
) -> Model:
r"""Create test model.
Expand All @@ -36,6 +37,9 @@ def create_model(
device: set device
(``'cpu'``, ``'cuda'``, or ``'cuda:<id>'``)
or a (list of) provider(s)_
num_workers: number of threads for running
onnxruntime inference on cpu.
If ``None`` onnxruntime chooses the number of threads
Returns:
model object
Expand Down Expand Up @@ -99,6 +103,7 @@ def create_model(
object,
transform=transform,
device=device,
num_workers=num_workers,
)

return model
Expand Down
23 changes: 23 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,29 @@ def test_call_concat(model, outputs, expected):
np.testing.assert_equal(y, expected)


@pytest.mark.parametrize(
'device, num_workers',
[
('cpu', None),
('cuda:0', None),
('cpu', 1),
('cuda:0', 1),
('cpu', 2),
('cuda:0', 2),
]
)
def test_call_num_workers(device, num_workers):
model = audonnx.testing.create_model(
[[2]], device=device, num_workers=num_workers
)
y = model(
pytest.SIGNAL,
pytest.SAMPLING_RATE,
)
expected = np.array([0.0, 0.0], np.float32)
np.testing.assert_equal(y, expected)


@pytest.mark.parametrize(
'device',
[
Expand Down

0 comments on commit a3a2c49

Please sign in to comment.