Skip to content

Commit

Permalink
Update partioning key with subtable partitioning columns (SOURCE_ID a…
Browse files Browse the repository at this point in the history
…nd OBS_MODE)
  • Loading branch information
sjperkins committed Feb 4, 2025
1 parent 6aebd38 commit d9f4044
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 17 deletions.
16 changes: 8 additions & 8 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ def test_baseline_id(na, auto_corrs):
@pytest.mark.parametrize("simmed_ms", [{"name": "proxy.ms"}], indirect=True)
@pytest.mark.parametrize("epoch", ["abcdef"])
def test_structure_factory(simmed_ms, epoch):
partition_columns = ["FIELD_ID", "DATA_DESC_ID", "OBSERVATION_ID"]
partition_schema = ["FIELD_ID", "DATA_DESC_ID", "OBSERVATION_ID", "OBS_MODE"]
table_factory = TableFactory(Table.from_filename, simmed_ms)
structure_factory = MSv2StructureFactory(table_factory, partition_columns, epoch)
structure_factory = MSv2StructureFactory(table_factory, partition_schema, epoch)
assert pickle.loads(pickle.dumps(structure_factory)) == structure_factory

structure_factory2 = MSv2StructureFactory(table_factory, partition_columns, epoch)
structure_factory2 = MSv2StructureFactory(table_factory, partition_schema, epoch)
assert structure_factory() is structure_factory2()

keys = tuple(k for kv in structure_factory().keys() for k, _ in kv)
assert tuple(sorted(partition_columns)) == keys
assert tuple(sorted(partition_schema)) == keys


def test_table_partitioner():
Expand Down Expand Up @@ -96,13 +96,13 @@ def test_table_partitioner():


def test_epoch(simmed_ms):
partition_columns = ["FIELD_ID", "DATA_DESC_ID", "OBSERVATION_ID"]
partition_schema = ["FIELD_ID", "DATA_DESC_ID", "OBSERVATION_ID"]
table_factory = TableFactory(Table.from_filename, simmed_ms)
structure_factory = MSv2StructureFactory(table_factory, partition_columns, "abc")
structure_factory2 = MSv2StructureFactory(table_factory, partition_columns, "abc")
structure_factory = MSv2StructureFactory(table_factory, partition_schema, "abc")
structure_factory2 = MSv2StructureFactory(table_factory, partition_schema, "abc")

assert structure_factory() is structure_factory2()

structure_factory3 = MSv2StructureFactory(table_factory, partition_columns, "def")
structure_factory3 = MSv2StructureFactory(table_factory, partition_schema, "def")

assert structure_factory() is not structure_factory3()
47 changes: 38 additions & 9 deletions xarray_ms/backend/msv2/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,10 @@ def resolve_key(self, key: str | PartitionKeyT | None) -> List[PartitionKeyT]:
if isinstance(key, str):
key = self.parse_partition_key(key)

column_set = set(self._partition_columns)
column_set = set(self._partition_columns) | set(self._subtable_partition_columns)

# Check that the key columns and values are valid
new_key: List[Tuple[str, int]] = []
new_key: List[Tuple[str, int | str]] = []
for column, value in key:
column = column.upper()
column = SHORT_TO_LONG_PARTITION_COLUMNS.get(column, column)
Expand All @@ -419,8 +419,11 @@ def resolve_key(self, key: str | PartitionKeyT | None) -> List[PartitionKeyT]:
f"{column} is not valid a valid partition column "
f"{self._partition_columns}"
)
if not isinstance(value, Integral):
raise InvalidPartitionKey(f"{value} is not a valid partition value")
if not isinstance(value, (str, Integral)):
raise InvalidPartitionKey(
f"{value} is an invalid partition key value. "
f"Should be an integer or (rarely) a string"
)
new_key.append((column, value))

key_set = set(new_key)
Expand Down Expand Up @@ -458,7 +461,12 @@ def par_copy(source, field):
def partition_columns_from_schema(
self, partition_schema: List[str]
) -> Tuple[List[str], List[str]]:
"""Given a partitioning schema, produce a list of partitioning columns"""
"""Given a partitioning schema, produce
1. a list of partitioning columns of the MAIN table
2. a list of subtable partitioning columns.
i.e. `FIELD.SOURCE_ID` and `STATE.OBS_MODE`
"""
schema: Set[str] = set(partition_schema)

# Always partition by these columns
Expand Down Expand Up @@ -566,8 +574,16 @@ def partition_data_factory(
value: Dict[str, npt.NDArray],
pool: cf.Executor,
ncpus: int,
) -> PartitionData:
"""Generate a `PartitionData` object"""
) -> Tuple[PartitionKeyT, PartitionData]:
"""Generate an updated partition key and
`PartitionData` object.
The partition key is updated with subtable partitioning keys
(primarily `FIELD.SOURCE_ID` and `STATE.OBSMODE`).
The `PartitionData` object represents a summary of the
partition data passed in via arguments.
"""
time = value["TIME"]
interval = value["INTERVAL"]
ant1 = value["ANTENNA1"]
Expand Down Expand Up @@ -601,6 +617,11 @@ def partition_data_factory(
fields = self._field.take(ufield_ids)
field_names = fields["NAME"].to_pylist()
source_ids = fields["SOURCE_ID"].to_pylist()

if "SOURCE_ID" in self._subtable_partition_columns:
assert len(source_ids) == 1
key += (("SOURCE_ID", source_ids[0]),)

# Select out SOURCES if we have the table
if hasattr(self, "_source") and len(self._source) > 0:
sources = self._source.take(source_ids)
Expand All @@ -626,6 +647,10 @@ def partition_data_factory(
intents = states["OBS_MODE"].to_numpy(zero_copy_only=False).tolist()
sub_scan_numbers = states["SUB_SCAN"].to_numpy(zero_copy_only=False).tolist()

if "OBS_MODE" in self._subtable_partition_columns:
assert len(intents) == 1
key += (("OBS_MODE", intents[0]),)

# Extract polarization information
pol_id = self._ddid["POLARIZATION_ID"][ddid].as_py()

Expand Down Expand Up @@ -665,7 +690,7 @@ def gen_row_map(time_ids, ant1, ant2, rows):
s = partial(partition_args, chunk=chunk_size)
pool.map(gen_row_map, s(time_ids), s(ant1), s(ant2), s(rows))

return PartitionData(
partition_data = PartitionData(
time=utime,
interval=uinterval,
chan_freq=chan_freq,
Expand All @@ -684,6 +709,8 @@ def gen_row_map(time_ids, ant1, ant2, rows):
sub_scan_numbers=sub_scan_numbers,
)

return tuple(sorted(key)), partition_data

def __init__(
self, ms: TableFactory, partition_schema: List[str], auto_corrs: bool = True
):
Expand All @@ -697,6 +724,7 @@ def __init__(

self._ms_factory = ms
self._partition_columns = partition_columns
self._subtable_partition_columns = subtable_columns
self._auto_corrs = auto_corrs

ms_table = ms()
Expand Down Expand Up @@ -733,8 +761,9 @@ def __init__(
self._partitions = {}

for k, v in partitions.items():
self._partitions[k] = self.partition_data_factory(
key, partition = self.partition_data_factory(
name, auto_corrs, k, v, pool, ncpus
)
self._partitions[key] = partition

logger.info("Reading %s structure in took %fs", name, modtime.time() - start)

0 comments on commit d9f4044

Please sign in to comment.