diff --git a/audonnx/core/api.py b/audonnx/core/api.py index 1ddca1d..fdd0085 100644 --- a/audonnx/core/api.py +++ b/audonnx/core/api.py @@ -20,6 +20,7 @@ def load( typing.Tuple[str, typing.Dict], typing.Sequence[typing.Union[str, typing.Tuple[str, typing.Dict]]], ] = 'cpu', + num_workers: typing.Optional[int] = 1, auto_install: bool = False, ) -> Model: r"""Load model from folder. @@ -44,6 +45,9 @@ def load( device: set device (``'cpu'``, ``'cuda'``, or ``'cuda:'``) or a (list of) provider_ + num_workers: number of threads for running + onnxruntime inference on cpu. + If ``None`` onnxruntime chooses the number of threads auto_install: install missing packages needed to create the object .. _provider: https://onnxruntime.ai/docs/execution-providers/ @@ -85,6 +89,7 @@ def load( auto_install=auto_install, override_args={ 'device': device, + 'num_workers': num_workers, }, ) @@ -111,6 +116,7 @@ def load( labels=labels, transform=transform, device=device, + num_workers=num_workers, ) return model diff --git a/audonnx/core/model.py b/audonnx/core/model.py index 617bf3b..3d744d0 100644 --- a/audonnx/core/model.py +++ b/audonnx/core/model.py @@ -50,6 +50,9 @@ class Model(audobject.Object): device: set device (``'cpu'``, ``'cuda'``, or ``'cuda:'``) or a (list of) provider(s)_ + num_workers: number of threads for running + onnxruntime inference on cpu. + If ``None`` onnxruntime chooses the number of threads Examples: >>> import audiofile @@ -97,6 +100,7 @@ class Model(audobject.Object): }, hide=[ 'device', + 'num_workers', ], ) def __init__( @@ -106,6 +110,7 @@ def __init__( labels: Labels = None, transform: Transform = None, device: Device = 'cpu', + num_workers: typing.Optional[int] = 1, ): # keep original arguments to store them # when object is serialized @@ -117,9 +122,15 @@ def __init__( self.path = audeer.path(path) if isinstance(path, str) else None r"""Model path""" + session_options = onnxruntime.SessionOptions() + if num_workers is not None: + session_options.inter_op_num_threads = num_workers + session_options.intra_op_num_threads = num_workers + providers = device_to_providers(device) self.sess = onnxruntime.InferenceSession( self.path if isinstance(path, str) else path.SerializeToString(), + sess_options=session_options, providers=providers, ) r"""Interference session""" diff --git a/audonnx/core/testing.py b/audonnx/core/testing.py index a9c783d..e377982 100644 --- a/audonnx/core/testing.py +++ b/audonnx/core/testing.py @@ -15,6 +15,7 @@ def create_model( dtype: int = onnx.TensorProto.FLOAT, opset_version: int = 14, device: Device = 'cpu', + num_workers: typing.Optional[int] = 1, ) -> Model: r"""Create test model. @@ -36,6 +37,9 @@ def create_model( device: set device (``'cpu'``, ``'cuda'``, or ``'cuda:'``) or a (list of) provider(s)_ + num_workers: number of threads for running + onnxruntime inference on cpu. + If ``None`` onnxruntime chooses the number of threads Returns: model object @@ -99,6 +103,7 @@ def create_model( object, transform=transform, device=device, + num_workers=num_workers, ) return model diff --git a/tests/test_model.py b/tests/test_model.py index 6b657cb..61f33ad 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -228,6 +228,29 @@ def test_call_concat(model, outputs, expected): np.testing.assert_equal(y, expected) +@pytest.mark.parametrize( + 'device, num_workers', + [ + ('cpu', None), + ('cuda:0', None), + ('cpu', 1), + ('cuda:0', 1), + ('cpu', 2), + ('cuda:0', 2), + ] +) +def test_call_num_workers(device, num_workers): + model = audonnx.testing.create_model( + [[2]], device=device, num_workers=num_workers + ) + y = model( + pytest.SIGNAL, + pytest.SAMPLING_RATE, + ) + expected = np.array([0.0, 0.0], np.float32) + np.testing.assert_equal(y, expected) + + @pytest.mark.parametrize( 'device', [