From 2e3fc12f0d827ba816f7673f54f8e01e36f631e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20M=C3=BCller?= <38459088+jo-mueller@users.noreply.github.com> Date: Mon, 13 Nov 2023 17:17:44 +0100 Subject: [PATCH 1/3] passed viewer instance down through classes --- napari_bioimageio/_bmm.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/napari_bioimageio/_bmm.py b/napari_bioimageio/_bmm.py index 3af3e35..1a2c4d4 100644 --- a/napari_bioimageio/_bmm.py +++ b/napari_bioimageio/_bmm.py @@ -208,6 +208,7 @@ def __init__( downloaded, parent: QWidget = None, select_mode=False, + napari_viewer: 'napari.viewer.Viewer' = None ): super().__init__(parent) self.parent = parent @@ -217,6 +218,7 @@ def __init__( self.selected_version = versions[0] self.model_name = model_info["name"] self.model_description = model_info["description"] + self.viewer = napari_viewer nickname_icon = '' if "nickname_icon" in model_info: nickname_icon = model_info["nickname_icon"] @@ -341,11 +343,12 @@ def handle_action(self, model_info, action_name): self.parent.ui_parent.run_thread("select", model_info, self.selected_version) class QtModelList(QListWidget): - def __init__(self, parent, ui_parent, select_mode): + def __init__(self, parent, ui_parent, select_mode, napari_viewer: 'napari.viewer.Viewer' = None): super().__init__(parent) self.ui_parent = ui_parent self.select_mode = select_mode self.setSortingEnabled(True) + self.viewer = napari_viewer def addItem( self, @@ -361,6 +364,7 @@ def addItem( downloaded=downloaded, parent=self, select_mode=self.select_mode, + napari_viewer=self.viewer ) item.widget = widg @@ -368,7 +372,12 @@ def addItem( self.setItemWidget(item, widg) class QtBioImageIOModelManager(QDialog): - def __init__(self, parent=None, filter_id=None, filter_tag=None, select_mode=False): + def __init__(self, + parent=None, + filter_id=None, + filter_tag=None, + select_mode=False, + napari_viewer='napari.viewer.Viewer'): super().__init__(parent) self.setStyleSheet(custom_style) self.models_folder = _utils.get_models_path() @@ -377,6 +386,7 @@ def __init__(self, parent=None, filter_id=None, filter_tag=None, select_mode=Fal self.RUNNING = False self.select_mode = select_mode self.selected = None + self.viewer = napari_viewer self.filter_id = filter_id self.filter_tag = filter_tag self.setup_ui() @@ -572,7 +582,7 @@ def setup_ui(self): mid_layout.addWidget(self.downloaded_label) mid_layout.addStretch() lay.addLayout(mid_layout) - self.downloaded_list = QtModelList(downloaded, self, self.select_mode) + self.downloaded_list = QtModelList(downloaded, self, self.select_mode, napari_viewer=self.viewer) self.downloaded_list.setFixedHeight(250) lay.addWidget(self.downloaded_list) @@ -584,7 +594,7 @@ def setup_ui(self): mid_layout.addWidget(self.avail_label) mid_layout.addStretch() lay.addLayout(mid_layout) - self.available_list = QtModelList(available, self, False) + self.available_list = QtModelList(available, self, False, napari_viewer=self.viewer) self.available_list.setFixedHeight(250) lay.addWidget(self.available_list) @@ -626,8 +636,8 @@ def getvalidation(self): self.run_thread("validate") -def show_model_selector(filter_id=None, filter_tag=None): - d = QtBioImageIOModelManager(filter_id=filter_id, filter_tag=filter_tag, select_mode=True) +def show_model_selector(filter_id=None, filter_tag=None, viewer: 'napari.viewer.Viewer' = None): + d = QtBioImageIOModelManager(filter_id=filter_id, filter_tag=filter_tag, select_mode=True, napari_viewer=viewer) d.setObjectName("QtBioImageIOModelManager") d.setWindowTitle("BioImageIO Model Selector") d.setWindowModality(Qt.ApplicationModal) From 447d6ec82093f829922b85c938bc8d6526aac660 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20M=C3=BCller?= <38459088+jo-mueller@users.noreply.github.com> Date: Mon, 13 Nov 2023 17:18:34 +0100 Subject: [PATCH 2/3] added `run_inference` model to run model based on rdf_path --- napari_bioimageio/_inference.py | 50 +++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 napari_bioimageio/_inference.py diff --git a/napari_bioimageio/_inference.py b/napari_bioimageio/_inference.py new file mode 100644 index 0000000..f346662 --- /dev/null +++ b/napari_bioimageio/_inference.py @@ -0,0 +1,50 @@ +def run_inference(image: 'napari.types.ImageData', + rdf_path: str, + halo: int = 16) -> 'napari.types.LayerDataTuple': + """Run inference on a napari image using a bioimage.io model. + + Parameters + ---------- + image : napari.types.ImageData + Image to run inference on. + rdf_path : str + Path to the model RDF file. + + Returns + ------- + napari.types.LayerDataTuple + Inference result. + """ + from bioimageio.core import (load_resource_description, + predict_with_tiling, + create_prediction_pipeline) + import xarray as xr + + model_resource = load_resource_description(rdf_path) + + prediction_pipeline = create_prediction_pipeline( + model_resource, devices=None, weight_format=None + ) + + tiling = {"tile": {"x": model_resource.inputs[0].shape[-1], + "y": model_resource.inputs[0].shape[-2], + "z": model_resource.inputs[0].shape[-3]}, + "halo": {"x": halo, "y": halo}} + + input_array = xr.DataArray( + image, + dims=tuple(model_resource.inputs[0].axes)) + + # run prediction and throw clear error message if dimensions don't match + prediction = predict_with_tiling(prediction_pipeline, + input_array, + tiling=tiling, + verbose=True) + properties = { + 'name': 'prediction', + 'colormap': 'inferno', + 'blending': 'additive', + 'opacity': 0.5 + } + + return (prediction, properties, 'image') From 9dc56ed6fadcf762a84b987cb79362c94f444683 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20M=C3=BCller?= <38459088+jo-mueller@users.noreply.github.com> Date: Mon, 13 Nov 2023 17:18:52 +0100 Subject: [PATCH 3/3] Added run action to possible actions for downloaded models --- napari_bioimageio/_bmm.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/napari_bioimageio/_bmm.py b/napari_bioimageio/_bmm.py index 1a2c4d4..2b6b32f 100644 --- a/napari_bioimageio/_bmm.py +++ b/napari_bioimageio/_bmm.py @@ -3,7 +3,7 @@ import napari.resources from napari._qt.qt_resources import QColoredSVGIcon, get_stylesheet -from qtpy.QtCore import QObject, QSize, Qt, QThread, Signal +from qtpy.QtCore import QObject, QSize, Qt, QThread, Signal, QEvent from qtpy.QtGui import QFont, QMovie from qtpy.QtWidgets import ( QAction, @@ -26,7 +26,8 @@ QMessageBox, ) from superqt import QElidingLabel - +from ._inference import run_inference +from magicgui import magicgui from . import _utils # TODO find a proper way to import style from napari @@ -302,6 +303,10 @@ def change_version(index): inspectAction.triggered.connect(lambda: self.handle_action(self.model_info, "inspect")) action_menu.addAction(inspectAction) + applyAction = QAction('Apply', self) + applyAction.triggered.connect(lambda: self.handle_action(self.model_info, "apply")) + action_menu.addAction(applyAction) + if self.select_mode: selectAction = QAction('Select', self) selectAction.triggered.connect(lambda: self.handle_action(self.model_info, "select")) @@ -341,6 +346,15 @@ def handle_action(self, model_info, action_name): self.parent.ui_parent.run_thread("inspect", model_info, self.selected_version) elif action_name == "select": self.parent.ui_parent.run_thread("select", model_info, self.selected_version) + elif action_name == "apply": + from functools import partial + run_inference_partial = partial( + run_inference, + rdf_path=model_info["id"]) + widget = magicgui(run_inference_partial) + run_inference_partial.__name__ = model_info["name"] + ' predictor' + self.viewer.window.add_dock_widget(widget, area='right') + class QtModelList(QListWidget): def __init__(self, parent, ui_parent, select_mode, napari_viewer: 'napari.viewer.Viewer' = None):