diff --git a/README.md b/README.md index d37742f..e43df23 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,8 @@ zarr.config.set({ }) ``` +If the `ZarrsCodecPipeline` is pickled, and then un-pickled, and during that time one of `store_empty_chunks`, `chunk_concurrent_minimum`, `chunk_concurrent_maximum`, or `num_threads` has changed, the newly un-pickled version will pick up the new value. However, one a `ZarrsCodecPipeline` object has been instantiated, these values are then fixed. This may change in the future as guidance from the `zarr` community becomes clear. + ## Concurrency Concurrency can be classified into two types: diff --git a/python/zarrs/pipeline.py b/python/zarrs/pipeline.py index 6c5912d..f6552ba 100644 --- a/python/zarrs/pipeline.py +++ b/python/zarrs/pipeline.py @@ -3,7 +3,7 @@ import asyncio import json from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, TypedDict import numpy as np from zarr.abc.codec import ( @@ -14,7 +14,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator - from typing import Self + from typing import Any, Self from zarr.abc.store import ByteGetter, ByteSetter from zarr.core.array_spec import ArraySpec @@ -32,10 +32,40 @@ ) -@dataclass(frozen=True) +def get_codec_pipeline_impl(codec_metadata_json: str) -> CodecPipelineImpl: + return CodecPipelineImpl( + codec_metadata_json, + validate_checksums=config.get("codec_pipeline.validate_checksums", None), + # TODO: upstream zarr-python array.write_empty_chunks is not merged yet #2429 + store_empty_chunks=config.get("array.write_empty_chunks", None), + chunk_concurrent_minimum=config.get( + "codec_pipeline.chunk_concurrent_minimum", None + ), + chunk_concurrent_maximum=config.get( + "codec_pipeline.chunk_concurrent_maximum", None + ), + num_threads=config.get("threading.max_workers", None), + ) + + +class ZarrsCodecPipelineState(TypedDict): + codec_metadata_json: str + codecs: tuple[Codec, ...] + + +@dataclass class ZarrsCodecPipeline(CodecPipeline): codecs: tuple[Codec, ...] impl: CodecPipelineImpl + codec_metadata_json: str + + def __getstate__(self) -> ZarrsCodecPipelineState: + return {"codec_metadata_json": self.codec_metadata_json, "codecs": self.codecs} + + def __setstate__(self, state: ZarrsCodecPipelineState): + self.codecs = state["codecs"] + self.codec_metadata_json = state["codec_metadata_json"] + self.impl = get_codec_pipeline_impl(self.codec_metadata_json) def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: raise NotImplementedError("evolve_from_array_spec") @@ -49,22 +79,9 @@ def from_codecs(cls, codecs: Iterable[Codec]) -> Self: # https://github.com/zarr-developers/zarr-python/issues/2409 # https://github.com/zarr-developers/zarr-python/pull/2429 return cls( + codec_metadata_json=codec_metadata_json, codecs=tuple(codecs), - impl=CodecPipelineImpl( - codec_metadata_json, - validate_checksums=config.get( - "codec_pipeline.validate_checksums", None - ), - # TODO: upstream zarr-python array.write_empty_chunks is not merged yet #2429 - store_empty_chunks=config.get("array.write_empty_chunks", None), - chunk_concurrent_minimum=config.get( - "codec_pipeline.chunk_concurrent_minimum", None - ), - chunk_concurrent_maximum=config.get( - "codec_pipeline.chunk_concurrent_maximum", None - ), - num_threads=config.get("threading.max_workers", None), - ), + impl=get_codec_pipeline_impl(codec_metadata_json), ) @property diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index a4270f5..1af6f0f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import operator +import pickle import tempfile from collections.abc import Callable from contextlib import contextmanager @@ -229,3 +230,13 @@ def test_ellipsis_indexing_invalid(arr: zarr.Array): # zarrs-python error: ValueError: operands could not be broadcast together with shapes (4,) (3,) # numpy error: ValueError: could not broadcast input array from shape (3,) into shape (4,) arr[2, ...] = stored_value + + +def test_pickle(arr: zarr.Array, tmp_path: Path): + arr[:] = np.arange(reduce(operator.mul, arr.shape, 1)).reshape(arr.shape) + expected = arr[:] + with Path.open(tmp_path / "arr.pickle", "wb") as f: + pickle.dump(arr._async_array.codec_pipeline, f) + with Path.open(tmp_path / "arr.pickle", "rb") as f: + object.__setattr__(arr._async_array, "codec_pipeline", pickle.load(f)) + assert (arr[:] == expected).all()