From 6905ebf0a7967826a4a068c48b2b74940a98a1f0 Mon Sep 17 00:00:00 2001 From: Jang Jiseob Date: Fri, 10 Jan 2025 10:39:12 +0900 Subject: [PATCH] [onert/python] Allow calling compilation by infer session (#14530) This commit allows calling compilation(prepare) by infer session. - Introduce recreating internal session(nnfw_session) into BaseSession - Introduce compile function that can recreate internal session - Modify __getattr__ of BaseSession strictly - Rename nnpackage_path to path ONE-DCO-1.0-Signed-off-by: ragmani --- .../api/python/package/common/basesession.py | 29 +++++++++++++++++-- .../onert/api/python/package/infer/session.py | 25 ++++++++++++++-- 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/runtime/onert/api/python/package/common/basesession.py b/runtime/onert/api/python/package/common/basesession.py index d9f62770474..1521efb452b 100644 --- a/runtime/onert/api/python/package/common/basesession.py +++ b/runtime/onert/api/python/package/common/basesession.py @@ -15,7 +15,7 @@ class BaseSession: """ Base class providing common functionality for inference and training sessions. """ - def __init__(self, backend_session): + def __init__(self, backend_session=None): """ Initialize the BaseSession with a backend session. Args: @@ -33,7 +33,24 @@ def __getattr__(self, name): Returns: The attribute or method from the bound NNFW_SESSION instance. """ - return getattr(self.session, name) + if name in self.__dict__: + # First, try to get the attribute from the instance's own dictionary + return self.__dict__[name] + elif hasattr(self.session, name): + # If not found, delegate to the session object + return getattr(self.session, name) + else: + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'") + + def _recreate_session(self, backend_session): + """ + Protected method to recreate the session. + Subclasses can override this method to provide custom session recreation logic. + """ + if self.session is not None: + del self.session # Clean up the existing session + self.session = backend_session def set_inputs(self, size, inputs_array=[]): """ @@ -42,6 +59,10 @@ def set_inputs(self, size, inputs_array=[]): size (int): Number of input tensors. inputs_array (list): List of numpy arrays for the input data. """ + if self.session is None: + raise ValueError( + "Session is not initialized with a model. Please compile with a model before setting inputs." + ) for i in range(size): input_tensorinfo = self.session.input_tensorinfo(i) @@ -63,6 +84,10 @@ def set_outputs(self, size): Args: size (int): Number of output tensors. """ + if self.session is None: + raise ValueError( + "Session is not initialized with a model. Please compile a model before setting outputs." + ) for i in range(size): output_tensorinfo = self.session.output_tensorinfo(i) output_array = np.zeros((num_elems(output_tensorinfo)), diff --git a/runtime/onert/api/python/package/infer/session.py b/runtime/onert/api/python/package/infer/session.py index da59c4ae653..e0ef4f7f8bb 100644 --- a/runtime/onert/api/python/package/infer/session.py +++ b/runtime/onert/api/python/package/infer/session.py @@ -6,14 +6,33 @@ class session(BaseSession): """ Class for inference using nnfw_session. """ - def __init__(self, nnpackage_path, backends="cpu"): + def __init__(self, path: str = None, backends: str = "cpu"): """ Initialize the inference session. Args: - nnpackage_path (str): Path to the nnpackage file or directory. + path (str): Path to the model file or nnpackage directory. backends (str): Backends to use, default is "cpu". """ - super().__init__(libnnfw_api_pybind.infer.nnfw_session(nnpackage_path, backends)) + if path is not None: + super().__init__(libnnfw_api_pybind.infer.nnfw_session(path, backends)) + self.session.prepare() + self.set_outputs(self.session.output_size()) + else: + super().__init__() + + def compile(self, path: str, backends: str = "cpu"): + """ + Prepare the session by recreating it with new parameters. + Args: + path (str): Path to the model file or nnpackage directory. Defaults to the existing path. + backends (str): Backends to use. Defaults to the existing backends. + """ + # Update parameters if provided + if path is None: + raise ValueError("path must not be None.") + # Recreate the session with updated parameters + self._recreate_session(libnnfw_api_pybind.infer.nnfw_session(path, backends)) + # Prepare the new session self.session.prepare() self.set_outputs(self.session.output_size())