Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use epoch argument distinguish multiple instances of the same dataset #54

Merged
merged 2 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Changelog

X.Y.Z (DD-MM-YYYY)
------------------
* Use epoch to dintinguish multiple instances of the same dataset (:pr:`54`)
* Use np.logical_or.reduce for generating diffs over more than 2 partitioning arrays (:pr:`53`)
* Improve Missing Column error (:pr:`52`)
* Fix `open_datatree` instructions in the README (:pr:`51`)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_row_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ def irregular_ms(tmp_path, request):
@pytest.mark.parametrize("auto_corrs", [True, False])
def test_row_mapping(irregular_ms, na, auto_corrs):
table_factory = TableFactory(Table.from_filename, irregular_ms)
structure_factory = MSv2StructureFactory(table_factory, auto_corrs=auto_corrs)
structure_factory = MSv2StructureFactory(
table_factory, [], "abcdef", auto_corrs=auto_corrs
)
structure = structure_factory()

ddid = table_factory().getcol("DATA_DESC_ID")
Expand Down
20 changes: 17 additions & 3 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@ def test_baseline_id(na, auto_corrs):


@pytest.mark.parametrize("simmed_ms", [{"name": "proxy.ms"}], indirect=True)
def test_structure_factory(simmed_ms):
@pytest.mark.parametrize("epoch", ["abcdef"])
def test_structure_factory(simmed_ms, epoch):
partition_columns = ["FIELD_ID", "DATA_DESC_ID", "OBSERVATION_ID"]
table_factory = TableFactory(Table.from_filename, simmed_ms)
structure_factory = MSv2StructureFactory(table_factory, partition_columns)
structure_factory = MSv2StructureFactory(table_factory, partition_columns, epoch)
assert pickle.loads(pickle.dumps(structure_factory)) == structure_factory

structure_factory2 = MSv2StructureFactory(table_factory, partition_columns)
structure_factory2 = MSv2StructureFactory(table_factory, partition_columns, epoch)
assert structure_factory() is structure_factory2()

keys = tuple(k for kv in structure_factory().keys() for k, _ in kv)
Expand Down Expand Up @@ -84,3 +85,16 @@ def test_table_partitioner():
},
},
)


def test_epoch(simmed_ms):
partition_columns = ["FIELD_ID", "DATA_DESC_ID", "OBSERVATION_ID"]
table_factory = TableFactory(Table.from_filename, simmed_ms)
structure_factory = MSv2StructureFactory(table_factory, partition_columns, "abc")
structure_factory2 = MSv2StructureFactory(table_factory, partition_columns, "abc")

assert structure_factory() is structure_factory2()

structure_factory3 = MSv2StructureFactory(table_factory, partition_columns, "def")

assert structure_factory() is not structure_factory3()
2 changes: 1 addition & 1 deletion xarray_ms/backend/msv2/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def initialise_default_args(
epoch = epoch or uuid4().hex[:8]
partition_columns = partition_columns or DEFAULT_PARTITION_COLUMNS
structure_factory = structure_factory or MSv2StructureFactory(
table_factory, partition_columns, auto_corrs=auto_corrs
table_factory, partition_columns, epoch, auto_corrs=auto_corrs
)
return epoch, table_factory, partition_columns, structure_factory

Expand Down
15 changes: 12 additions & 3 deletions xarray_ms/backend/msv2/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,16 +229,22 @@ class MSv2StructureFactory:

_ms_factory: TableFactory
_partition_columns: List[str]
_epoch: str
_auto_corrs: bool
_STRUCTURE_CACHE: ClassVar[Cache] = Cache(
maxsize=100, ttl=60, on_get=on_get_keep_alive
)

def __init__(
self, ms: TableFactory, partition_columns: List[str], auto_corrs: bool = True
self,
ms: TableFactory,
partition_columns: List[str],
epoch: str,
auto_corrs: bool = True,
):
self._ms_factory = ms
self._partition_columns = partition_columns
self._epoch = epoch
self._auto_corrs = auto_corrs

def __eq__(self, other: Any) -> bool:
Expand All @@ -248,16 +254,19 @@ def __eq__(self, other: Any) -> bool:
return (
self._ms_factory == other._ms_factory
and self._partition_columns == other._partition_columns
and self._epoch == other._epoch
and self._auto_corrs == other._auto_corrs
)

def __hash__(self):
return hash((self._ms_factory, tuple(self._partition_columns), self._auto_corrs))
return hash(
(self._ms_factory, tuple(self._partition_columns), self._epoch, self._auto_corrs)
)

def __reduce__(self):
return (
MSv2StructureFactory,
(self._ms_factory, self._partition_columns, self._auto_corrs),
(self._ms_factory, self._partition_columns, self._epoch, self._auto_corrs),
)

def __call__(self, *args, **kw) -> MSv2Structure:
Expand Down