Skip to content

Commit

Permalink
Merge pull request bioimage-io#266 from bioimage-io/image-helper
Browse files Browse the repository at this point in the history
Refactor image functionality
  • Loading branch information
constantinpape authored May 7, 2022
2 parents 7af0e0f + 94bcd3f commit f241d3a
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 135 deletions.
180 changes: 180 additions & 0 deletions bioimageio/core/image_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import os
from copy import deepcopy
from typing import Dict, List, Optional, Sequence, Tuple, Union

import imageio
import numpy as np
from xarray import DataArray
from bioimageio.core.resource_io.nodes import InputTensor, OutputTensor


#
# helper functions to transform input images / output tensors to the required axes
#


def transform_input_image(image: np.ndarray, tensor_axes: str, image_axes: Optional[str] = None):
"""Transform input image into output tensor with desired axes.
Args:
image: the input image
tensor_axes: the desired tensor axes
input_axes: the axes of the input image (optional)
"""
# if the image axes are not given deduce them from the required axes and image shape
if image_axes is None:
has_z_axis = "z" in tensor_axes
ndim = image.ndim
if ndim == 2:
image_axes = "yx"
elif ndim == 3:
image_axes = "zyx" if has_z_axis else "cyx"
elif ndim == 4:
image_axes = "czyx"
elif ndim == 5:
image_axes = "bczyx"
else:
raise ValueError(f"Invalid number of image dimensions: {ndim}")
tensor = DataArray(image, dims=tuple(image_axes))
# expand the missing image axes
missing_axes = tuple(set(tensor_axes) - set(image_axes))
tensor = tensor.expand_dims(dim=missing_axes)
# transpose to the correct axis order
tensor = tensor.transpose(*tuple(tensor_axes))
# return numpy array
return tensor.values


def _drop_axis_default(axis_name, axis_len):
# spatial axes: drop at middle coordnate
# other axes (channel or batch): drop at 0 coordinate
return axis_len // 2 if axis_name in "zyx" else 0


def transform_output_tensor(tensor: np.ndarray, tensor_axes: str, output_axes: str, drop_function=_drop_axis_default):
"""Transform output tensor into image with desired axes.
Args:
tensor: the output tensor
tensor_axes: bioimageio model spec
output_axes: the desired output axes
drop_function: function that determines how to drop unwanted axes
"""
if len(tensor_axes) != tensor.ndim:
raise ValueError(f"Number of axes {len(tensor_axes)} and dimension of tensor {tensor.ndim} don't match")
shape = {ax_name: sh for ax_name, sh in zip(tensor_axes, tensor.shape)}
output = DataArray(tensor, dims=tuple(tensor_axes))
# drop unwanted axes
drop_axis_names = tuple(set(tensor_axes) - set(output_axes))
drop_axes = {ax_name: drop_function(ax_name, shape[ax_name]) for ax_name in drop_axis_names}
output = output[drop_axes]
# transpose to the desired axis order
output = output.transpose(*tuple(output_axes))
# return numpy array
return output.values


def to_channel_last(image):
chan_id = image.dims.index("c")
if chan_id != image.ndim - 1:
target_axes = tuple(ax for ax in image.dims if ax != "c") + ("c",)
image = image.transpose(*target_axes)
return image


#
# helper functions for loading and saving images
#


def load_image(in_path, axes: Sequence[str]) -> DataArray:
ext = os.path.splitext(in_path)[1]
if ext == ".npy":
im = np.load(in_path)
else:
is_volume = "z" in axes
im = imageio.volread(in_path) if is_volume else imageio.imread(in_path)
im = transform_input_image(im, axes)
return DataArray(im, dims=axes)


def load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]]) -> List[DataArray]:
return [load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)]


def save_image(out_path, image):
ext = os.path.splitext(out_path)[1]
if ext == ".npy":
np.save(out_path, image)
else:
is_volume = "z" in image.dims

# squeeze batch or channel axes if they are singletons
squeeze = {ax: 0 if (ax in "bc" and sh == 1) else slice(None) for ax, sh in zip(image.dims, image.shape)}
image = image[squeeze]

if "b" in image.dims:
raise RuntimeError(f"Cannot save prediction with batchsize > 1 as {ext}-file")
if "c" in image.dims: # image formats need channel last
image = to_channel_last(image)

save_function = imageio.volsave if is_volume else imageio.imsave
# most image formats only support channel dimensions of 1, 3 or 4;
# if not we need to save the channels separately
ndim = 3 if is_volume else 2
save_as_single_image = image.ndim == ndim or (image.shape[-1] in (3, 4))

if save_as_single_image:
save_function(out_path, image)
else:
out_prefix, ext = os.path.splitext(out_path)
for c in range(image.shape[-1]):
chan_out_path = f"{out_prefix}-c{c}{ext}"
save_function(chan_out_path, image[..., c])


#
# helper function for padding
#


def pad(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray, Dict[str, slice]]:
assert image.ndim == len(axes), f"{image.ndim}, {len(axes)}"

padding_ = deepcopy(padding)
mode = padding_.pop("mode", "dynamic")
assert mode in ("dynamic", "fixed")

is_volume = "z" in axes
if is_volume:
assert len(padding_) == 3
else:
assert len(padding_) == 2

if isinstance(pad_right, bool):
pad_right = len(axes) * [pad_right]

pad_width = []
crop = {}
for ax, dlen, pr in zip(axes, image.shape, pad_right):

if ax in "zyx":
pad_to = padding_[ax]

if mode == "dynamic":
r = dlen % pad_to
pwidth = 0 if r == 0 else (pad_to - r)
else:
if pad_to < dlen:
msg = f"Padding for axis {ax} failed; pad shape {pad_to} is smaller than the image shape {dlen}."
raise RuntimeError(msg)
pwidth = pad_to - dlen

pad_width.append([0, pwidth] if pr else [pwidth, 0])
crop[ax] = slice(0, dlen) if pr else slice(pwidth, None)
else:
pad_width.append([0, 0])
crop[ax] = slice(None)

image = np.pad(image, pad_width, mode="symmetric")
return image, crop
139 changes: 5 additions & 134 deletions bioimageio/core/prediction.py
Original file line number Diff line number Diff line change
@@ -1,150 +1,21 @@
import collections
import os
from copy import deepcopy
from itertools import product
from pathlib import Path
from typing import Dict, Iterator, List, NamedTuple, Optional, OrderedDict, Sequence, Tuple, Union

import imageio
import numpy as np
import xarray as xr
from tqdm import tqdm

from bioimageio.core import image_helper
from bioimageio.core import load_resource_description
from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline
from bioimageio.core.resource_io.nodes import ImplicitOutputShape, InputTensor, Model, ResourceDescription, OutputTensor
from bioimageio.core.resource_io.nodes import ImplicitOutputShape, Model, ResourceDescription
from bioimageio.spec.shared import raw_nodes
from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription


#
# utility functions for prediction
#
def _require_axes(im, axes):
is_volume = "z" in axes
# we assume images / volumes are loaded as one of
# yx, yxc, zyxc
if im.ndim == 2:
im_axes = ("y", "x")
elif im.ndim == 3:
im_axes = ("z", "y", "x") if is_volume else ("y", "x", "c")
elif im.ndim == 4:
raise NotImplementedError
else: # ndim >= 5 not implemented
raise RuntimeError

# add singleton channel dimension if not present
if "c" not in im_axes:
im = im[..., None]
im_axes = im_axes + ("c",)

# add singleton batch dim
im = im[None]
im_axes = ("b",) + im_axes

# permute the axes correctly
assert set(axes) == set(im_axes)
axes_permutation = tuple(im_axes.index(ax) for ax in axes)
im = im.transpose(axes_permutation)
return im


def _pad(im, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray, Dict[str, slice]]:
assert im.ndim == len(axes), f"{im.ndim}, {len(axes)}"

padding_ = deepcopy(padding)
mode = padding_.pop("mode", "dynamic")
assert mode in ("dynamic", "fixed")

is_volume = "z" in axes
if is_volume:
assert len(padding_) == 3
else:
assert len(padding_) == 2

if isinstance(pad_right, bool):
pad_right = len(axes) * [pad_right]

pad_width = []
crop = {}
for ax, dlen, pr in zip(axes, im.shape, pad_right):

if ax in "zyx":
pad_to = padding_[ax]

if mode == "dynamic":
r = dlen % pad_to
pwidth = 0 if r == 0 else (pad_to - r)
else:
if pad_to < dlen:
msg = f"Padding for axis {ax} failed; pad shape {pad_to} is smaller than the image shape {dlen}."
raise RuntimeError(msg)
pwidth = pad_to - dlen

pad_width.append([0, pwidth] if pr else [pwidth, 0])
crop[ax] = slice(0, dlen) if pr else slice(pwidth, None)
else:
pad_width.append([0, 0])
crop[ax] = slice(None)

im = np.pad(im, pad_width, mode="symmetric")
return im, crop


def _load_image(in_path, axes: Sequence[str]) -> xr.DataArray:
ext = os.path.splitext(in_path)[1]
if ext == ".npy":
im = np.load(in_path)
else:
is_volume = "z" in axes
im = imageio.volread(in_path) if is_volume else imageio.imread(in_path)
im = _require_axes(im, axes)
return xr.DataArray(im, dims=axes)


def _load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]]) -> List[xr.DataArray]:
return [_load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)]


def _to_channel_last(image):
chan_id = image.dims.index("c")
if chan_id != image.ndim - 1:
target_axes = tuple(ax for ax in image.dims if ax != "c") + ("c",)
image = image.transpose(*target_axes)
return image


def _save_image(out_path, image):
ext = os.path.splitext(out_path)[1]
if ext == ".npy":
np.save(out_path, image)
else:
is_volume = "z" in image.dims

# squeeze batch or channel axes if they are singletons
squeeze = {ax: 0 if (ax in "bc" and sh == 1) else slice(None) for ax, sh in zip(image.dims, image.shape)}
image = image[squeeze]

if "b" in image.dims:
raise RuntimeError(f"Cannot save prediction with batchsize > 1 as {ext}-file")
if "c" in image.dims: # image formats need channel last
image = _to_channel_last(image)

save_function = imageio.volsave if is_volume else imageio.imsave
# most image formats only support channel dimensions of 1, 3 or 4;
# if not we need to save the channels separately
ndim = 3 if is_volume else 2
save_as_single_image = image.ndim == ndim or (image.shape[-1] in (3, 4))

if save_as_single_image:
save_function(out_path, image)
else:
out_prefix, ext = os.path.splitext(out_path)
for c in range(image.shape[-1]):
chan_out_path = f"{out_prefix}-c{c}{ext}"
save_function(chan_out_path, image[..., c])


def _apply_crop(data, crop):
crop = tuple(crop[ax] for ax in data.dims)
return data[crop]
Expand Down Expand Up @@ -345,7 +216,7 @@ def predict_with_padding(
assert len(padding) == len(prediction_pipeline.input_specs)
inputs, crops = zip(
*[
_pad(inp, spec.axes, p, pad_right=pad_right)
image_helper.pad(inp, spec.axes, p, pad_right=pad_right)
for inp, spec, p in zip(inputs, prediction_pipeline.input_specs, padding)
]
)
Expand Down Expand Up @@ -508,7 +379,7 @@ def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling):
if padding and tiling:
raise ValueError("Only one of padding or tiling is supported")

input_data = _load_tensors(inputs, prediction_pipeline.input_specs)
input_data = image_helper.load_tensors(inputs, prediction_pipeline.input_specs)
if padding is not None:
result = predict_with_padding(prediction_pipeline, input_data, padding)
elif tiling is not None:
Expand All @@ -519,7 +390,7 @@ def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling):
assert isinstance(result, list)
assert len(result) == len(outputs)
for res, out in zip(result, outputs):
_save_image(out, res)
image_helper.save_image(out, res)


def predict_image(
Expand Down
2 changes: 1 addition & 1 deletion bioimageio/core/resource_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _validate_output_shape(shape: Tuple[int, ...], shape_spec, input_shapes) ->
if ref_tensor not in input_shapes:
raise ValidationError(f"The reference tensor name {ref_tensor} is not in {input_shapes}")
ipt_shape = numpy.array(input_shapes[ref_tensor])
scale = numpy.array(shape_spec.scale)
scale = numpy.array([0.0 if sc is None else sc for sc in shape_spec.scale])
offset = numpy.array(shape_spec.offset)
exp_shape = numpy.round_(ipt_shape * scale) + 2 * offset

Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"unet2d_nuclei_broad/rdf.yaml"
),
"unet2d_expand_output_shape": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"unet2d_nuclei_broad/rdf_expand_output_shape.yaml"
),
"unet2d_fixed_shape": (
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
"unet2d_fixed_shape/rdf.yaml"
Expand Down Expand Up @@ -205,6 +209,12 @@ def unet2d_diff_output_shape(request):
return pytest.model_packages[request.param]


# written as model group to automatically skip on missing torch
@pytest.fixture(params=[] if skip_torch else ["unet2d_expand_output_shape"])
def unet2d_expand_output_shape(request):
return pytest.model_packages[request.param]


# written as model group to automatically skip on missing torch
@pytest.fixture(params=[] if skip_torch else ["unet2d_fixed_shape"])
def unet2d_fixed_shape(request):
Expand Down
Loading

0 comments on commit f241d3a

Please sign in to comment.