Skip to content

Commit

Permalink
Parallelise row partitioning (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins authored Oct 4, 2024
1 parent aabca42 commit 8734bec
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 53 deletions.
2 changes: 1 addition & 1 deletion doc/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ Changelog

X.Y.Z (DD-MM-YYYY)
------------------
* Parallelise row partitioning (:pr:`28`, :pr:`30`)
* Upgrade to arcae 0.2.5 (:pr:`29`)
* Parallelise row map generation (:pr:`28`)
* Rename antenna{1,2}_name to baseline_antenna{1,2}_name (:pr:`26`)
* Update Cloud Storage write documentation (:pr:`25`, :pr:`27`)
* Use datatree as the primary representation (:pr:`24`)
Expand Down
59 changes: 57 additions & 2 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import concurrent.futures as cf
import pickle

import numpy as np
import pyarrow as pa
import pytest
from arcae.lib.arrow_tables import Table
from numpy.testing import assert_array_equal
from numpy.testing import assert_array_equal, assert_equal

from xarray_ms.backend.msv2.structure import MSv2StructureFactory, baseline_id
from xarray_ms.backend.msv2.structure import (
MSv2StructureFactory,
TablePartitioner,
baseline_id,
)
from xarray_ms.backend.msv2.table_factory import TableFactory


Expand All @@ -29,3 +35,52 @@ def test_structure_factory(simmed_ms):

keys = tuple(k for kv in structure_factory().keys() for k, _ in kv)
assert tuple(sorted(partition_columns)) == keys


def test_table_partitioner():
table = pa.Table.from_pydict(
{
"DATA_DESC_ID": pa.array([1, 1, 0, 0, 0], pa.int32()),
"FIELD_ID": pa.array([1, 0, 0, 1, 0], pa.int32()),
"TIME": pa.array([0, 1, 4, 3, 2], pa.float64()),
"ANTENNA1": pa.array([0, 0, 0, 0, 0], pa.int32()),
"ANTENNA2": pa.array([1, 1, 1, 1, 1], pa.int32()),
}
)

partitioner = TablePartitioner(
["DATA_DESC_ID", "FIELD_ID"], ["TIME", "ANTENNA1", "ANTENNA2"], ["row"]
)

with cf.ThreadPoolExecutor(max_workers=4) as pool:
groups = partitioner.partition(table, pool)

assert_equal(
groups,
{
(("DATA_DESC_ID", 0), ("FIELD_ID", 0)): {
"TIME": [2.0, 4.0],
"ANTENNA1": [0, 0],
"ANTENNA2": [1, 1],
"row": [4, 2],
},
(("DATA_DESC_ID", 0), ("FIELD_ID", 1)): {
"TIME": [3.0],
"ANTENNA1": [0],
"ANTENNA2": [1],
"row": [3],
},
(("DATA_DESC_ID", 1), ("FIELD_ID", 0)): {
"TIME": [1.0],
"ANTENNA1": [0],
"ANTENNA2": [1],
"row": [1],
},
(("DATA_DESC_ID", 1), ("FIELD_ID", 1)): {
"TIME": [0.0],
"ANTENNA1": [0],
"ANTENNA2": [1],
"row": [0],
},
},
)
117 changes: 67 additions & 50 deletions xarray_ms/backend/msv2/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import multiprocessing as mp
from collections import defaultdict
from functools import partial
from functools import partial, reduce
from numbers import Integral
from typing import (
Any,
Expand All @@ -18,10 +18,10 @@
Tuple,
)

import arcae
import numpy as np
import numpy.typing as npt
import pyarrow as pa
from arcae.lib.arrow_tables import Table, merge_np_partitions
from cacheout import Cache
from xarray.core.utils import FrozenDict

Expand Down Expand Up @@ -125,7 +125,7 @@ def is_partition_key(key: PartitionKeyT) -> bool:

@dataclasses.dataclass
class PartitionData:
"""Dataclass described data unique to a partition"""
"""Dataclass describing data unique to a partition"""

time: npt.NDArray[np.float64]
interval: npt.NDArray[np.float64]
Expand Down Expand Up @@ -156,8 +156,9 @@ def __init__(
self._sortby = list(sortby)
self._other = list(other)

def partition(self, index: pa.Table) -> Dict[PartitionKeyT, pa.Table]:
sortby = set(self._sortby)
def partition(
self, index: pa.Table, pool: cf.ThreadPoolExecutor
) -> Dict[PartitionKeyT, Dict[str, npt.NDArray]]:
other = set(self._other)

try:
Expand All @@ -168,33 +169,52 @@ def partition(self, index: pa.Table) -> Dict[PartitionKeyT, pa.Table]:
except KeyError:
pass

maybe_row = {"row"} if "row" in index.column_names else set()
read_columns = sortby | other
if not read_columns.issubset(set(index.column_names) - maybe_row):
raise ValueError(f"{read_columns} is not a subset of {index.column_names}")
nrow = len(index)
nworkers = pool._max_workers
chunk = (nrow + nworkers - 1) // nworkers

agg_cmd = [
(c, "list") for c in (maybe_row | set(read_columns) - set(self._partitionby))
ordered_columns = self._partitionby + self._sortby + self._other

# Create a dictionary out of the pyarrow table
table_dict = {k: index[k].to_numpy() for k in ordered_columns}
# Partition the range over the workers in the pool
partitions = [
{k: v[s : s + chunk] for k, v in table_dict.items()}
for s in range(0, nrow, chunk)
]
partitions = index.group_by(self._partitionby).aggregate(agg_cmd)
renames = {f"{c}_list": c for c, _ in agg_cmd}
partitions = partitions.rename_columns(
renames.get(c, c) for c in partitions.column_names
)

partition_map: Dict[PartitionKeyT, pa.Table] = {}
def sort_partition(p):
indices = np.lexsort(tuple(reversed(p.values())))
return {k: v[indices] for k, v in p.items()}

for p in range(len(partitions)):
key: PartitionKeyT = tuple(
sorted((c, int(partitions[c][p].as_py())) for c in self._partitionby)
)
table_dict = {c: partitions[c][p].values for c in read_columns | maybe_row}
partition_table = pa.Table.from_pydict(table_dict)
if sortby:
partition_table = partition_table.sort_by([(c, "ascending") for c in sortby])
partition_map[key] = partition_table
# Sort each partition in parallel
partitions = list(pool.map(sort_partition, partitions))
# Merge partitions
merged = merge_np_partitions(partitions)

# Find the group start and end points in parallel
def find_edges(p, s):
diffs = [np.diff(p[v]) > 0 for v in self._partitionby]
return np.where(np.logical_or(*diffs))[0] + s + 1

group_diffs = [
{k: v[s : s + chunk + 1] for k, v in merged.items() if k in self._partitionby}
for s in range(0, nrow, chunk)
]
starts = reduce(lambda x, y: x + [x[-1] + y], [chunk] * (len(group_diffs) - 1), [0])
assert len(starts) == len(group_diffs)
edges = list(pool.map(find_edges, group_diffs, starts))
group_offsets = np.concatenate([[0]] + edges + [[nrow]])

# Create the grouped partitions
groups = {}

return partition_map
for start, end in zip(group_offsets[:-1], group_offsets[1:]):
key = tuple(sorted((k, merged[k][start].item()) for k in self._partitionby))
data = {k: merged[k][start:end] for k in self._sortby + self._other}
groups[key] = data

return groups


def on_get_keep_alive(key, value, exists):
Expand Down Expand Up @@ -376,65 +396,62 @@ def __init__(
{c: ColumnDesc.from_descriptor(c, table_desc) for c in table.columns()}
)

with arcae.table(f"{name}::ANTENNA", lockoptions="nolock") as A:
with Table.from_filename(f"{name}::ANTENNA", lockoptions="nolock") as A:
self._ant = A.to_arrow()
table_desc = A.tabledesc()
col_descs["ANTENNA"] = FrozenDict(
{c: ColumnDesc.from_descriptor(c, table_desc) for c in A.columns()}
)

with arcae.table(f"{name}::FEED", lockoptions="nolock") as F:
with Table.from_filename(f"{name}::FEED", lockoptions="nolock") as F:
self._feed = F.to_arrow()
table_desc = F.tabledesc()
col_descs["FEED"] = FrozenDict(
{c: ColumnDesc.from_descriptor(c, table_desc) for c in F.columns()}
)

with arcae.table(f"{name}::DATA_DESCRIPTION", lockoptions="nolock") as D:
with Table.from_filename(f"{name}::DATA_DESCRIPTION", lockoptions="nolock") as D:
self._ddid = D.to_arrow()
table_desc = D.tabledesc()
col_descs["DATA_DESCRIPTION"] = FrozenDict(
{c: ColumnDesc.from_descriptor(c, table_desc) for c in D.columns()}
)

with arcae.table(f"{name}::SPECTRAL_WINDOW", lockoptions="nolock") as S:
with Table.from_filename(f"{name}::SPECTRAL_WINDOW", lockoptions="nolock") as S:
self._spw = S.to_arrow()
table_desc = S.tabledesc()
col_descs["SPECTRAL_WINDOW"] = FrozenDict(
{c: ColumnDesc.from_descriptor(c, table_desc) for c in S.columns()}
)

with arcae.table(f"{name}::POLARIZATION", lockoptions="nolock") as P:
with Table.from_filename(f"{name}::POLARIZATION", lockoptions="nolock") as P:
self._pol = P.to_arrow()
table_desc = P.tabledesc()
col_descs["POLARIZATION"] = FrozenDict(
{c: ColumnDesc.from_descriptor(c, table_desc) for c in P.columns()}
)

self._column_descs = FrozenDict(col_descs)

other_columns = ["INTERVAL"]
read_columns = set(partition_columns) | set(SORT_COLUMNS) | set(other_columns)
index = table.to_arrow(columns=read_columns)
partitions = TablePartitioner(
partition_columns, SORT_COLUMNS, other_columns + ["row"]
).partition(index)

self._partitions = {}

ncpus = mp.cpu_count()
unique_inv_fn = partial(np.unique, return_inverse=True)

with cf.ThreadPoolExecutor(max_workers=ncpus) as pool:
other_columns = ["INTERVAL"]
read_columns = set(partition_columns) | set(SORT_COLUMNS) | set(other_columns)
partitions = TablePartitioner(
partition_columns, SORT_COLUMNS, other_columns + ["row"]
).partition(table.to_arrow(columns=read_columns), pool)
self._partitions = {}

unique_inv_fn = partial(np.unique, return_inverse=True)

for k, v in partitions.items():
time = v["TIME"].to_numpy()
interval = v["INTERVAL"].to_numpy()
ant1 = v["ANTENNA1"].to_numpy()
ant2 = v["ANTENNA2"].to_numpy()
rows = v["row"].to_numpy()
time = v["TIME"]
interval = v["INTERVAL"]
ant1 = v["ANTENNA1"]
ant2 = v["ANTENNA2"]
rows = v["row"]

# Compute the unique times and their inverse index
chunk_size = len(time) // ncpus
chunk_size = (len(time) + ncpus - 1) // ncpus
time_chunks = [
time[i : i + chunk_size] for i in range(0, len(time), chunk_size)
]
Expand Down

0 comments on commit 8734bec

Please sign in to comment.