From 8e8aa0dd9ff72e10e36ba6578db29c914e3108b9 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 30 Jan 2025 11:36:23 +0200 Subject: [PATCH] WIP --- xarray_ms/backend/msv2/structure.py | 28 ++++++++++++++++++---------- xarray_ms/testing/simulator.py | 12 ++++++++++++ 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/xarray_ms/backend/msv2/structure.py b/xarray_ms/backend/msv2/structure.py index 0194623..b662f35 100644 --- a/xarray_ms/backend/msv2/structure.py +++ b/xarray_ms/backend/msv2/structure.py @@ -163,13 +163,13 @@ class PartitionData: chan_freq: npt.NDArray[np.float64] chan_width: npt.NDArray[np.float64] corr_type: npt.NDArray[np.int32] - field_names: List[str] | None + field_names: List[str] spw_name: str spw_freq_group_name: str spw_ref_freq: float spw_frame: str - scan_numbers: List[int] | None - sub_scan_numbers: List[int] | None + scan_numbers: List[int] + sub_scan_numbers: List[int] row_map: npt.NDArray[np.int64] @@ -208,8 +208,16 @@ def partition( nworkers = pool._max_workers chunk = (nrow + nworkers - 1) // nworkers - all_columns = index.column_names + # Order columns by + # + # 1. Partitioning columns + # 2. Sorting columns + # 3. Others (such as row and INTERVAL) + # 4. Remaining columns + # + # 4 is needed for the merge_np_partitions to work ordered_columns = self._partitionby + self._sortby + self._other + ordered_columns += list(set(index.column_names) - set(ordered_columns)) # Create a dictionary out of the pyarrow table table_dict = {k: index[k].to_numpy() for k in ordered_columns} @@ -252,8 +260,7 @@ def find_edges(p, s): for start, end in zip(group_offsets[:-1], group_offsets[1:]): key = tuple(sorted((k, merged[k][start].item()) for k in self._partitionby)) - data = {k: merged[k][start:end] for k in self._sortby + self._other} - groups[key] = data + groups[key] = {k: v[start:end] for k, v in merged.items()} return groups @@ -602,17 +609,17 @@ def partition_data_factory( source_id = value.get("SOURCE_ID") state_id = value.get("STATE_ID") - field_names: List[str] | None = None + field_names: List[str] = [] if field_id is not None and len(self._field) > 0: ufield_ids = self.par_unique(pool, ncpus, field_id) fields = self._field.take(ufield_ids) - field_names = fields["NAME"].to_numpy() + field_names = fields["NAME"].unique().to_numpy(zero_copy_only=False).tolist() - scan_numbers: List[int] | None = None + scan_numbers: List[int] = [] if scan_number is not None: - scan_numbers = list(self.par_unique(pool, ncpus, scan_number)) + scan_numbers = self.par_unique(pool, ncpus, scan_number).tolist() corr_type = Polarisations.from_values(self._pol["CORR_TYPE"][pol_id].as_py()) chan_freq = self._spw["CHAN_FREQ"][spw_id].as_py() @@ -705,5 +712,6 @@ def __init__( self._partitions[k] = self.partition_data_factory( name, auto_corrs, k, v, pool, ncpus ) + print(dataclasses.asdict(self._partitions[k])) logger.info("Reading %s structure in took %fs", name, modtime.time() - start) diff --git a/xarray_ms/testing/simulator.py b/xarray_ms/testing/simulator.py index b133777..d07ee00 100644 --- a/xarray_ms/testing/simulator.py +++ b/xarray_ms/testing/simulator.py @@ -103,6 +103,7 @@ class MSStructureSimulator: ntime: int nantenna: int + nfield: int auto_corrs: bool dump_rate: float time_chunks: int @@ -175,6 +176,7 @@ def __init__( self.ntime = ntime self.nantenna = nantenna + self.nfield = nfield self.auto_corrs = auto_corrs self.dump_rate = dump_rate self.time_chunks = time_chunks @@ -197,6 +199,7 @@ def nfeed(self) -> int: def simulate_ms(self, output_ms: str) -> None: """Simulate data into the given measurement set name""" table_desc = ADDITIONAL_COLUMNS if self.simulate_data else {} + # Generate descriptors, create simulated data from the descriptors # and write simulated data to the main Measurement Set with Table.ms_from_descriptor(output_ms, "MAIN", table_desc) as T: @@ -279,6 +282,15 @@ def simulate_ms(self, output_ms: str) -> None: T.putcol("MOUNT", np.asarray(["ALT-AZ" for _ in range(self.nantenna)])) T.putcol("STATION", np.asarray([f"STATION-{i}" for i in range(self.nantenna)])) + with Table.from_filename(f"{output_ms}::FIELD", **kw) as T: + T.addrows(self.nfield) + T.putcol("NAME", np.asarray([f"FIELD-{i}" for i in range(self.nfield)])) + T.putcol("SOURCE_ID", np.arange(self.nfield)) + + with Table.ms_from_descriptor(output_ms, "SOURCE") as T: + T.addrows(self.nfield) + T.putcol("NAME", np.asarray([f"SOURCE-{i}" for i in range(self.nfield)])) + def generate_descriptors(self) -> Generator[PartitionDescriptor, None, None]: """Generates a sequence of descriptors, each corresponding to a partition""" chunk_id = 0