Skip to content

Commit

Permalink
Make MSv2Array transform a property
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Feb 5, 2025
1 parent a54a2af commit c5c1e61
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
15 changes: 11 additions & 4 deletions xarray_ms/backend/msv2/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from xarray_ms.backend.msv2.structure import MSv2StructureFactory, PartitionKeyT
from xarray_ms.backend.msv2.table_factory import TableFactory

TransformerT = Callable[[npt.NDArray], npt.NDArray] | None


def slice_length(s, max_len):
if isinstance(s, np.ndarray):
Expand All @@ -35,7 +37,7 @@ class MSv2Array(BackendArray):
_shape: Tuple[int, ...]
_dtype: npt.DTypeLike
_default: Any | None
_transform: Callable[[npt.NDArray], npt.NDArray] | None
_transform: TransformerT

def __init__(
self,
Expand All @@ -46,7 +48,7 @@ def __init__(
shape: Tuple[int, ...],
dtype: npt.DTypeLike,
default: Any | None = None,
transform: Callable[[npt.NDArray], npt.NDArray] | None = None,
transform: TransformerT = None,
):
self._table_factory = table_factory
self._structure_factory = structure_factory
Expand Down Expand Up @@ -76,5 +78,10 @@ def _getitem(self, key) -> npt.NDArray:
result = result.reshape(rows.shape + expected_shape[2:])
return self._transform(result) if self._transform else result

def set_transform(self, transform: Callable[[npt.NDArray], npt.NDArray]):
self._transform = transform
@property
def transform(self) -> TransformerT:
return self._transform

@transform.setter
def transform(self, value: TransformerT):
self._transform = value
4 changes: 2 additions & 2 deletions xarray_ms/backend/msv2/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
attrs.pop("format", None)

if isinstance(data, MSv2Array):
data.set_transform(UTCCoder.encode_array)
data.transform = UTCCoder.encode_array
elif isinstance(data, np.ndarray):
data = UTCCoder.encode_array(data)
else:
Expand All @@ -162,7 +162,7 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
attrs["format"] = "unix"

if isinstance(data, MSv2Array):
data.set_transform(UTCCoder.decode_array)
data.transform = UTCCoder.decode_array
elif isinstance(data, np.ndarray):
data = UTCCoder.decode_array(data)
else:
Expand Down

0 comments on commit c5c1e61

Please sign in to comment.