Skip to content

Commit

Permalink
Merge pull request #312 from imagej/new-button-axes
Browse files Browse the repository at this point in the history
WIP: New image, v2
  • Loading branch information
gselzer authored Nov 25, 2024
2 parents 3fde6d2 + c380d96 commit 5b3cbc1
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 48 deletions.
15 changes: 12 additions & 3 deletions src/napari_imagej/types/converters/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,27 @@ def _image_layer_to_dataset(image: Image, **kwargs) -> "jc.Dataset":
:param image: a napari Image layer
:return: a Dataset
"""
# Redefine dimension order if necessary
data = image.data
if hasattr(data, "dims"):
if "dim_order" in kwargs:
dim_remapping = {
old: new for old, new in zip(data.dims, kwargs["dim_order"])
}
data = data.rename(dim_remapping)
kwargs.pop("dim_order")
# Define dimension order if necessary
if "dim_order" not in kwargs:
elif "dim_order" not in kwargs:
# NB "dim_i"s will be overwritten later
dim_order = [f"dim_{i}" for i in range(len(image.data.shape))]
dim_order = [f"dim_{i}" for i in range(len(data.shape))]
# if RGB, last dimension is Channel
if image.rgb:
dim_order[-1] = "Channel"

kwargs["dim_order"] = dim_order

# Construct a dataset from the data
dataset: "jc.Dataset" = nij.ij.py.to_dataset(image.data, **kwargs)
dataset: "jc.Dataset" = nij.ij.py.to_dataset(data, **kwargs)

# Clean up the axes
axes = [
Expand Down
103 changes: 87 additions & 16 deletions src/napari_imagej/widgets/parameter_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
They should align with a SciJava ModuleItem that satisfies some set of conditions.
"""

from __future__ import annotations

import importlib
from dataclasses import dataclass
from functools import lru_cache
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TYPE_CHECKING

from imagej.images import _imglib2_types
from jpype import JArray, JClass, JInt, JLong
Expand All @@ -24,13 +26,29 @@
request_values,
)
from napari import current_viewer
from napari.layers import Layer
from napari.layers import Layer, Image
from napari.utils._magicgui import get_layers
from numpy import dtype
from scyjava import numeric_bounds

from napari_imagej.java import jc

if TYPE_CHECKING:
from typing import Sequence


# Generally, Python libraries treat the dimensions i of an array with n
# dimensions as the CONVENTIONAL_DIMS[n][i] axis
# FIXME: Also in widget_utils.py
CONVENTIONAL_DIMS = [
[],
["X"],
["Y", "X"],
["Z", "Y", "X"],
["Time", "Y", "X", "Channel"],
["Time", "Z", "Y", "X", "Channel"],
]


def widget_supported_java_types() -> List[JClass]:
"""
Expand Down Expand Up @@ -147,6 +165,40 @@ def value(self, value: Any):
return Widget


class MutableDimWidget(Container):
def __init__(
self,
val: int = 256,
idx: int = 0,
**kwargs,
):
layer_tooltip = f"Parameters for the dimension of index {idx}"
self.size_spin = SpinBox(value=val, min=1)
choices = ["X", "Y", "Z", "C", "T"]
layer_kwargs = kwargs.copy()
layer_kwargs["nullable"] = True
self.layer_select = ComboBox(
choices=choices, tooltip=layer_tooltip, **layer_kwargs
)
self._nullable = True
kwargs["widgets"] = [self.size_spin, self.layer_select]
kwargs["labels"] = False
kwargs["layout"] = "horizontal"
kwargs.pop("value")
kwargs.pop("nullable")
super().__init__(**kwargs)
self.margins = (0, 0, 0, 0)

@property
def value(self) -> tuple[int, str]:
return (self.size_spin.value, self.layer_select.value)

@value.setter
def value(self, v: tuple[int, str]) -> None:
self.size_spin.value = v[0]
self.layer_select.value = v[1]


class MutableOutputWidget(Container):
"""
A ComboBox widget combined with a button that creates new layers.
Expand Down Expand Up @@ -205,12 +257,23 @@ def _default_layer(self) -> Optional[Layer]:
selection_name = widget.current_choice
if selection_name != "":
return current_viewer().layers[selection_name]
return None

def _default_new_shape(self):
def _default_new_shape(self) -> Sequence[tuple[int, str]]:
guess = self._default_layer()
if guess:
return guess.data.shape
return [512, 512]
data = guess.data
# xarray has dims, otherwise use conventions
if hasattr(data, "dims"):
dims = data.dims
# Special case: RGB
elif isinstance(guess, Image) and guess.rgb:
dims = list(CONVENTIONAL_DIMS[len(data.shape) - 1])
dims.append("C")
else:
dims = list(CONVENTIONAL_DIMS[len(data.shape)])
return [t for t in zip(data.shape, dims)]
return [(512, "Y"), (512, "X")]

def _default_new_type(self) -> str:
"""
Expand All @@ -237,12 +300,10 @@ def create_new_image(self) -> None:
"""

# Array types that are always included
backing_choices = ["NumPy"]
backing_choices = ["xarray", "NumPy"]
# Array types that may be present
if importlib.util.find_spec("zarr"):
backing_choices.append("Zarr")
if importlib.util.find_spec("xarray"):
backing_choices.append("xarray")

# Define the magicgui widget for parameter harvesting
params = request_values(
Expand All @@ -253,16 +314,19 @@ def create_new_image(self) -> None:
options=dict(tooltip="If blank, a name will be generated"),
),
shape=dict(
annotation=List[int],
annotation=List[MutableDimWidget],
value=self._default_new_shape(),
options=dict(
tooltip="By default, the shape of the first Layer input",
options=dict(min=0, max=2**31 - 10),
layout="vertical",
options=dict(
widget_type=MutableDimWidget,
),
tooltip="The size of each image axis",
),
),
array_type=dict(
annotation=str,
value="NumPy",
value="xarray",
options=dict(
tooltip="The backing data array implementation",
choices=backing_choices,
Expand Down Expand Up @@ -295,7 +359,7 @@ def _add_new_image(self, params: dict):
import numpy as np

data = np.full(
shape=tuple(params["shape"]),
shape=tuple(p[0] for p in params["shape"]),
fill_value=params["fill_value"],
dtype=params["data_type"],
)
Expand All @@ -305,7 +369,7 @@ def _add_new_image(self, params: dict):
import zarr

data = zarr.full(
shape=params["shape"],
shape=tuple(p[0] for p in params["shape"]),
fill_value=params["fill_value"],
dtype=params["data_type"],
)
Expand All @@ -315,12 +379,19 @@ def _add_new_image(self, params: dict):
import numpy as np
import xarray

dims = tuple(p[1] for p in params["shape"])
# Ensure every dimension is unique
if len(dims) != len(set(dims)):
# TODO: Ideally this would be prevented in the widget
raise ValueError("Cannot have repeated dimensions")

data = xarray.DataArray(
data=np.full(
shape=tuple(params["shape"]),
shape=tuple(p[0] for p in params["shape"]),
fill_value=params["fill_value"],
dtype=params["data_type"],
)
),
dims=tuple(p[1] for p in params["shape"]),
)

# give the data array to the viewer.
Expand Down
29 changes: 16 additions & 13 deletions src/napari_imagej/widgets/widget_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
info_for,
)

# Generally, Python libraries treat the dimensions i of an array with n
# dimensions as the CONVENTIONAL_DIMS[n][i] axis
CONVENTIONAL_DIMS = [
[],
["X"],
["Y", "X"],
["Z", "Y", "X"],
["Time", "Y", "X", "Channel"],
["Time", "Z", "Y", "X", "Channel"],
]


def python_actions_for(
result: "jc.SearchResult", output_signal: Signal, parent_widget: QWidget = None
Expand Down Expand Up @@ -168,16 +179,6 @@ def __init__(self, title: str, choices: List[Layer], required=True):
class DimsComboBox(QFrame):
"""A QFrame used to map the axes of a Layer to dimension labels"""

# NB: Strings correspond to supported net.imagej.axis.Axes types
dims = [
[],
["X"],
["Y", "X"],
["Z", "Y", "X"],
["Time", "Y", "X", "Channel"],
["Time", "Z", "Y", "X", "Channel"],
]

def __init__(self, combo_box: LayerComboBox):
super().__init__()
self.selection_box: LayerComboBox = combo_box
Expand All @@ -201,11 +202,13 @@ def update(self, index: int):
selected = self.selection_box.combo.itemData(index)
# Guess dimension labels for the selection
ndim = len(selected.data.shape)
if isinstance(selected, Image) and selected.rgb:
guess = list(self.dims[ndim - 1])
if (dims := getattr(selected.data, "dims", None)) is not None:
guess = dims
elif isinstance(selected, Image) and selected.rgb:
guess = list(CONVENTIONAL_DIMS[ndim - 1])
guess.append("Channel")
else:
guess = self.dims[ndim]
guess = CONVENTIONAL_DIMS[ndim]
# Create dimension selectors for each dimension of the selection.
for i, g in enumerate(guess):
self.layout().addWidget(
Expand Down
22 changes: 17 additions & 5 deletions tests/widgets/test_menu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
QHBoxLayout,
QMessageBox,
)
from xarray import DataArray

from napari_imagej import settings
from napari_imagej.resources import resource_path
Expand Down Expand Up @@ -328,10 +329,12 @@ def test_advanced_data_transfer(
assert not button.isEnabled()

# Add an image to the viewer
sample_data = numpy.ones((100, 100, 3), dtype=numpy.uint8)
sample_data = numpy.ones((50, 100, 3), dtype=numpy.uint8)
# NB: Unnatural data order used for testing in handler
dims = ("X", "Y", "Z")
sample_data = DataArray(data=sample_data, dims=dims)
image: Image = Image(data=sample_data, name="test_to")
current_viewer().add_layer(image)
assert image.rgb
asserter(lambda: button.isEnabled())

# Add some rois to the viewer
Expand Down Expand Up @@ -361,6 +364,9 @@ def handle_transfer(widget: QDialog) -> bool:
if not len(dim_bars) == 3:
print("Expected more dimension comboboxes")
return False
for i, e in enumerate(dims):
if e != dim_bars[i].combo.currentText():
return False

ok_button = widget.buttons.button(QDialogButtonBox.Ok)
ok_button.clicked.emit()
Expand All @@ -375,10 +381,16 @@ def check_active_display():
dataset = ij.display().getActiveDisplay().getActiveView().getData()
if not dataset.getName() == "test_to":
return False
if not dataset.isRGBMerged():
if dataset.getProperties().get("rois") is None:
return False
if dataset.dimension(jc.Axes.X) != 50:
return False
if dataset.dimension(jc.Axes.Y) != 100:
return False
rois = dataset.getProperties().get("rois")
return rois is not None
if dataset.dimension(jc.Axes.Z) != 3:
return False

return True

asserter(check_active_display)

Expand Down
Loading

0 comments on commit 5b3cbc1

Please sign in to comment.