diff --git a/tests/test_read.py b/tests/test_read.py index dac0cae..60e0462 100644 --- a/tests/test_read.py +++ b/tests/test_read.py @@ -17,7 +17,7 @@ ], indirect=True, ) -def test_basic_read(simmed_ms): +def test_regular_read(simmed_ms): """Test for ramp function values produced by simulator""" xdt = open_datatree(simmed_ms) @@ -33,3 +33,73 @@ def test_basic_read(simmed_ms): nelements = reduce(mul, uvw.shape, 1) expected = np.arange(nelements, dtype=np.float64).reshape(uvw.shape) assert_array_equal(uvw, expected) + + +ANT1_SUBSET = [0, 0, 1] +ANT2_SUBSET = [0, 1, 2] + + +def _select_rows(antenna1, antenna2, ant1_subset, ant2_subset): + dtype = [("a1", antenna1.dtype), ("a2", antenna2.dtype)] + baselines = np.rec.fromarrays([antenna1, antenna2], dtype=dtype) + desired = np.rec.fromarrays([ant1_subset, ant2_subset], dtype=dtype) + return np.isin(baselines, desired) + + +def _excise_rows(data_dict): + _, ant1 = data_dict["ANTENNA1"] + _, ant2 = data_dict["ANTENNA2"] + index = _select_rows(ant1, ant2, ANT1_SUBSET, ANT2_SUBSET) + return {k: (d, v[index]) for k, (d, v) in data_dict.items()} + + +@pytest.mark.parametrize( + "simmed_ms", + [ + { + "name": "backend.ms", + "nantenna": 3, + "data_description": [(8, ["XX", "XY", "YX", "YY"]), (4, ["RR", "LL"])], + "transform_data": _excise_rows, + } + ], + indirect=True, +) +def test_irregular_read(simmed_ms): + xdt = open_datatree(simmed_ms) + + for node in xdt.subtree: + if "data_description_id" in node.attrs: + bl_index = _select_rows( + node.baseline_antenna1_name.values, + node.baseline_antenna2_name.values, + [f"ANTENNA-{i}" for i in ANT1_SUBSET], + [f"ANTENNA-{i}" for i in ANT2_SUBSET], + ) + + vis = node.VISIBILITY.values + # Selected baseline elements are as expected + nelements = reduce(mul, vis.shape, 1) + expected = np.arange(nelements, dtype=np.float32) + expected = (expected + expected * 1j).reshape(vis.shape) + assert_array_equal(vis[:, bl_index], expected[:, bl_index]) + # Other baseline elements are nan + vis = node.VISIBILITY.values + assert np.all(np.isnan((vis[:, ~bl_index]))) + + uvw = node.UVW.values + # Selected baseline elements are as expected + nelements = reduce(mul, uvw.shape, 1) + expected = np.arange(nelements, dtype=np.float64).reshape(uvw.shape) + assert_array_equal(uvw[:, bl_index], expected[:, bl_index]) + # Other baseline elements are nan + assert np.all(np.isnan((uvw[:, ~bl_index, ...]))) + + flag = node.FLAG.values + # Selected baseline elements are as expected + nelements = reduce(mul, flag.shape, 1) + expected = np.where(np.arange(nelements) & 0x1, 0, 1) + expected = expected.reshape(flag.shape) + assert_array_equal(flag[:, bl_index], expected[:, bl_index]) + # Other baseline elements are flagged + assert np.all(flag[:, ~bl_index, ...] == 1) diff --git a/xarray_ms/testing/simulator.py b/xarray_ms/testing/simulator.py index f5f0ca4..8d4f821 100644 --- a/xarray_ms/testing/simulator.py +++ b/xarray_ms/testing/simulator.py @@ -2,6 +2,7 @@ import os import tempfile import typing +from collections.abc import Callable from typing import ( Any, Dict, @@ -88,6 +89,7 @@ class PartitionDescriptor: DDIDArgType = List[Tuple[npt.NDArray[np.float64], List[str]]] +PartitionDataType = Dict[str, Tuple[Tuple[str, ...], npt.NDArray]] class MSStructureSimulator: @@ -124,6 +126,7 @@ def __init__( partition: Tuple[str, ...] = ("PROCESSOR_ID", "FIELD_ID", "DATA_DESC_ID"), auto_corrs: bool = True, simulate_data: bool = True, + transform_data: Callable[[PartitionDataType], PartitionDataType] | None = None, ): assert ntime >= 1 assert time_chunks > 0 @@ -178,6 +181,7 @@ def __init__( self.simulate_data = simulate_data self.partition_names = cbp_names self.partition_indices = bcbp_indices + self.transform_data = transform_data self.model = { "data_description": self.data_description, "feed_map": self.feeds, @@ -199,6 +203,8 @@ def simulate_ms(self, output_ms: str) -> None: for chunk_desc in self.generate_descriptors(): data_dict = self.data_factory(chunk_desc) + if self.transform_data is not None: + data_dict = self.transform_data(data_dict) (nrow,) = data_dict["TIME"][1].shape T.addrows(nrow) @@ -311,9 +317,7 @@ def broadcast_partition_indices( return np.stack([a.ravel() for a in np.broadcast_arrays(*bparts)], axis=1) @staticmethod - def data_factory( - desc: PartitionDescriptor, - ) -> Dict[str, Tuple[Tuple[str, ...], npt.NDArray]]: + def data_factory(desc: PartitionDescriptor) -> PartitionDataType: """Creates simulated MS data from a partition descriptor""" try: ddid = desc.DATA_DESC_ID.item()