Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zarr compression tests only with versions before 3.0 #8319

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ repos:
)$

- repo: https://github.com/hadialqattan/pycln
rev: v2.4.0
rev: v2.5.0
hooks:
- id: pycln
args: [--config=pyproject.toml]
5 changes: 5 additions & 0 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,3 +607,8 @@ def print_verbose(self) -> None:
print(self)
if self.meta is not None:
print(self.meta.__repr__())


# needed in later versions of Pytorch to indicate the class is safe for serialisation
if hasattr(torch.serialization, "add_safe_globals"):
torch.serialization.add_safe_globals([MetaTensor])
2 changes: 1 addition & 1 deletion monai/utils/jupyter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def plot_engine_status(


def _get_loss_from_output(
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor,
) -> torch.Tensor:
"""Returns a single value from the network output, which is a dict or tensor."""

Expand Down
4 changes: 2 additions & 2 deletions monai/visualize/img2tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def _image3_animated_gif(
img_str = b""
for b_data in PIL.GifImagePlugin.getheader(ims[0])[0]:
img_str += b_data
img_str += b"\x21\xFF\x0B\x4E\x45\x54\x53\x43\x41\x50" b"\x45\x32\x2E\x30\x03\x01\x00\x00\x00"
img_str += b"\x21\xff\x0b\x4e\x45\x54\x53\x43\x41\x50" b"\x45\x32\x2e\x30\x03\x01\x00\x00\x00"
for i in ims:
for b_data in PIL.GifImagePlugin.getdata(i):
img_str += b_data
img_str += b"\x3B"
img_str += b"\x3b"

summary = SummaryX if has_tensorboardx and isinstance(writer, SummaryWriterX) else Summary
summary_image_str = summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str)
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pep8-naming
pycodestyle
pyflakes
black>=22.12
isort>=5.1
isort>=5.1, <6.0
ruff
pytype>=2020.6.1; platform_system != "Windows"
types-setuptools
Expand Down
45 changes: 23 additions & 22 deletions tests/test_zarr_avg_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,32 +260,33 @@
TENSOR_4x4,
]

ALL_TESTS = [
TEST_CASE_0_DEFAULT_DTYPE,
TEST_CASE_1_DEFAULT_DTYPE,
TEST_CASE_2_DEFAULT_DTYPE,
TEST_CASE_3_DEFAULT_DTYPE,
TEST_CASE_4_DEFAULT_DTYPE,
TEST_CASE_5_VALUE_DTYPE,
TEST_CASE_6_COUNT_DTYPE,
TEST_CASE_7_COUNT_VALUE_DTYPE,
TEST_CASE_8_DTYPE,
TEST_CASE_9_LARGER_SHAPE,
TEST_CASE_10_DIRECTORY_STORE,
TEST_CASE_11_MEMORY_STORE,
TEST_CASE_12_CHUNKS,
TEST_CASE_16_WITH_LOCK,
TEST_CASE_17_WITHOUT_LOCK,
]

# add compression tests only when using Zarr version before 3.0
if not version_geq(get_package_version("zarr"), "3.0.0"):
ALL_TESTS += [TEST_CASE_13_COMPRESSOR_LZ4, TEST_CASE_14_COMPRESSOR_PICKLE, TEST_CASE_15_COMPRESSOR_LZMA]


@unittest.skipUnless(has_zarr and has_numcodecs, "Requires zarr (and numcodecs) packages.)")
class ZarrAvgMergerTests(unittest.TestCase):

@parameterized.expand(
[
TEST_CASE_0_DEFAULT_DTYPE,
TEST_CASE_1_DEFAULT_DTYPE,
TEST_CASE_2_DEFAULT_DTYPE,
TEST_CASE_3_DEFAULT_DTYPE,
TEST_CASE_4_DEFAULT_DTYPE,
TEST_CASE_5_VALUE_DTYPE,
TEST_CASE_6_COUNT_DTYPE,
TEST_CASE_7_COUNT_VALUE_DTYPE,
TEST_CASE_8_DTYPE,
TEST_CASE_9_LARGER_SHAPE,
TEST_CASE_10_DIRECTORY_STORE,
TEST_CASE_11_MEMORY_STORE,
TEST_CASE_12_CHUNKS,
TEST_CASE_13_COMPRESSOR_LZ4,
TEST_CASE_14_COMPRESSOR_PICKLE,
TEST_CASE_15_COMPRESSOR_LZMA,
TEST_CASE_16_WITH_LOCK,
TEST_CASE_17_WITHOUT_LOCK,
]
)
@parameterized.expand(ALL_TESTS)
def test_zarr_avg_merger_patches(self, arguments, patch_locations, expected):
codec_reg = numcodecs.registry.codec_registry
if "compressor" in arguments:
Expand Down
Loading