Skip to content

Commit

Permalink
[onert/python] Allow calling compilation by infer session (#14530)
Browse files Browse the repository at this point in the history
This commit allows calling compilation(prepare) by infer session.
  - Introduce recreating internal session(nnfw_session) into BaseSession
  - Introduce compile function that can recreate internal session
  - Modify __getattr__ of BaseSession strictly
  - Rename nnpackage_path to path

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jan 10, 2025
1 parent 863750a commit 6905ebf
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 5 deletions.
29 changes: 27 additions & 2 deletions runtime/onert/api/python/package/common/basesession.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class BaseSession:
"""
Base class providing common functionality for inference and training sessions.
"""
def __init__(self, backend_session):
def __init__(self, backend_session=None):
"""
Initialize the BaseSession with a backend session.
Args:
Expand All @@ -33,7 +33,24 @@ def __getattr__(self, name):
Returns:
The attribute or method from the bound NNFW_SESSION instance.
"""
return getattr(self.session, name)
if name in self.__dict__:
# First, try to get the attribute from the instance's own dictionary
return self.__dict__[name]
elif hasattr(self.session, name):
# If not found, delegate to the session object
return getattr(self.session, name)
else:
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'")

def _recreate_session(self, backend_session):
"""
Protected method to recreate the session.
Subclasses can override this method to provide custom session recreation logic.
"""
if self.session is not None:
del self.session # Clean up the existing session
self.session = backend_session

def set_inputs(self, size, inputs_array=[]):
"""
Expand All @@ -42,6 +59,10 @@ def set_inputs(self, size, inputs_array=[]):
size (int): Number of input tensors.
inputs_array (list): List of numpy arrays for the input data.
"""
if self.session is None:
raise ValueError(
"Session is not initialized with a model. Please compile with a model before setting inputs."
)
for i in range(size):
input_tensorinfo = self.session.input_tensorinfo(i)

Expand All @@ -63,6 +84,10 @@ def set_outputs(self, size):
Args:
size (int): Number of output tensors.
"""
if self.session is None:
raise ValueError(
"Session is not initialized with a model. Please compile a model before setting outputs."
)
for i in range(size):
output_tensorinfo = self.session.output_tensorinfo(i)
output_array = np.zeros((num_elems(output_tensorinfo)),
Expand Down
25 changes: 22 additions & 3 deletions runtime/onert/api/python/package/infer/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,33 @@ class session(BaseSession):
"""
Class for inference using nnfw_session.
"""
def __init__(self, nnpackage_path, backends="cpu"):
def __init__(self, path: str = None, backends: str = "cpu"):
"""
Initialize the inference session.
Args:
nnpackage_path (str): Path to the nnpackage file or directory.
path (str): Path to the model file or nnpackage directory.
backends (str): Backends to use, default is "cpu".
"""
super().__init__(libnnfw_api_pybind.infer.nnfw_session(nnpackage_path, backends))
if path is not None:
super().__init__(libnnfw_api_pybind.infer.nnfw_session(path, backends))
self.session.prepare()
self.set_outputs(self.session.output_size())
else:
super().__init__()

def compile(self, path: str, backends: str = "cpu"):
"""
Prepare the session by recreating it with new parameters.
Args:
path (str): Path to the model file or nnpackage directory. Defaults to the existing path.
backends (str): Backends to use. Defaults to the existing backends.
"""
# Update parameters if provided
if path is None:
raise ValueError("path must not be None.")
# Recreate the session with updated parameters
self._recreate_session(libnnfw_api_pybind.infer.nnfw_session(path, backends))
# Prepare the new session
self.session.prepare()
self.set_outputs(self.session.output_size())

Expand Down

0 comments on commit 6905ebf

Please sign in to comment.