Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Jan 30, 2025
1 parent 23d3ceb commit 8e8aa0d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
28 changes: 18 additions & 10 deletions xarray_ms/backend/msv2/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ class PartitionData:
chan_freq: npt.NDArray[np.float64]
chan_width: npt.NDArray[np.float64]
corr_type: npt.NDArray[np.int32]
field_names: List[str] | None
field_names: List[str]
spw_name: str
spw_freq_group_name: str
spw_ref_freq: float
spw_frame: str
scan_numbers: List[int] | None
sub_scan_numbers: List[int] | None
scan_numbers: List[int]
sub_scan_numbers: List[int]

row_map: npt.NDArray[np.int64]

Expand Down Expand Up @@ -208,8 +208,16 @@ def partition(
nworkers = pool._max_workers
chunk = (nrow + nworkers - 1) // nworkers

all_columns = index.column_names
# Order columns by
#
# 1. Partitioning columns
# 2. Sorting columns
# 3. Others (such as row and INTERVAL)
# 4. Remaining columns
#
# 4 is needed for the merge_np_partitions to work
ordered_columns = self._partitionby + self._sortby + self._other
ordered_columns += list(set(index.column_names) - set(ordered_columns))

# Create a dictionary out of the pyarrow table
table_dict = {k: index[k].to_numpy() for k in ordered_columns}
Expand Down Expand Up @@ -252,8 +260,7 @@ def find_edges(p, s):

for start, end in zip(group_offsets[:-1], group_offsets[1:]):
key = tuple(sorted((k, merged[k][start].item()) for k in self._partitionby))
data = {k: merged[k][start:end] for k in self._sortby + self._other}
groups[key] = data
groups[key] = {k: v[start:end] for k, v in merged.items()}

return groups

Expand Down Expand Up @@ -602,17 +609,17 @@ def partition_data_factory(
source_id = value.get("SOURCE_ID")
state_id = value.get("STATE_ID")

field_names: List[str] | None = None
field_names: List[str] = []

if field_id is not None and len(self._field) > 0:
ufield_ids = self.par_unique(pool, ncpus, field_id)
fields = self._field.take(ufield_ids)
field_names = fields["NAME"].to_numpy()
field_names = fields["NAME"].unique().to_numpy(zero_copy_only=False).tolist()

scan_numbers: List[int] | None = None
scan_numbers: List[int] = []

if scan_number is not None:
scan_numbers = list(self.par_unique(pool, ncpus, scan_number))
scan_numbers = self.par_unique(pool, ncpus, scan_number).tolist()

corr_type = Polarisations.from_values(self._pol["CORR_TYPE"][pol_id].as_py())
chan_freq = self._spw["CHAN_FREQ"][spw_id].as_py()
Expand Down Expand Up @@ -705,5 +712,6 @@ def __init__(
self._partitions[k] = self.partition_data_factory(
name, auto_corrs, k, v, pool, ncpus
)
print(dataclasses.asdict(self._partitions[k]))

logger.info("Reading %s structure in took %fs", name, modtime.time() - start)
12 changes: 12 additions & 0 deletions xarray_ms/testing/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class MSStructureSimulator:

ntime: int
nantenna: int
nfield: int
auto_corrs: bool
dump_rate: float
time_chunks: int
Expand Down Expand Up @@ -175,6 +176,7 @@ def __init__(

self.ntime = ntime
self.nantenna = nantenna
self.nfield = nfield
self.auto_corrs = auto_corrs
self.dump_rate = dump_rate
self.time_chunks = time_chunks
Expand All @@ -197,6 +199,7 @@ def nfeed(self) -> int:
def simulate_ms(self, output_ms: str) -> None:
"""Simulate data into the given measurement set name"""
table_desc = ADDITIONAL_COLUMNS if self.simulate_data else {}

# Generate descriptors, create simulated data from the descriptors
# and write simulated data to the main Measurement Set
with Table.ms_from_descriptor(output_ms, "MAIN", table_desc) as T:
Expand Down Expand Up @@ -279,6 +282,15 @@ def simulate_ms(self, output_ms: str) -> None:
T.putcol("MOUNT", np.asarray(["ALT-AZ" for _ in range(self.nantenna)]))
T.putcol("STATION", np.asarray([f"STATION-{i}" for i in range(self.nantenna)]))

with Table.from_filename(f"{output_ms}::FIELD", **kw) as T:
T.addrows(self.nfield)
T.putcol("NAME", np.asarray([f"FIELD-{i}" for i in range(self.nfield)]))
T.putcol("SOURCE_ID", np.arange(self.nfield))

with Table.ms_from_descriptor(output_ms, "SOURCE") as T:
T.addrows(self.nfield)
T.putcol("NAME", np.asarray([f"SOURCE-{i}" for i in range(self.nfield)]))

def generate_descriptors(self) -> Generator[PartitionDescriptor, None, None]:
"""Generates a sequence of descriptors, each corresponding to a partition"""
chunk_id = 0
Expand Down

0 comments on commit 8e8aa0d

Please sign in to comment.