Skip to content

Commit

Permalink
Further align partitioning code with MSv4
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Feb 4, 2025
1 parent a5e4cf2 commit 07eefd1
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 102 deletions.
4 changes: 2 additions & 2 deletions doc/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ be used to open either a :class:`~xarray.Dataset` or a
>>> dataset = xarray.open_dataset(
"/data/data.ms",
partition_columns=["DATA_DESC_ID", "FIELD_ID"])
partition_schema=["DATA_DESC_ID", "FIELD_ID"])
>>> datatree = xarray.backends.api.open_datatree(
"/data/data.ms",
partition_columns=["DATA_DESC_ID", "FIELD_ID"])
partition_schema=["DATA_DESC_ID", "FIELD_ID"])
These methods defer to the relevant methods on the
`Entrypoint Class <entrypoint-class_>`_.
Expand Down
16 changes: 8 additions & 8 deletions doc/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ collection of Datasets of ndarrays on a regular grid.
To move data between the two formats, it is necessary to partition
or group MSv2 rows by the same shape and configuration.

In xarray-ms, this is accomplished by specifying ``partition_columns``
In xarray-ms, this is accomplished by specifying ``partition_schema``
when opening a Measurement Set.
Different columns may be used to define the partition, but
:code:`[DATA_DESC_ID, FIELD_ID, OBSERVATION_ID]` is a reasonable choice.
Different columns may be used to define the partition.
:code:`[DATA_DESC_ID, OBS_MODE, OBSERVATION_ID]` is the default.

Opening a Measurement Set
-------------------------
Expand All @@ -39,7 +39,7 @@ to open multiple partitions of a Measurement Set.
(8, ("XX", "XY", "YX", "YY")),
(4, ("RR", "LL"))])
dt = open_datatree(ms, partition_columns=[
dt = open_datatree(ms, partition_schema=[
"DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"])
dt
Expand All @@ -62,7 +62,7 @@ For example, one could select select some specific dimensions out:
.. ipython:: python
dt = open_datatree(ms,
partition_columns=["DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"])
partition_schema=["DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"])
subdt = dt.isel(time=slice(1, 3), baseline_id=[1, 3, 5], frequency=slice(2, 4))
subdt
Expand Down Expand Up @@ -92,7 +92,7 @@ can be enabled by specifying the ``chunks`` parameter:

.. ipython:: python
dt = open_datatree(ms, partition_columns=[
dt = open_datatree(ms, partition_schema=[
"DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"],
chunks={"time": 2, "frequency": 2})
Expand All @@ -108,7 +108,7 @@ to specify different chunking setups for each partition.

.. ipython:: python
dt = open_datatree(ms, partition_columns=[
dt = open_datatree(ms, partition_schema=[
"DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"],
chunks={},
preferred_chunks={
Expand Down Expand Up @@ -137,7 +137,7 @@ this to a zarr_ store.
import os.path
import tempfile
dt = open_datatree(ms, partition_columns=[
dt = open_datatree(ms, partition_schema=[
"DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"],
chunks={},
preferred_chunks={
Expand Down
18 changes: 9 additions & 9 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ def test_open_dataset(simmed_ms):
ids=id_string,
)
@pytest.mark.parametrize(
"partition_columns", [["DATA_DESC_ID", "OBSERVATION_ID", "FIELD_ID"]]
"partition_schema", [["DATA_DESC_ID", "OBSERVATION_ID", "FIELD_ID"]]
)
def test_open_dataset_partition_keys(
simmed_ms, partition_columns, partition_key, pols, nfreq
simmed_ms, partition_schema, partition_key, pols, nfreq
):
ds = xarray.open_dataset(
simmed_ms, partition_columns=partition_columns, partition_key=partition_key
simmed_ms, partition_schema=partition_schema, partition_key=partition_key
)
assert_array_equal(ds.polarization.values, pols)
assert {("frequency", nfreq), ("polarization", len(pols))}.issubset(ds.sizes.items())
Expand All @@ -150,18 +150,18 @@ def test_open_datatree(simmed_ms):
with ExitStack() as stack:
mem_dt = open_datatree(simmed_ms)
mem_dt.load()
for ds in mem_dt.subtree:
if "version" in ds.attrs:
ds.attrs.pop("creation_date", None)
assert isinstance(ds.VISIBILITY.data, np.ndarray)
for node in mem_dt.subtree:
if node.attrs.get("type") == "visibility":
node.attrs.pop("creation_date", None)
assert isinstance(node.VISIBILITY.data, np.ndarray)

chunks = {"time": 2, "frequency": 2}

# Works with default dask scheduler
with ExitStack() as stack:
dt = open_datatree(simmed_ms, preferred_chunks=chunks)
for node in dt.subtree:
if "version" in node.attrs:
if node.attrs.get("type") == "visibility":
del node.attrs["creation_date"]
xt.assert_identical(dt, mem_dt)

Expand All @@ -171,7 +171,7 @@ def test_open_datatree(simmed_ms):
stack.enter_context(Client(cluster))
dt = open_datatree(simmed_ms, preferred_chunks=chunks)
for node in dt.subtree:
if "version" in node.attrs:
if node.attrs.get("type") == "visibility":
del node.attrs["creation_date"]
xt.assert_identical(dt, mem_dt)

Expand Down
8 changes: 8 additions & 0 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,32 @@ def test_table_partitioner():
groups,
{
(("DATA_DESC_ID", 0), ("FIELD_ID", 0)): {
"DATA_DESC_ID": [0, 0],
"FIELD_ID": [0, 0],
"TIME": [2.0, 4.0],
"ANTENNA1": [0, 0],
"ANTENNA2": [1, 1],
"row": [4, 2],
},
(("DATA_DESC_ID", 0), ("FIELD_ID", 1)): {
"DATA_DESC_ID": [0],
"FIELD_ID": [1],
"TIME": [3.0],
"ANTENNA1": [0],
"ANTENNA2": [1],
"row": [3],
},
(("DATA_DESC_ID", 1), ("FIELD_ID", 0)): {
"DATA_DESC_ID": [1],
"FIELD_ID": [0],
"TIME": [1.0],
"ANTENNA1": [0],
"ANTENNA2": [1],
"row": [1],
},
(("DATA_DESC_ID", 1), ("FIELD_ID", 1)): {
"DATA_DESC_ID": [1],
"FIELD_ID": [1],
"TIME": [0.0],
"ANTENNA1": [0],
"ANTENNA2": [1],
Expand Down
82 changes: 43 additions & 39 deletions xarray_ms/backend/msv2/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import warnings
from datetime import datetime, timezone
from importlib.metadata import version as importlib_version
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Tuple
from uuid import uuid4

Expand Down Expand Up @@ -73,7 +74,7 @@ def initialise_default_args(
auto_corrs: bool,
epoch: str | None,
table_factory: TableFactory | None,
partition_columns: List[str] | None,
partition_schema: List[str] | None,
structure_factory: MSv2StructureFactory | None,
) -> Tuple[str, TableFactory, List[str], MSv2StructureFactory]:
"""
Expand All @@ -90,11 +91,11 @@ def initialise_default_args(
lockoptions="nolock",
)
epoch = epoch or uuid4().hex[:8]
partition_columns = partition_columns or DEFAULT_PARTITION_COLUMNS
partition_schema = partition_schema or DEFAULT_PARTITION_COLUMNS
structure_factory = structure_factory or MSv2StructureFactory(
table_factory, partition_columns, epoch, auto_corrs=auto_corrs
table_factory, partition_schema, epoch, auto_corrs=auto_corrs
)
return epoch, table_factory, partition_columns, structure_factory
return epoch, table_factory, partition_schema, structure_factory


class MSv2Store(AbstractWritableDataStore):
Expand All @@ -103,7 +104,7 @@ class MSv2Store(AbstractWritableDataStore):
__slots__ = (
"_table_factory",
"_structure_factory",
"_partition_columns",
"_partition_schema",
"_partition_key",
"_preferred_chunks",
"_auto_corrs",
Expand All @@ -113,7 +114,7 @@ class MSv2Store(AbstractWritableDataStore):

_table_factory: TableFactory
_structure_factory: MSv2StructureFactory
_partition_columns: List[str]
_partition_schema: List[str]
_preferred_chunks: Dict[str, int]
_partition: PartitionKeyT
_autocorrs: bool
Expand All @@ -124,7 +125,7 @@ def __init__(
self,
table_factory: TableFactory,
structure_factory: MSv2StructureFactory,
partition_columns: List[str],
partition_schema: List[str],
partition_key: PartitionKeyT,
preferred_chunks: Dict[str, int],
auto_corrs: bool,
Expand All @@ -133,7 +134,7 @@ def __init__(
):
self._table_factory = table_factory
self._structure_factory = structure_factory
self._partition_columns = partition_columns
self._partition_schema = partition_schema
self._partition_key = partition_key
self._preferred_chunks = preferred_chunks
self._auto_corrs = auto_corrs
Expand All @@ -145,7 +146,7 @@ def open(
cls,
ms: str,
drop_variables=None,
partition_columns: List[str] | None = None,
partition_schema: List[str] | None = None,
partition_key: PartitionKeyT | None = None,
preferred_chunks: Dict[str, int] | None = None,
auto_corrs: bool = True,
Expand All @@ -156,16 +157,14 @@ def open(
if not isinstance(ms, str):
raise ValueError("Measurement Sets paths must be strings")

epoch, table_factory, partition_columns, structure_factory = (
initialise_default_args(
ms,
ninstances,
auto_corrs,
epoch,
None,
partition_columns,
structure_factory,
)
epoch, table_factory, partition_schema, structure_factory = initialise_default_args(
ms,
ninstances,
auto_corrs,
epoch,
None,
partition_schema,
structure_factory,
)

# Resolve the user supplied partition key against actual
Expand All @@ -188,7 +187,7 @@ def open(
return cls(
table_factory,
structure_factory,
partition_columns=partition_columns,
partition_schema=partition_schema,
partition_key=partition_key,
preferred_chunks=preferred_chunks,
auto_corrs=auto_corrs,
Expand All @@ -210,17 +209,22 @@ def get_variables(self):
return factory.get_variables()

def get_attrs(self):
try:
ddid = next(iter(v for k, v in self._partition_key if k == "DATA_DESC_ID"))
except StopIteration:
raise KeyError("DATA_DESC_ID not found in partition")
factory = MainDatasetFactory(
self._partition_key,
self._preferred_chunks,
self._table_factory,
self._structure_factory,
)

return {
"version": "4.0.0",
attrs = {
"schema_version": "4.0.0",
"creation_date": datetime.now(timezone.utc).isoformat(),
"data_description_id": ddid,
"type": "visibility",
"xarray_ms_version": importlib_version("xarray-ms"),
}

return {**attrs, **factory.get_attrs()}

def get_dimensions(self):
return None

Expand All @@ -231,7 +235,7 @@ def get_encoding(self):
class MSv2EntryPoint(BackendEntrypoint):
open_dataset_parameters = [
"filename_or_obj",
"partition_columns",
"partition_schema",
"partition_key",
"preferred_chunks",
"auto_corrs",
Expand Down Expand Up @@ -266,7 +270,7 @@ def open_dataset(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
drop_variables: str | Iterable[str] | None = None,
partition_columns: List[str] | None = None,
partition_schema: List[str] | None = None,
partition_key: PartitionKeyT | None = None,
preferred_chunks: Dict[str, int] | None = None,
auto_corrs: bool = True,
Expand All @@ -280,7 +284,7 @@ def open_dataset(
Args:
filename_or_obj: The path to the MSv2 CASA Measurement Set file.
drop_variables: Variables to drop from the dataset.
partition_columns: The columns to use for partitioning the Measurement set.
partition_schema: The columns to use for partitioning the Measurement set.
Defaults to :code:`{DEFAULT_PARTITION_COLUMNS}`.
partition_key: A key corresponding to an individual partition.
For example :code:`(('DATA_DESC_ID', 0), ('FIELD_ID', 0))`.
Expand All @@ -295,14 +299,14 @@ def open_dataset(
Returns:
A :class:`~xarray.Dataset` referring to the unique
partition specified by :code:`partition_columns` and :code:`partition_key`.
partition specified by :code:`partition_schema` and :code:`partition_key`.
"""
filename_or_obj = _normalize_path(filename_or_obj)

store = MSv2Store.open(
filename_or_obj,
drop_variables=drop_variables,
partition_columns=partition_columns,
partition_schema=partition_schema,
partition_key=partition_key,
preferred_chunks=preferred_chunks,
auto_corrs=auto_corrs,
Expand All @@ -320,7 +324,7 @@ def open_datatree(
*,
preferred_chunks: Dict[str, Any] | None = None,
drop_variables: str | Iterable[str] | None = None,
partition_columns: List[str] | None = None,
partition_schema: List[str] | None = None,
auto_corrs: bool = True,
ninstances: int = 8,
epoch: str | None = None,
Expand Down Expand Up @@ -359,7 +363,7 @@ def open_datatree(
for more information.
drop_variables: Variables to drop from the dataset.
partition_columns: The columns to use for partitioning the Measurement set.
partition_schema: The columns to use for partitioning the Measurement set.
Defaults to :code:`{DEFAULT_PARTITION_COLUMNS}`.
auto_corrs: Include/Exclude auto-correlations.
ninstances: The number of Measurement Set instances to open for parallel I/O.
Expand All @@ -374,7 +378,7 @@ def open_datatree(
groups_dict = self.open_groups_as_dict(
filename_or_obj,
drop_variables=drop_variables,
partition_columns=partition_columns,
partition_schema=partition_schema,
preferred_chunks=preferred_chunks,
auto_corrs=auto_corrs,
ninstances=ninstances,
Expand All @@ -390,7 +394,7 @@ def open_groups_as_dict(
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
drop_variables: str | Iterable[str] | None = None,
partition_columns: List[str] | None = None,
partition_schema: List[str] | None = None,
preferred_chunks: Dict[str, int] | None = None,
auto_corrs: bool = True,
ninstances: int = 8,
Expand All @@ -408,8 +412,8 @@ def open_groups_as_dict(
else:
raise ValueError("Measurement Set paths must be strings")

epoch, _, partition_columns, structure_factory = initialise_default_args(
ms, ninstances, auto_corrs, epoch, None, partition_columns, None
epoch, _, partition_schema, structure_factory = initialise_default_args(
ms, ninstances, auto_corrs, epoch, None, partition_schema, None
)

# /path/to/some_name.ext -> some_name
Expand All @@ -424,7 +428,7 @@ def open_groups_as_dict(
ms,
drop_variables=drop_variables,
engine="xarray-ms:msv2",
partition_columns=partition_columns,
partition_schema=partition_schema,
partition_key=partition_key,
preferred_chunks=pchunks[partition_key]
if isinstance(pchunks, Mapping)
Expand Down
Loading

0 comments on commit 07eefd1

Please sign in to comment.