diff --git a/xarray_ms/backend/msv2/structure.py b/xarray_ms/backend/msv2/structure.py index a7a05da..0194623 100644 --- a/xarray_ms/backend/msv2/structure.py +++ b/xarray_ms/backend/msv2/structure.py @@ -105,6 +105,10 @@ def is_partition_key(key: PartitionKeyT) -> bool: ) +def partition_args(data: npt.NDArray, chunk: int) -> List[npt.NDArray]: + return [data[i : i + chunk] for i in range(0, len(data), chunk)] + + DEFAULT_PARTITION_COLUMNS: List[str] = [ "DATA_DESC_ID", "FIELD_ID", @@ -159,10 +163,14 @@ class PartitionData: chan_freq: npt.NDArray[np.float64] chan_width: npt.NDArray[np.float64] corr_type: npt.NDArray[np.int32] + field_names: List[str] | None spw_name: str spw_freq_group_name: str spw_ref_freq: float spw_frame: str + scan_numbers: List[int] | None + sub_scan_numbers: List[int] | None + row_map: npt.NDArray[np.int64] @@ -200,44 +208,47 @@ def partition( nworkers = pool._max_workers chunk = (nrow + nworkers - 1) // nworkers + all_columns = index.column_names ordered_columns = self._partitionby + self._sortby + self._other - # breakpoint() - # Create a dictionary out of the pyarrow table table_dict = {k: index[k].to_numpy() for k in ordered_columns} - # Partition the range over the workers in the pool + # Partition the data over the workers in the pool partitions = [ {k: v[s : s + chunk] for k, v in table_dict.items()} for s in range(0, nrow, chunk) ] + # Sort each partition in parallel def sort_partition(p): - indices = np.lexsort(tuple(reversed(p.values()))) + sort_arrays = tuple(p[k] for k in reversed(ordered_columns)) + indices = np.lexsort(sort_arrays) return {k: v[indices] for k, v in p.items()} - # Sort each partition in parallel partitions = list(pool.map(sort_partition, partitions)) # Merge partitions merged = merge_np_partitions(partitions) - # Find the group start and end points in parallel - def find_edges(p, s): - diffs = [np.diff(p[v]) > 0 for v in self._partitionby] - return np.where(np.logical_or.reduce(diffs))[0] + s + 1 - + # Find the edges of the group partitions in parallel by + # partitioning the sorted merged values into chunks, including + # the starting value of the next chunk. starts = list(range(0, nrow, chunk)) - - group_diffs = [ + group_values = [ {k: v[s : s + chunk + 1] for k, v in merged.items() if k in self._partitionby} for s in starts ] - assert len(starts) == len(group_diffs) - edges = list(pool.map(find_edges, group_diffs, starts)) + assert len(starts) == len(group_values) + + # Find the group start and end points in parallel by finding edges + def find_edges(p, s): + diffs = [np.diff(p[v]) > 0 for v in self._partitionby] + return np.where(np.logical_or.reduce(diffs))[0] + s + 1 + + edges = list(pool.map(find_edges, group_values, starts)) group_offsets = np.concatenate([[0]] + edges + [[nrow]]) # Create the grouped partitions - groups = {} + groups: Dict[PartitionKeyT, Dict[str, npt.NDArray]] = {} for start, end in zip(group_offsets[:-1], group_offsets[1:]): key = tuple(sorted((k, merged[k][start].item()) for k in self._partitionby)) @@ -423,10 +434,12 @@ def maybe_get_source_id( source_id = np.empty_like(field_id) chunk = (len(source_id) + ncpus - 1) // ncpus - def par_copy(i: int): - source_id[i : i + chunk] = field_source_id[field_id[i : i + chunk]] + def par_copy(source, field): + source[:] = field_source_id[field] - pool.map(par_copy, range(0, len(source_id), chunk)) + pool.map( + par_copy, partition_args(source_id, chunk), partition_args(field_id, chunk) + ) return source_id @@ -480,7 +493,7 @@ def partition_columns_from_schema( def par_unique(pool, ncpus, data, return_inverse=False): """Parallel unique function using the associated threadpool""" chunk_size = (len(data) + ncpus - 1) // ncpus - data_chunks = [data[i : i + chunk_size] for i in range(0, len(data), chunk_size)] + data_chunks = partition_args(data, chunk_size) if return_inverse: unique_fn = partial(np.unique, return_inverse=True) udatas, indices = zip(*pool.map(unique_fn, data_chunks)) @@ -489,8 +502,14 @@ def par_unique(pool, ncpus, data, return_inverse=False): def inv_fn(data, idx): return np.searchsorted(udata, data)[idx] + def par_assign(target, data): + target[:] = data + data_ids = pool.map(inv_fn, udatas, indices) - return udata, np.concatenate(list(data_ids)) + inverse = np.empty(len(data), dtype=indices[0].dtype) + pool.map(par_assign, partition_args(inverse, chunk_size), data_ids) + + return udata, inverse else: udata = list(pool.map(np.unique, data_chunks)) return np.unique(np.concatenate(udata)) @@ -533,6 +552,107 @@ def read_subtables( return subtables, coldescs + def partition_data_factory( + self, + name: str, + auto_corrs: bool, + key: PartitionKeyT, + value: Dict[str, npt.NDArray], + pool: cf.Executor, + ncpus: int, + ) -> PartitionData: + """Generate a `PartitionData` object""" + time = value["TIME"] + interval = value["INTERVAL"] + ant1 = value["ANTENNA1"] + ant2 = value["ANTENNA2"] + rows = value["row"] + + # Compute the unique times and their inverse index + utime, time_ids = self.par_unique(pool, ncpus, time, return_inverse=True) + + # Compute unique intervals + uinterval = self.par_unique(pool, ncpus, interval) + + try: + ddid = next(i for (c, i) in key if c == "DATA_DESC_ID") + except StopIteration: + raise KeyError(f"DATA_DESC_ID must be present in partition key {key}") + + if ddid >= len(self._ddid): + raise InvalidMeasurementSet( + f"DATA_DESC_ID {ddid} does not exist in {name}::DATA_DESCRIPTION" + ) + + spw_id = self._ddid["SPECTRAL_WINDOW_ID"][ddid].as_py() + pol_id = self._ddid["POLARIZATION_ID"][ddid].as_py() + + if spw_id >= len(self._spw): + raise InvalidMeasurementSet( + f"SPECTRAL_WINDOW_ID {spw_id} does not exist in {name}::SPECTRAL_WINDOW" + ) + + if pol_id >= len(self._pol): + raise InvalidMeasurementSet( + f"POLARIZATION_ID {pol_id} does not exist in {name}::POLARIZATION" + ) + + scan_number = value.get("SCAN_NUMBER") + field_id = value.get("FIELD_ID") + source_id = value.get("SOURCE_ID") + state_id = value.get("STATE_ID") + + field_names: List[str] | None = None + + if field_id is not None and len(self._field) > 0: + ufield_ids = self.par_unique(pool, ncpus, field_id) + fields = self._field.take(ufield_ids) + field_names = fields["NAME"].to_numpy() + + scan_numbers: List[int] | None = None + + if scan_number is not None: + scan_numbers = list(self.par_unique(pool, ncpus, scan_number)) + + corr_type = Polarisations.from_values(self._pol["CORR_TYPE"][pol_id].as_py()) + chan_freq = self._spw["CHAN_FREQ"][spw_id].as_py() + uchan_width = np.unique(self._spw["CHAN_WIDTH"][spw_id].as_py()) + + spw_name = self._spw["NAME"][spw_id].as_py() + spw_freq_group_name = self._spw["FREQ_GROUP_NAME"][spw_id].as_py() + spw_ref_freq = self._spw["REF_FREQUENCY"][spw_id].as_py() + spw_meas_freq_ref = self._spw["MEAS_FREQ_REF"][spw_id].as_py() + spw_frame = FrequencyMeasures(spw_meas_freq_ref).name.lower() + + row_map = np.full(utime.size * self.nbl, -1, dtype=np.int64) + chunk_size = (len(rows) + ncpus - 1) // ncpus + + def gen_row_map(time_ids, ant1, ant2, rows): + bl_ids = baseline_id(ant1, ant2, self.na, auto_corrs=auto_corrs) + row_map[time_ids * self.nbl + bl_ids] = rows + + assert len(ant1) == len(rows) + + # Generate the row map in parallel + s = partial(partition_args, chunk=chunk_size) + pool.map(gen_row_map, s(time_ids), s(ant1), s(ant2), s(rows)) + + return PartitionData( + time=utime, + interval=uinterval, + chan_freq=chan_freq, + chan_width=uchan_width, + corr_type=corr_type.to_str(), + field_names=field_names, + spw_name=spw_name, + spw_freq_group_name=spw_freq_group_name, + spw_ref_freq=spw_ref_freq, + spw_frame=spw_frame, + row_map=row_map.reshape(utime.size, self.nbl), + scan_numbers=scan_numbers, + sub_scan_numbers=None, + ) + def __init__( self, ms: TableFactory, partition_schema: List[str], auto_corrs: bool = True ): @@ -582,78 +702,8 @@ def __init__( self._partitions = {} for k, v in partitions.items(): - time = v["TIME"] - interval = v["INTERVAL"] - ant1 = v["ANTENNA1"] - ant2 = v["ANTENNA2"] - rows = v["row"] - - # Compute the unique times and their inverse index - utime, time_ids = self.par_unique(pool, ncpus, time, return_inverse=True) - - # Compute unique intervals - uinterval = self.par_unique(pool, ncpus, interval) - - try: - ddid = next(i for (c, i) in k if c == "DATA_DESC_ID") - except StopIteration: - raise KeyError(f"DATA_DESC_ID must be present in partition key {k}") - - if ddid >= len(self._ddid): - raise InvalidMeasurementSet( - f"DATA_DESC_ID {ddid} does not exist in {name}::DATA_DESCRIPTION" - ) - - spw_id = self._ddid["SPECTRAL_WINDOW_ID"][ddid].as_py() - pol_id = self._ddid["POLARIZATION_ID"][ddid].as_py() - - if spw_id >= len(self._spw): - raise InvalidMeasurementSet( - f"SPECTRAL_WINDOW_ID {spw_id} does not exist in {name}::SPECTRAL_WINDOW" - ) - - if pol_id >= len(self._pol): - raise InvalidMeasurementSet( - f"POLARIZATION_ID {pol_id} does not exist in {name}::POLARIZATION" - ) - - corr_type = Polarisations.from_values(self._pol["CORR_TYPE"][pol_id].as_py()) - chan_freq = self._spw["CHAN_FREQ"][spw_id].as_py() - uchan_width = np.unique(self._spw["CHAN_WIDTH"][spw_id].as_py()) - - spw_name = self._spw["NAME"][spw_id].as_py() - spw_freq_group_name = self._spw["FREQ_GROUP_NAME"][spw_id].as_py() - spw_ref_freq = self._spw["REF_FREQUENCY"][spw_id].as_py() - spw_meas_freq_ref = self._spw["MEAS_FREQ_REF"][spw_id].as_py() - spw_frame = FrequencyMeasures(spw_meas_freq_ref).name.lower() - - row_map = np.full(utime.size * self.nbl, -1, dtype=np.int64) - chunk_size = (len(rows) + ncpus - 1) // ncpus - - def gen_row_map(time_ids, ant1, ant2, rows): - bl_ids = baseline_id(ant1, ant2, self.na, auto_corrs=auto_corrs) - row_map[time_ids * self.nbl + bl_ids] = rows - - def split_arg(data, chunk): - return [data[i : i + chunk] for i in range(0, len(data), chunk)] - - assert len(ant1) == len(rows) - - # Generate the row map in parallel - s = partial(split_arg, chunk=chunk_size) - pool.map(gen_row_map, s(time_ids), s(ant1), s(ant2), s(rows)) - - self._partitions[k] = PartitionData( - time=utime, - interval=uinterval, - chan_freq=chan_freq, - chan_width=uchan_width, - corr_type=corr_type.to_str(), - spw_name=spw_name, - spw_freq_group_name=spw_freq_group_name, - spw_ref_freq=spw_ref_freq, - spw_frame=spw_frame, - row_map=row_map.reshape(utime.size, self.nbl), + self._partitions[k] = self.partition_data_factory( + name, auto_corrs, k, v, pool, ncpus ) logger.info("Reading %s structure in took %fs", name, modtime.time() - start)