Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Jan 27, 2025
1 parent f64d46e commit 23d3ceb
Showing 1 changed file with 142 additions and 92 deletions.
234 changes: 142 additions & 92 deletions xarray_ms/backend/msv2/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def is_partition_key(key: PartitionKeyT) -> bool:
)


def partition_args(data: npt.NDArray, chunk: int) -> List[npt.NDArray]:
return [data[i : i + chunk] for i in range(0, len(data), chunk)]


DEFAULT_PARTITION_COLUMNS: List[str] = [
"DATA_DESC_ID",
"FIELD_ID",
Expand Down Expand Up @@ -159,10 +163,14 @@ class PartitionData:
chan_freq: npt.NDArray[np.float64]
chan_width: npt.NDArray[np.float64]
corr_type: npt.NDArray[np.int32]
field_names: List[str] | None
spw_name: str
spw_freq_group_name: str
spw_ref_freq: float
spw_frame: str
scan_numbers: List[int] | None
sub_scan_numbers: List[int] | None

row_map: npt.NDArray[np.int64]


Expand Down Expand Up @@ -200,44 +208,47 @@ def partition(
nworkers = pool._max_workers
chunk = (nrow + nworkers - 1) // nworkers

all_columns = index.column_names
ordered_columns = self._partitionby + self._sortby + self._other

# breakpoint()

# Create a dictionary out of the pyarrow table
table_dict = {k: index[k].to_numpy() for k in ordered_columns}
# Partition the range over the workers in the pool
# Partition the data over the workers in the pool
partitions = [
{k: v[s : s + chunk] for k, v in table_dict.items()}
for s in range(0, nrow, chunk)
]

# Sort each partition in parallel
def sort_partition(p):
indices = np.lexsort(tuple(reversed(p.values())))
sort_arrays = tuple(p[k] for k in reversed(ordered_columns))
indices = np.lexsort(sort_arrays)
return {k: v[indices] for k, v in p.items()}

# Sort each partition in parallel
partitions = list(pool.map(sort_partition, partitions))
# Merge partitions
merged = merge_np_partitions(partitions)

# Find the group start and end points in parallel
def find_edges(p, s):
diffs = [np.diff(p[v]) > 0 for v in self._partitionby]
return np.where(np.logical_or.reduce(diffs))[0] + s + 1

# Find the edges of the group partitions in parallel by
# partitioning the sorted merged values into chunks, including
# the starting value of the next chunk.
starts = list(range(0, nrow, chunk))

group_diffs = [
group_values = [
{k: v[s : s + chunk + 1] for k, v in merged.items() if k in self._partitionby}
for s in starts
]
assert len(starts) == len(group_diffs)
edges = list(pool.map(find_edges, group_diffs, starts))
assert len(starts) == len(group_values)

# Find the group start and end points in parallel by finding edges
def find_edges(p, s):
diffs = [np.diff(p[v]) > 0 for v in self._partitionby]
return np.where(np.logical_or.reduce(diffs))[0] + s + 1

edges = list(pool.map(find_edges, group_values, starts))
group_offsets = np.concatenate([[0]] + edges + [[nrow]])

# Create the grouped partitions
groups = {}
groups: Dict[PartitionKeyT, Dict[str, npt.NDArray]] = {}

for start, end in zip(group_offsets[:-1], group_offsets[1:]):
key = tuple(sorted((k, merged[k][start].item()) for k in self._partitionby))
Expand Down Expand Up @@ -423,10 +434,12 @@ def maybe_get_source_id(
source_id = np.empty_like(field_id)
chunk = (len(source_id) + ncpus - 1) // ncpus

def par_copy(i: int):
source_id[i : i + chunk] = field_source_id[field_id[i : i + chunk]]
def par_copy(source, field):
source[:] = field_source_id[field]

pool.map(par_copy, range(0, len(source_id), chunk))
pool.map(
par_copy, partition_args(source_id, chunk), partition_args(field_id, chunk)
)

return source_id

Expand Down Expand Up @@ -480,7 +493,7 @@ def partition_columns_from_schema(
def par_unique(pool, ncpus, data, return_inverse=False):
"""Parallel unique function using the associated threadpool"""
chunk_size = (len(data) + ncpus - 1) // ncpus
data_chunks = [data[i : i + chunk_size] for i in range(0, len(data), chunk_size)]
data_chunks = partition_args(data, chunk_size)
if return_inverse:
unique_fn = partial(np.unique, return_inverse=True)
udatas, indices = zip(*pool.map(unique_fn, data_chunks))
Expand All @@ -489,8 +502,14 @@ def par_unique(pool, ncpus, data, return_inverse=False):
def inv_fn(data, idx):
return np.searchsorted(udata, data)[idx]

def par_assign(target, data):
target[:] = data

data_ids = pool.map(inv_fn, udatas, indices)
return udata, np.concatenate(list(data_ids))
inverse = np.empty(len(data), dtype=indices[0].dtype)
pool.map(par_assign, partition_args(inverse, chunk_size), data_ids)

return udata, inverse
else:
udata = list(pool.map(np.unique, data_chunks))
return np.unique(np.concatenate(udata))
Expand Down Expand Up @@ -533,6 +552,107 @@ def read_subtables(

return subtables, coldescs

def partition_data_factory(
self,
name: str,
auto_corrs: bool,
key: PartitionKeyT,
value: Dict[str, npt.NDArray],
pool: cf.Executor,
ncpus: int,
) -> PartitionData:
"""Generate a `PartitionData` object"""
time = value["TIME"]
interval = value["INTERVAL"]
ant1 = value["ANTENNA1"]
ant2 = value["ANTENNA2"]
rows = value["row"]

# Compute the unique times and their inverse index
utime, time_ids = self.par_unique(pool, ncpus, time, return_inverse=True)

# Compute unique intervals
uinterval = self.par_unique(pool, ncpus, interval)

try:
ddid = next(i for (c, i) in key if c == "DATA_DESC_ID")
except StopIteration:
raise KeyError(f"DATA_DESC_ID must be present in partition key {key}")

if ddid >= len(self._ddid):
raise InvalidMeasurementSet(
f"DATA_DESC_ID {ddid} does not exist in {name}::DATA_DESCRIPTION"
)

spw_id = self._ddid["SPECTRAL_WINDOW_ID"][ddid].as_py()
pol_id = self._ddid["POLARIZATION_ID"][ddid].as_py()

if spw_id >= len(self._spw):
raise InvalidMeasurementSet(
f"SPECTRAL_WINDOW_ID {spw_id} does not exist in {name}::SPECTRAL_WINDOW"
)

if pol_id >= len(self._pol):
raise InvalidMeasurementSet(
f"POLARIZATION_ID {pol_id} does not exist in {name}::POLARIZATION"
)

scan_number = value.get("SCAN_NUMBER")
field_id = value.get("FIELD_ID")
source_id = value.get("SOURCE_ID")
state_id = value.get("STATE_ID")

field_names: List[str] | None = None

if field_id is not None and len(self._field) > 0:
ufield_ids = self.par_unique(pool, ncpus, field_id)
fields = self._field.take(ufield_ids)
field_names = fields["NAME"].to_numpy()

scan_numbers: List[int] | None = None

if scan_number is not None:
scan_numbers = list(self.par_unique(pool, ncpus, scan_number))

corr_type = Polarisations.from_values(self._pol["CORR_TYPE"][pol_id].as_py())
chan_freq = self._spw["CHAN_FREQ"][spw_id].as_py()
uchan_width = np.unique(self._spw["CHAN_WIDTH"][spw_id].as_py())

spw_name = self._spw["NAME"][spw_id].as_py()
spw_freq_group_name = self._spw["FREQ_GROUP_NAME"][spw_id].as_py()
spw_ref_freq = self._spw["REF_FREQUENCY"][spw_id].as_py()
spw_meas_freq_ref = self._spw["MEAS_FREQ_REF"][spw_id].as_py()
spw_frame = FrequencyMeasures(spw_meas_freq_ref).name.lower()

row_map = np.full(utime.size * self.nbl, -1, dtype=np.int64)
chunk_size = (len(rows) + ncpus - 1) // ncpus

def gen_row_map(time_ids, ant1, ant2, rows):
bl_ids = baseline_id(ant1, ant2, self.na, auto_corrs=auto_corrs)
row_map[time_ids * self.nbl + bl_ids] = rows

assert len(ant1) == len(rows)

# Generate the row map in parallel
s = partial(partition_args, chunk=chunk_size)
pool.map(gen_row_map, s(time_ids), s(ant1), s(ant2), s(rows))

return PartitionData(
time=utime,
interval=uinterval,
chan_freq=chan_freq,
chan_width=uchan_width,
corr_type=corr_type.to_str(),
field_names=field_names,
spw_name=spw_name,
spw_freq_group_name=spw_freq_group_name,
spw_ref_freq=spw_ref_freq,
spw_frame=spw_frame,
row_map=row_map.reshape(utime.size, self.nbl),
scan_numbers=scan_numbers,
sub_scan_numbers=None,
)

def __init__(
self, ms: TableFactory, partition_schema: List[str], auto_corrs: bool = True
):
Expand Down Expand Up @@ -582,78 +702,8 @@ def __init__(
self._partitions = {}

for k, v in partitions.items():
time = v["TIME"]
interval = v["INTERVAL"]
ant1 = v["ANTENNA1"]
ant2 = v["ANTENNA2"]
rows = v["row"]

# Compute the unique times and their inverse index
utime, time_ids = self.par_unique(pool, ncpus, time, return_inverse=True)

# Compute unique intervals
uinterval = self.par_unique(pool, ncpus, interval)

try:
ddid = next(i for (c, i) in k if c == "DATA_DESC_ID")
except StopIteration:
raise KeyError(f"DATA_DESC_ID must be present in partition key {k}")

if ddid >= len(self._ddid):
raise InvalidMeasurementSet(
f"DATA_DESC_ID {ddid} does not exist in {name}::DATA_DESCRIPTION"
)

spw_id = self._ddid["SPECTRAL_WINDOW_ID"][ddid].as_py()
pol_id = self._ddid["POLARIZATION_ID"][ddid].as_py()

if spw_id >= len(self._spw):
raise InvalidMeasurementSet(
f"SPECTRAL_WINDOW_ID {spw_id} does not exist in {name}::SPECTRAL_WINDOW"
)

if pol_id >= len(self._pol):
raise InvalidMeasurementSet(
f"POLARIZATION_ID {pol_id} does not exist in {name}::POLARIZATION"
)

corr_type = Polarisations.from_values(self._pol["CORR_TYPE"][pol_id].as_py())
chan_freq = self._spw["CHAN_FREQ"][spw_id].as_py()
uchan_width = np.unique(self._spw["CHAN_WIDTH"][spw_id].as_py())

spw_name = self._spw["NAME"][spw_id].as_py()
spw_freq_group_name = self._spw["FREQ_GROUP_NAME"][spw_id].as_py()
spw_ref_freq = self._spw["REF_FREQUENCY"][spw_id].as_py()
spw_meas_freq_ref = self._spw["MEAS_FREQ_REF"][spw_id].as_py()
spw_frame = FrequencyMeasures(spw_meas_freq_ref).name.lower()

row_map = np.full(utime.size * self.nbl, -1, dtype=np.int64)
chunk_size = (len(rows) + ncpus - 1) // ncpus

def gen_row_map(time_ids, ant1, ant2, rows):
bl_ids = baseline_id(ant1, ant2, self.na, auto_corrs=auto_corrs)
row_map[time_ids * self.nbl + bl_ids] = rows

def split_arg(data, chunk):
return [data[i : i + chunk] for i in range(0, len(data), chunk)]

assert len(ant1) == len(rows)

# Generate the row map in parallel
s = partial(split_arg, chunk=chunk_size)
pool.map(gen_row_map, s(time_ids), s(ant1), s(ant2), s(rows))

self._partitions[k] = PartitionData(
time=utime,
interval=uinterval,
chan_freq=chan_freq,
chan_width=uchan_width,
corr_type=corr_type.to_str(),
spw_name=spw_name,
spw_freq_group_name=spw_freq_group_name,
spw_ref_freq=spw_ref_freq,
spw_frame=spw_frame,
row_map=row_map.reshape(utime.size, self.nbl),
self._partitions[k] = self.partition_data_factory(
name, auto_corrs, k, v, pool, ncpus
)

logger.info("Reading %s structure in took %fs", name, modtime.time() - start)

0 comments on commit 23d3ceb

Please sign in to comment.