From 482873a41f2d79a4c26d5e1816fb433bf28351ce Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Fri, 11 Nov 2022 19:32:29 +0200 Subject: [PATCH] Support optional CASA columns (#270) --- .github/workflows/ci.yml | 61 ++++- HISTORY.rst | 4 + conftest.py | 33 +-- daskms/apps/conftest.py | 67 +++++ daskms/apps/convert.py | 275 ++----------------- daskms/apps/formats.py | 278 ++++++++++++++++++++ daskms/apps/tests/test_convert.py | 57 ++++ daskms/descriptors/ms.py | 4 +- daskms/descriptors/ms_subtable.py | 4 +- daskms/experimental/zarr/tests/test_zarr.py | 1 - daskms/reads.py | 8 + daskms/tests/test_optional.py | 4 +- daskms/writes.py | 6 +- 13 files changed, 521 insertions(+), 281 deletions(-) create mode 100644 daskms/apps/conftest.py create mode 100644 daskms/apps/formats.py create mode 100644 daskms/apps/tests/test_convert.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 04bf4ae1..c9ac37b0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,11 +44,6 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Checkout source - uses: actions/checkout@v2 - with: - fetch-depth: 1 - - name: Cache Installations id: cache-installs uses: actions/cache@v3 @@ -103,8 +98,62 @@ jobs: # if: ${{ failure() }} # uses: mxschmitt/action-tmate@v3 + test_apps: + needs: check_skip + runs-on: ubuntu-latest + if: "!contains(github.event.head_commit.message, '[skip ci]')" + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] + + steps: + - name: Create Cache Hash + run: | + export HASH=$(sha256sum <> $GITHUB_ENV + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Cache Installations + id: cache-installs + uses: actions/cache@v3 + with: + path: ~/.local + key: install-${{ env.INSTALL_CACHE_HASH }}-0 + + - name: Install Poetry + if: steps.cache-installs.outputs.cache-hit != 'true' + run: | + curl -sSL https://install.python-poetry.org | python3 - --version ${{ env.POETRY_VERSION }} + + - name: Test poetry run + run: poetry --version + + - name: Checkout source + uses: actions/checkout@v2 + with: + fetch-depth: 1 + + - name: Install dask-ms complete + run: poetry install --extras "testing complete" + + - name: Test dask-ms applications + run: poetry run py.test -s -vvv --applications daskms/ + + # - name: Debug with tmate on failure + # if: ${{ failure() }} + # uses: mxschmitt/action-tmate@v3 + + deploy: - needs: [test] + needs: [test, test_apps] runs-on: ubuntu-latest # Run on a push to a tag or master if: > diff --git a/HISTORY.rst b/HISTORY.rst index 722e1efa..fb8278e5 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -2,6 +2,10 @@ History ======= +X.Y.Z (YYYY-MM-DD) +------------------ +* Support optional CASA columns (:pr:`270`) + 0.2.15 (2022-10-19) ------------------- * Fix poetry install and cache hit detection on CI (:pr:`266`) diff --git a/conftest.py b/conftest.py index a4765711..d22915a7 100644 --- a/conftest.py +++ b/conftest.py @@ -1,11 +1,4 @@ # -*- coding: utf-8 -*- - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from os.path import join as pjoin - collect_ignore = ["setup.py"] @@ -24,24 +17,26 @@ def pytest_addoption(parser): default=False, help="Enable optional tests", ) + parser.addoption( + "--applications", + action="store_true", + dest="applications", + default=False, + help="Enable application tests", + ) def pytest_configure(config): # Add non-standard markers config.addinivalue_line("markers", "stress: long running stress tests") config.addinivalue_line("markers", "optional: optional tests") + config.addinivalue_line("markers", "applications: application tests") - # Enable/disable them based on parsed config - disable_str = [] - - if not config.option.stress: - disable_str.append("not stress") - - if not config.option.optional: - disable_str.append("not optional") + markexpr = [config.option.markexpr] if config.option.markexpr else [] - disable_str = " and ".join(disable_str) + for mark in ("stress", "optional", "applications"): + test = "" if getattr(config.option, mark, False) else "not " + markexpr.append(f"{test}{mark}") - if disable_str != "": - print(disable_str) - setattr(config.option, "markexpr", disable_str) + config.option.markexpr = " and ".join(markexpr) + print(config.option.markexpr) diff --git a/daskms/apps/conftest.py b/daskms/apps/conftest.py new file mode 100644 index 00000000..2b0dc971 --- /dev/null +++ b/daskms/apps/conftest.py @@ -0,0 +1,67 @@ +from appdirs import user_cache_dir +from hashlib import sha256 +import logging +from pathlib import Path +import tarfile + +import pytest + +log = logging.getLogger(__file__) + +TAU_MS = "HLTau_B6cont.calavg.tav300s" +TAU_MS_TAR = f"{TAU_MS}.tar.xz" +TAU_MS_TAR_HASH = "fc2ce9261817dfd88bbdd244c8e9e58ae0362173938df6ef2a587b1823147f70" +DATA_URL = f"s3://ratt-public-data/test-data/{TAU_MS_TAR}" + + +def download_tau_ms(tau_ms_tar): + if tau_ms_tar.exists(): + with open(tau_ms_tar, "rb") as f: + digest = sha256() + + while data := f.read(2**20): + digest.update(data) + + if digest.hexdigest() == TAU_MS_TAR_HASH: + return + + tau_ms_tar.unlink(missing_ok=True) + raise ValueError( + f"sha256 digest '{digest.hexdigest()}' " + f"of {tau_ms_tar} does not match " + f"{TAU_MS_TAR_HASH}" + ) + else: + s3fs = pytest.importorskip("s3fs") + s3 = s3fs.S3FileSystem(anon=True) + + for attempt in range(3): + with s3.open(DATA_URL, "rb") as fin, open(tau_ms_tar, "wb") as fout: + digest = sha256() + + while data := fin.read(2**20): + digest.update(data) + fout.write(data) + + if digest.hexdigest() == TAU_MS_TAR_HASH: + return + + log.warning("Download of %s failed on attempt %d", DATA_URL, attempt) + tau_ms_tar.unlink(missing_ok=True) + + raise ValueError(f"Download of {DATA_URL} failed {attempt} times") + + +@pytest.fixture(scope="function") +def tau_ms(tmp_path_factory): + cache_dir = Path(user_cache_dir("dask-ms")) / "test-data" + cache_dir.mkdir(parents=True, exist_ok=True) + tau_ms_tar = cache_dir / TAU_MS_TAR + + download_tau_ms(tau_ms_tar) + msdir = tmp_path_factory.mktemp("taums") + + with tarfile.open(tau_ms_tar) as tar: + tar.extractall(msdir) + + yield msdir / TAU_MS diff --git a/daskms/apps/convert.py b/daskms/apps/convert.py index be122854..fc489eeb 100644 --- a/daskms/apps/convert.py +++ b/daskms/apps/convert.py @@ -1,14 +1,12 @@ -import abc import ast from argparse import ArgumentTypeError from collections import defaultdict -from functools import partial import logging -from pathlib import Path import dask.array as da from daskms.apps.application import Application +from daskms.apps.formats import TableFormat, CasaFormat from daskms.fsspec_store import DaskMSStore log = logging.getLogger(__name__) @@ -41,236 +39,6 @@ def visit_Constant(self, node): return node.n -class TableFormat(abc.ABC): - @abc.abstractproperty - def version(self): - raise NotImplementedError - - @abc.abstractproperty - def subtables(self): - raise NotImplementedError - - @abc.abstractclassmethod - def reader(self): - raise NotImplementedError - - @abc.abstractclassmethod - def writer(self): - raise NotImplementedError - - @staticmethod - def from_store(store): - typ = store.type() - - if typ == "casa": - from daskms.table_proxy import TableProxy - import pyrap.tables as pt - - table_proxy = TableProxy( - pt.table, str(store.casa_path()), readonly=True, ack=False - ) - keywords = table_proxy.getkeywords().result() - - try: - version = str(keywords["MS_VERSION"]) - except KeyError: - typ = "plain" - version = "" - else: - typ = "measurementset" - - subtables = CasaFormat.find_subtables(keywords) - return CasaFormat(version, subtables, typ) - elif typ == "zarr": - subtables = ZarrFormat.find_subtables(store) - return ZarrFormat("0.1", subtables) - elif typ == "parquet": - subtables = ParquetFormat.find_subtables(store) - return ParquetFormat("0.1", subtables) - else: - raise ValueError(f"Unexpected table type {typ}") - - @staticmethod - def from_type(typ): - if typ == "ms": - return CasaFormat("2.0", [], "measurementset") - if typ == "casa": - return CasaFormat("", [], "plain") - elif typ == "zarr": - return ZarrFormat("0.1", []) - elif typ == "parquet": - return ParquetFormat("0.1", []) - else: - raise ValueError(f"Unexpected table type {typ}") - - -class BaseTableFormat(TableFormat): - def __init__(self, version): - self._version = version - - @property - def version(self): - return self._version - - def check_unused_kwargs(self, fn_name, **kwargs): - if kwargs: - raise NotImplementedError( - f"The following kwargs: " - f"{list(kwargs.keys())} " - f"were not consumed by " - f"{self.__class__.__name__}." - f"{fn_name}(**kw)" - ) - - -class CasaFormat(BaseTableFormat): - TABLE_PREFIX = "Table: " - TABLE_TYPES = set(["plain", "measurementset"]) - - def __init__(self, version, subtables, type="plain"): - super().__init__(version) - - if type not in self.TABLE_TYPES: - raise ValueError(f"{type} is not in {self.TABLE_TYPES}") - - self._subtables = subtables - self._type = type - - @classmethod - def find_subtables(cls, keywords): - return [k for k, v in keywords.items() if cls.is_subtable(v)] - - @classmethod - def is_subtable(cls, keyword): - if not isinstance(keyword, str): - return False - - if not keyword.startswith(cls.TABLE_PREFIX): - return False - - path = Path(keyword[len(cls.TABLE_PREFIX) :]) - return path.exists() and path.is_dir() and (path / "table.dat").exists() - - def is_measurement_set(self): - return self._type == "measurementset" - - def reader(self, **kw): - try: - group_cols = kw.pop("group_columns", None) - index_cols = kw.pop("index_columns", None) - taql_where = kw.pop("taql_where", "") - - if self.is_measurement_set(): - from daskms import xds_from_ms - - return partial( - xds_from_ms, - group_cols=group_cols, - index_cols=index_cols, - taql_where=taql_where, - ) - else: - from daskms import xds_from_table - - return xds_from_table - finally: - self.check_unused_kwargs("reader", **kw) - - def writer(self): - from daskms import xds_to_table - - if self.is_measurement_set(): - return partial(xds_to_table, descriptor="ms") - else: - return xds_to_table - - @property - def subtables(self): - return self._subtables - - def __str__(self): - return "casa" if not self.is_measurement_set() else self._type - - -class ZarrFormat(BaseTableFormat): - def __init__(self, version, subtables): - self._subtables = subtables - - @classmethod - def find_subtables(cls, store): - paths = (p.relative_to(store.path) for p in map(Path, store.subdirectories())) - - return [ - str(p) - for p in paths - if p.stem != "MAIN" and store.exists(str(p / ".zgroup")) - ] - - @property - def subtables(self): - return self._subtables - - def reader(self, **kw): - for arg in Convert.CASA_INPUT_ONLY_ARGS: - if kw.pop(arg, False): - raise ValueError(f'"{arg}" is not supported for zarr inputs') - - try: - from daskms.experimental.zarr import xds_from_zarr - - return xds_from_zarr - finally: - self.check_unused_kwargs("reader", **kw) - - def writer(self): - from daskms.experimental.zarr import xds_to_zarr - - return xds_to_zarr - - def __str__(self): - return "zarr" - - -class ParquetFormat(BaseTableFormat): - def __init__(self, version, subtables): - super().__init__(version) - self._subtables = subtables - - @classmethod - def find_subtables(cls, store): - paths = (p.relative_to(store.path) for p in map(Path, store.subdirectories())) - - return [ - str(p) - for p in paths - if p.stem != "MAIN" and store.exists(str(p / ".zgroup")) - ] - - @property - def subtables(self): - return self._subtables - - def reader(self, **kw): - for arg in Convert.CASA_INPUT_ONLY_ARGS: - if kw.pop(arg, False): - raise ValueError(f'"{arg}" is not supported for parquet inputs') - - try: - from daskms.experimental.arrow.reads import xds_from_parquet - - return xds_from_parquet - finally: - self.check_unused_kwargs("reader", **kw) - - def writer(self): - from daskms.experimental.arrow.writes import xds_to_parquet - - return xds_to_parquet - - def __str__(self): - return "parquet" - - NONUNIFORM_SUBTABLES = ["SPECTRAL_WINDOW", "POLARIZATION", "FEED", "SOURCE"] @@ -288,7 +56,10 @@ def _check_output_path(output: str): def _check_exclude_columns(columns: str): - outputs = defaultdict(list) + if not columns: + return {} + + outputs = defaultdict(set) for column in (c.strip() for c in columns.split(",")): bits = column.split("::") @@ -296,15 +67,22 @@ def _check_exclude_columns(columns: str): if len(bits) == 2: table, column = bits elif len(bits) == 1: - table, column = "MAIN", bits + table, column = "MAIN", bits[0] else: - raise ValueError( + raise ArgumentTypeError( f"Excluded columns must be of the form " f"COLUMN or SUBTABLE::COLUMN. " f"Received {column}" ) - outputs[table].append(column) + outputs[table].add(column) + + outputs = { + table: "*" if "*" in columns else columns for table, columns in outputs.items() + } + + if outputs.get("MAIN", "") == "*": + raise ValueError("Excluding all columns in the MAIN table is not supported") return outputs @@ -314,9 +92,6 @@ def parse_chunks(chunks: str): class Convert(Application): - TABLE_KEYWORD_PREFIX = "Table: " - CASA_INPUT_ONLY_ARGS = ("group_columns", "index_columns", "taql_where") - def __init__(self, args, log): self.log = log self.args = args @@ -340,10 +115,11 @@ def setup_parser(cls, parser): help="Comma-separated list of columns to exclude. " "For example 'CORRECTED_DATA," "SPECTRAL_WINDOW::EFFECTIVE_BW' " - "will exclude 'CORRECTED_DATA " + "will exclude CORRECTED_DATA " "from the main table and " "EFFECTIVE_BW from the SPECTRAL_WINDOW " - "subtable", + "subtable. SPECTRAL_WINDOW::* will exclude " + "the entire SPECTRAL_WINDOW subtable", ) parser.add_argument( "-g", @@ -456,9 +232,9 @@ def convert_table(self, args): # Drop any ROWID columns datasets = [ds.drop_vars("ROWID", errors="ignore") for ds in datasets] - for exclude_column in args.exclude.get("MAIN", []): + if exclude_columns := args.exclude.get("MAIN", False): datasets = [ - ds.drop_vars(exclude_column, errors="ignore") for ds in datasets + ds.drop_vars(exclude_columns, errors="ignore") for ds in datasets ] if isinstance(out_fmt, CasaFormat): @@ -472,14 +248,17 @@ def convert_table(self, args): # Now do the subtables for table in list(in_fmt.subtables): - if table in {"SORTED_TABLE", "SOURCE"}: + if ( + table in {"SORTED_TABLE", "SOURCE"} + or args.exclude.get(table, "") == "*" + ): log.warning(f"Ignoring {table}") continue in_store = args.input.subtable_store(table) in_fmt = TableFormat.from_store(in_store) out_store = args.output.subtable_store(table) - out_fmt = TableFormat.from_type(args.format) + out_fmt = TableFormat.from_type(args.format, subtable=table) reader = in_fmt.reader() writer = out_fmt.writer() @@ -489,9 +268,9 @@ def convert_table(self, args): else: datasets = reader(in_store) - for exclude_column in args.exclude.get(table, []): + if exclude_columns := args.exclude.get(table, False): datasets = [ - ds.drop_vars(exclude_column, errors="ignore") for ds in datasets + ds.drop_vars(exclude_columns, errors="ignore") for ds in datasets ] if isinstance(in_fmt, CasaFormat): diff --git a/daskms/apps/formats.py b/daskms/apps/formats.py new file mode 100644 index 00000000..40b61bf8 --- /dev/null +++ b/daskms/apps/formats.py @@ -0,0 +1,278 @@ +import abc +from functools import partial +from pathlib import Path + +CASA_INPUT_ONLY_ARGS = ("group_columns", "index_columns", "taql_where") + + +class TableFormat(abc.ABC): + @abc.abstractproperty + def version(self): + raise NotImplementedError + + @abc.abstractproperty + def subtables(self): + raise NotImplementedError + + @abc.abstractclassmethod + def reader(self): + raise NotImplementedError + + @abc.abstractclassmethod + def writer(self): + raise NotImplementedError + + @staticmethod + def from_store(store): + typ = store.type() + + if typ == "casa": + from daskms.table_proxy import TableProxy + import pyrap.tables as pt + + table_proxy = TableProxy(pt.table, store.root, readonly=True, ack=False) + keywords = table_proxy.getkeywords().result() + subtables = CasaFormat.find_subtables(keywords) + + try: + version = str(keywords["MS_VERSION"]) + except KeyError: + cls = CasaMainFormat + version = "" + else: + cls = MeasurementSetFormat + + main_fmt = cls(version, subtables) + + if store.table: + return main_fmt.subtable_format(store.table) + + return main_fmt + + elif typ == "zarr": + subtables = ZarrFormat.find_subtables(store) + return ZarrFormat("0.1", subtables) + elif typ == "parquet": + subtables = ParquetFormat.find_subtables(store) + return ParquetFormat("0.1", subtables) + else: + raise ValueError(f"Unexpected table type {typ}") + + @staticmethod + def from_type(typ, subtable=""): + if typ == "ms": + if subtable: + return MSSubtableFormat("2.0", subtable) + else: + return MeasurementSetFormat("2.0", []) + elif typ == "casa": + if subtable: + return CasaSubtableFormat("", subtable) + else: + return CasaMainFormat("", []) + elif typ == "zarr": + return ZarrFormat("0.1", []) + elif typ == "parquet": + return ParquetFormat("0.1", []) + else: + raise ValueError(f"Unexpected table type {typ}") + + +class BaseTableFormat(TableFormat): + def __init__(self, version): + self._version = version + + @property + def version(self): + return self._version + + def check_unused_kwargs(self, fn_name, **kwargs): + if kwargs: + raise NotImplementedError( + f"The following kwargs: " + f"{list(kwargs.keys())} " + f"were not consumed by " + f"{self.__class__.__name__}." + f"{fn_name}(**kw)" + ) + + +class CasaFormat(BaseTableFormat): + TABLE_PREFIX = "Table: " + + @classmethod + def find_subtables(cls, keywords): + return [k for k, v in keywords.items() if cls.is_subtable(v)] + + @classmethod + def is_subtable(cls, keyword: str): + if not isinstance(keyword, str): + return False + + if not keyword.startswith(cls.TABLE_PREFIX): + return False + + path = Path(keyword[len(cls.TABLE_PREFIX) :]) + return path.exists() and path.is_dir() and (path / "table.dat").exists() + + +class CasaMainFormat(CasaFormat): + def __init__(self, version, subtables): + super().__init__(version) + self._subtables = subtables + + def subtable_format(self, subtable: str): + if subtable not in self._subtables: + raise ValueError(f"{subtable} is not a valid subtable") + + return CasaSubtableFormat(self.version, subtable) + + @property + def subtables(self): + return self._subtables + + def subtable_format(self, subtable): + return CasaSubtableFormat(self.version, subtable) + + def __str__(self): + return "casa" + + +class CasaSubtableFormat(CasaFormat): + def __init__(self, version, subtable): + super().__init__(version) + self._subtable = subtable + + @property + def subtables(self): + return [] + + def reader(self, **kw): + self.check_unused_kwargs("CasaSubtableFormat.reader", **kw) + from daskms import xds_from_table + + return xds_from_table + + def writer(self): + from daskms import xds_to_table + + return xds_to_table + + +class MSSubtableFormat(CasaSubtableFormat): + def writer(self): + from daskms import xds_to_table + from daskms.table_schemas import SUBTABLES + + if self._subtable in SUBTABLES: + descriptor = f"mssubtable('{self._subtable}')" + else: + descriptor = None + + return partial(xds_to_table, descriptor=descriptor) + + +class MeasurementSetFormat(CasaMainFormat): + def __init__(self, version, subtables): + super().__init__(version, subtables) + + def __str__(self): + return "MeasurementSet" + + def reader(self, **kw): + group_cols = kw.pop("group_columns", None) + index_cols = kw.pop("index_columns", None) + taql_where = kw.pop("taql_where", "") + self.check_unused_kwargs("reader", **kw) + + from daskms import xds_from_ms + + return partial( + xds_from_ms, + group_cols=group_cols, + index_cols=index_cols, + taql_where=taql_where, + ) + + def writer(self): + from daskms import xds_to_table + + return partial(xds_to_table, descriptor="ms") + + +class ZarrFormat(BaseTableFormat): + def __init__(self, version, subtables): + super().__init__(version) + self._subtables = subtables + + @classmethod + def find_subtables(cls, store): + paths = (p.relative_to(store.path) for p in map(Path, store.subdirectories())) + + return [ + str(p) + for p in paths + if p.stem != "MAIN" and store.exists(str(p / ".zgroup")) + ] + + @property + def subtables(self): + return self._subtables + + def reader(self, **kw): + for arg in CASA_INPUT_ONLY_ARGS: + if kw.pop(arg, False): + raise ValueError(f'"{arg}" is not supported for zarr inputs') + + self.check_unused_kwargs("reader", **kw) + + from daskms.experimental.zarr import xds_from_zarr + + return xds_from_zarr + + def writer(self): + from daskms.experimental.zarr import xds_to_zarr + + return xds_to_zarr + + def __str__(self): + return "zarr" + + +class ParquetFormat(BaseTableFormat): + def __init__(self, version, subtables): + super().__init__(version) + self._subtables = subtables + + @classmethod + def find_subtables(cls, store): + paths = (p.relative_to(store.path) for p in map(Path, store.subdirectories())) + + return [ + str(p) + for p in paths + if p.stem != "MAIN" and store.exists(str(p / ".zgroup")) + ] + + @property + def subtables(self): + return self._subtables + + def reader(self, **kw): + for arg in CASA_INPUT_ONLY_ARGS: + if kw.pop(arg, False): + raise ValueError(f'"{arg}" is not supported for parquet inputs') + + self.check_unused_kwargs("reader", **kw) + + from daskms.experimental.arrow.reads import xds_from_parquet + + return xds_from_parquet + + def writer(self): + from daskms.experimental.arrow.writes import xds_to_parquet + + return xds_to_parquet + + def __str__(self): + return "parquet" diff --git a/daskms/apps/tests/test_convert.py b/daskms/apps/tests/test_convert.py new file mode 100644 index 00000000..795ee834 --- /dev/null +++ b/daskms/apps/tests/test_convert.py @@ -0,0 +1,57 @@ +from argparse import ArgumentParser +import logging + +from daskms.apps.convert import Convert +from daskms import xds_from_storage_ms, xds_from_storage_table + +import pytest + +log = logging.getLogger(__file__) + + +@pytest.mark.applications +@pytest.mark.parametrize("format", ["ms", "zarr", "parquet"]) +def test_convert_application(tau_ms, format, tmp_path_factory): + OUTPUT = tmp_path_factory.mktemp(f"convert_{format}") / "output.{format}" + + exclude_columns = [ + "ASDM_ANTENNA::*", + "ASDM_CALATMOSPHERE::*", + "ASDM_CALWVR::*", + "ASDM_RECEIVER::*", + "ASDM_SOURCE::*", + "ASDM_STATION::*", + "POINTING::OVER_THE_TOP", + "MODEL_DATA", + ] + + args = [ + str(tau_ms), + # "-g", + # "FIELD_ID,DATA_DESC_ID,SCAN_NUMBER", + "-x", + ",".join(exclude_columns), + "-o", + str(OUTPUT), + "--format", + "zarr", + "--force", + ] + + p = ArgumentParser() + Convert.setup_parser(p) + args = p.parse_args(args) + app = Convert(args, log) + app.execute() + + datasets = xds_from_storage_ms(OUTPUT) + + for ds in datasets: + assert "MODEL_DATA" not in ds.data_vars + assert "FLAG" in ds.data_vars + + datasets = xds_from_storage_table(f"{str(OUTPUT)}::POINTING") + + for ds in datasets: + assert "OVER_THE_TOP" not in ds.data_vars + assert "NAME" in ds.data_vars diff --git a/daskms/descriptors/ms.py b/daskms/descriptors/ms.py index 5880c3c8..1b7dd2a6 100644 --- a/daskms/descriptors/ms.py +++ b/daskms/descriptors/ms.py @@ -30,8 +30,8 @@ class MSDescriptorBuilder(AbstractDescriptorBuilder): def __init__(self, fixed=True): super(AbstractDescriptorBuilder, self).__init__() - self.DEFAULT_MS_DESC = pt.required_ms_desc() - self.REQUIRED_FIELDS = set(self.DEFAULT_MS_DESC.keys()) + self.DEFAULT_MS_DESC = pt.complete_ms_desc() + self.REQUIRED_FIELDS = set(pt.required_ms_desc().keys()) self.fixed = fixed self.ms_dims = None diff --git a/daskms/descriptors/ms_subtable.py b/daskms/descriptors/ms_subtable.py index a7ae7410..fe0b4526 100644 --- a/daskms/descriptors/ms_subtable.py +++ b/daskms/descriptors/ms_subtable.py @@ -18,8 +18,8 @@ def __init__(self, subtable): ) self.subtable = subtable - self.DEFAULT_TABLE_DESC = pt.required_ms_desc(subtable) - self.REQUIRED_FIELDS = set(self.DEFAULT_TABLE_DESC.keys()) + self.DEFAULT_TABLE_DESC = pt.complete_ms_desc(subtable) + self.REQUIRED_FIELDS = set(pt.required_ms_desc(subtable).keys()) def default_descriptor(self): return self.DEFAULT_TABLE_DESC.copy() diff --git a/daskms/experimental/zarr/tests/test_zarr.py b/daskms/experimental/zarr/tests/test_zarr.py index 85695856..e7864d56 100644 --- a/daskms/experimental/zarr/tests/test_zarr.py +++ b/daskms/experimental/zarr/tests/test_zarr.py @@ -234,7 +234,6 @@ def test_multiprocess_create(ms, tmp_path_factory): assert_array_equal(v, getattr(zds, k)) -@pytest.mark.optional @pytest.mark.skipif(xarray is None, reason="depends on xarray") def test_xarray_to_zarr(ms, tmp_path_factory): store = tmp_path_factory.mktemp("zarr_store") diff --git a/daskms/reads.py b/daskms/reads.py index df227026..1e476050 100644 --- a/daskms/reads.py +++ b/daskms/reads.py @@ -154,6 +154,10 @@ def getter_wrapper(row_orders, *args): blc, trc = zip(*args[:nextent_args]) shape = tuple(t - b + 1 for b, t in zip(blc, trc)) result = np.empty((np.sum(row_runs[:, 1]),) + shape, dtype=dtype) + + if result.size == 0: + return result + io_fn = object_getcolslice if np.dtype == object else ndarray_getcolslice # Submit table I/O on executor @@ -164,6 +168,10 @@ def getter_wrapper(row_orders, *args): # for each row is requested, so we defer to getcol else: result = np.empty((np.sum(row_runs[:, 1]),) + col_shape, dtype=dtype) + + if result.size == 0: + return result + io_fn = object_getcol if dtype == object else ndarray_getcol # Submit table I/O on executor diff --git a/daskms/tests/test_optional.py b/daskms/tests/test_optional.py index f3695fdb..85aa23e7 100644 --- a/daskms/tests/test_optional.py +++ b/daskms/tests/test_optional.py @@ -181,12 +181,12 @@ def test_only_row_shape(tmp_path, column, dtype): assert T.getcol(column).shape == (10,) # Must be ndim == 2 - err_str = "Vector: ndim of other array > 1 ndim 1 differs from 2" + err_str = "Vector: ndim of other array > 1 -- ndim 1 differs from 2" with pytest.raises(RuntimeError, match=err_str): T.putcol(column, np.zeros((5, 40), dtype=dtype)) # shape != (20, 30) - err_str = "Vector: ndim of other array > 1 ndim 1 differs from 3" + err_str = "Vector: ndim of other array > 1 -- ndim 1 differs from 3" with pytest.raises(RuntimeError, match=err_str): T.putcol(column, np.zeros((5, 40, 30), dtype=dtype), startrow=0, nrow=5) diff --git a/daskms/writes.py b/daskms/writes.py index 3d8e6365..f772ce6f 100644 --- a/daskms/writes.py +++ b/daskms/writes.py @@ -328,13 +328,17 @@ def _updated_table(table, datasets, columns, descriptor): # Original Data Manager Groups odminfo = {g["NAME"] for g in table_proxy.getdminfo().result().values()} + SENTINEL = object() # NOTE(JSKenyon): Add columns one at a time - this avoids issues when # adding multiple columns with different managers. for col in missing: _table_desc = {col: table_desc[col]} _dminfo = builder.dminfo(_table_desc) - _dminfo = {} if _dminfo["*1"]["NAME"] in odminfo else _dminfo + in_odminfo = any( + v.get("NAME", SENTINEL) in odminfo for v in _dminfo.values() + ) + _dminfo = {} if in_odminfo else _dminfo table_proxy.addcols(_table_desc, dminfo=_dminfo).result() return table_proxy