diff --git a/altair/datasets/_readers.py b/altair/datasets/_readers.py index f76cc5a0a..0a18c1e61 100644 --- a/altair/datasets/_readers.py +++ b/altair/datasets/_readers.py @@ -60,6 +60,7 @@ from typing import TypeAlias else: from typing_extensions import TypeAlias + from packaging.requirements import Requirement from altair.datasets._typing import Dataset, Extension, Metadata from altair.vegalite.v5.schema._typing import OneOrSeq @@ -379,7 +380,7 @@ class _PyArrowReader(_Reader["pa.Table", "pa.Table"]): def _maybe_fn(self, meta: Metadata, /) -> Callable[..., pa.Table]: fn = super()._maybe_fn(meta) - if fn is self._read_json_polars: + if fn == self._read_json_polars: return fn elif meta["is_json"]: if meta["is_tabular"]: @@ -550,7 +551,7 @@ def _requirements(s: _ConcreteT, /) -> _ConcreteT: ... def _requirements(s: Literal["pandas[pyarrow]"], /) -> tuple[_Pandas, _PyArrow]: ... -def _requirements(s: _Backend, /): +def _requirements(s: Any, /) -> Any: concrete: set[Literal[_Polars, _Pandas, _PyArrow]] = {"polars", "pandas", "pyarrow"} if s in concrete: return s @@ -559,12 +560,13 @@ def _requirements(s: _Backend, /): req = Requirement(s) supports_extras: set[Literal[_Pandas]] = {"pandas"} - if req.name in supports_extras: - name = req.name - if (extras := req.extras) and extras == {"pyarrow"}: - extra = "pyarrow" - return name, extra - else: - raise NotImplementedError(s) - else: - raise NotImplementedError(s) + if req.name in supports_extras and req.extras == {"pyarrow"}: + return req.name, "pyarrow" + return _requirements_unknown(req) + + +def _requirements_unknown(req: Requirement | str, /) -> Any: + from packaging.requirements import Requirement + + req = Requirement(req) if isinstance(req, str) else req + return (req.name, *req.extras) diff --git a/pyproject.toml b/pyproject.toml index 5ac95f190..03e33cc36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -262,16 +262,18 @@ cwd = "." [tool.taskipy.tasks] lint = "ruff check" format = "ruff format --diff --check" +ruff-check = "task lint && task format" ruff-fix = "task lint && ruff format" type-check = "mypy altair tests" -pytest = "pytest" -test = "task lint && task format && task type-check && task pytest" -test-fast = "task ruff-fix && pytest -m \"not slow\"" -test-slow = "task ruff-fix && pytest -m \"slow\"" -test-datasets = "task ruff-fix && pytest tests -k test_datasets -m \"\"" -test-min = "task lint && task format && task type-check && hatch test --python 3.9" -test-all = "task lint && task format && task type-check && hatch test --all" +pytest-serial = "pytest -m \"no_xdist\" --numprocesses=1" +pytest = "pytest && task pytest-serial" +test = "task ruff-check && task type-check && task pytest" +test-fast = "task ruff-fix && pytest -m \"not slow and not datasets_debug and not no_xdist\"" +test-slow = "task ruff-fix && pytest -m \"slow and not datasets_debug and not no_xdist\"" +test-datasets = "task ruff-fix && pytest tests -k test_datasets -m \"not no_xdist\" && task pytest-serial" +test-min = "task ruff-check && task type-check && hatch test --python 3.9" +test-all = "task ruff-check && task type-check && hatch test --all" generate-schema-wrapper = "mypy tools && python tools/generate_schema_wrapper.py && task test" @@ -303,12 +305,13 @@ addopts = [ "tests", "altair", "tools", - "-m not datasets_debug", + "-m not datasets_debug and not no_xdist", ] # https://docs.pytest.org/en/stable/how-to/mark.html#registering-marks markers = [ "slow: Label tests as slow (deselect with '-m \"not slow\"')", - "datasets_debug: Disabled by default due to high number of requests" + "datasets_debug: Disabled by default due to high number of requests", + "no_xdist: Unsafe to run in parallel" ] [tool.mypy] diff --git a/tests/__init__.py b/tests/__init__.py index 5d78dce0d..80c27fc2c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -60,6 +60,16 @@ def windows_has_tzdata() -> bool: >>> hatch run test-slow --durations=25 # doctest: +SKIP """ +no_xdist: pytest.MarkDecorator = pytest.mark.no_xdist() +""" +Custom ``pytest.mark`` decorator. + +Each marked test will run **serially**, after all other selected tests. + +.. tip:: + Use as a last resort when a test depends on manipulating global state. +""" + skip_requires_ipython: pytest.MarkDecorator = pytest.mark.skipif( find_spec("IPython") is None, reason="`IPython` not installed." ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 3ccdba273..b212d79ce 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,6 +1,5 @@ from __future__ import annotations -import contextlib import datetime as dt import re import sys @@ -15,18 +14,14 @@ from narwhals.stable import v1 as nw from narwhals.stable.v1 import dependencies as nw_dep -from altair.datasets import Loader, url +from altair.datasets import Loader from altair.datasets._exceptions import AltairDatasetsError from altair.datasets._typing import Dataset, Extension, Metadata, is_ext_read -from tests import skip_requires_pyarrow, slow - -if sys.version_info >= (3, 14): - from typing import TypedDict -else: - from typing_extensions import TypedDict +from tests import no_xdist, skip_requires_pyarrow +from tools import fs if TYPE_CHECKING: - from collections.abc import Container, Iterator + from collections.abc import Callable, Container, Iterator, Mapping from pathlib import Path from typing import Literal @@ -34,7 +29,7 @@ import polars as pl from _pytest.mark.structures import ParameterSet - from altair.datasets._readers import _Backend, _PandasAny, _Polars + from altair.datasets._readers import _Backend, _PandasAny, _Polars, _PyArrow from altair.vegalite.v5.schema._typing import OneOrSeq from tests import MarksType @@ -45,46 +40,24 @@ PolarsLoader: TypeAlias = Loader[pl.DataFrame, pl.LazyFrame] CACHE_ENV_VAR: Literal["ALTAIR_DATASETS_DIR"] = "ALTAIR_DATASETS_DIR" - - -class DatasetSpec(TypedDict, total=False): - """Exceptional cases which cannot rely on defaults.""" - - name: Dataset - suffix: Extension - marks: MarksType - - -requires_pyarrow: pytest.MarkDecorator = skip_requires_pyarrow() - -_b_params = { +_backend_params: Mapping[_Backend, ParameterSet] = { "polars": pytest.param("polars"), - "pandas": pytest.param( - "pandas", - marks=pytest.mark.xfail( - find_spec("pyarrow") is None, - reason=( - "`pandas` supports backends other than `pyarrow` for `.parquet`.\n" - "However, none of these are currently an `altair` dependency." - ), - ), - ), - "pandas[pyarrow]": pytest.param("pandas[pyarrow]", marks=requires_pyarrow), - "pyarrow": pytest.param("pyarrow", marks=requires_pyarrow), + "pandas": pytest.param("pandas"), + "pandas[pyarrow]": pytest.param("pandas[pyarrow]", marks=skip_requires_pyarrow()), + "pyarrow": pytest.param("pyarrow", marks=skip_requires_pyarrow()), } -backends: pytest.MarkDecorator = pytest.mark.parametrize("backend", _b_params.values()) +backends: pytest.MarkDecorator = pytest.mark.parametrize( + "backend", _backend_params.values() +) backends_no_polars: pytest.MarkDecorator = pytest.mark.parametrize( - "backend", [v for k, v in _b_params.items() if k != "polars"] + "backend", [v for k, v in _backend_params.items() if k != "polars"] ) backends_pandas_any: pytest.MarkDecorator = pytest.mark.parametrize( - "backend", [v for k, v in _b_params.items() if "pandas" in k] -) -backends_single: pytest.MarkDecorator = pytest.mark.parametrize( - "backend", [v for k, v in _b_params.items() if "[" not in k] + "backend", [v for k, v in _backend_params.items() if "pandas" in k] ) -backends_multi: pytest.MarkDecorator = pytest.mark.parametrize( - "backend", [v for k, v in _b_params.items() if "[" in k] +backends_pyarrow: pytest.MarkDecorator = pytest.mark.parametrize( + "backend", [v for k, v in _backend_params.items() if k == "pyarrow"] ) datasets_debug: pytest.MarkDecorator = pytest.mark.datasets_debug() @@ -100,24 +73,12 @@ class DatasetSpec(TypedDict, total=False): """ -@pytest.fixture -def is_flaky_datasets(request: pytest.FixtureRequest) -> bool: - mark_filter = request.config.getoption("-m", None) # pyright: ignore[reportArgumentType] - if mark_filter is None: - return False - elif mark_filter == "": - return True - elif isinstance(mark_filter, str): - return False - else: - raise TypeError(mark_filter) - - @pytest.fixture(scope="session") -def polars_loader(tmp_path_factory: pytest.TempPathFactory) -> PolarsLoader: - data = Loader.from_backend("polars") - data.cache.path = tmp_path_factory.mktemp("loader-cache-polars") - return data +def polars_loader() -> PolarsLoader: + load = Loader.from_backend("polars") + if load.cache.is_not_active(): + load.cache.path = load.cache._XDG_CACHE + return load @pytest.fixture( @@ -127,17 +88,6 @@ def spatial_datasets(request: pytest.FixtureRequest) -> Dataset: return request.param -@backends_no_polars -def test_spatial(spatial_datasets, backend: _Backend) -> None: - load = Loader.from_backend(backend) - pattern = re.compile( - rf"{spatial_datasets}.+geospatial.+native.+{re.escape(backend)}.+url", - flags=re.DOTALL | re.IGNORECASE, - ) - with pytest.raises(NotImplementedError, match=pattern): - load(spatial_datasets) - - @pytest.fixture def metadata_columns() -> frozenset[str]: """ @@ -158,25 +108,65 @@ def metadata_columns() -> frozenset[str]: ) -def match_url(name: Dataset, url: str) -> bool: +def is_frame_backend(frame: Any, backend: _Backend, /) -> bool: + pandas_any: set[_PandasAny] = {"pandas", "pandas[pyarrow]"} + if backend in pandas_any: + return nw_dep.is_pandas_dataframe(frame) + elif backend == "pyarrow": + return nw_dep.is_pyarrow_table(frame) + elif backend == "polars": + return nw_dep.is_polars_dataframe(frame) + else: + raise TypeError(backend) + + +def is_loader_backend(loader: Loader[Any, Any], backend: _Backend, /) -> bool: + return repr(loader) == f"{type(loader).__name__}[{backend}]" + + +def is_url(name: Dataset, fn_url: Callable[..., str], /) -> bool: pattern = rf".+/vega-datasets@.+/data/{name}\..+" + url = fn_url(name) return re.match(pattern, url) is not None +def is_polars_backed_pyarrow(loader: Loader[Any, Any], /) -> bool: + """ + User requested ``pyarrow``, but also has ``polars`` installed. + + Notes + ----- + - Currently, defers to ``polars`` only for ``.json``. + """ + return bool( + is_loader_backend(loader, "pyarrow") + and (fn := getattr(loader._reader, "_read_json_polars", None)) + and fn == loader._reader.read_fn("dummy.json") + ) + + +@backends +def test_metadata_columns(backend: _Backend, metadata_columns: frozenset[str]) -> None: + """Ensure all backends will query the same column names.""" + load = Loader.from_backend(backend) + schema_columns = load._reader._scan_metadata().collect().columns + assert set(schema_columns) == metadata_columns + + @backends def test_loader_from_backend(backend: _Backend) -> None: - data = Loader.from_backend(backend) - assert data._reader._name == backend + load = Loader.from_backend(backend) + assert is_loader_backend(load, backend) @backends def test_loader_url(backend: _Backend) -> None: - data = Loader.from_backend(backend) - dataset_name: Dataset = "volcano" - assert match_url(dataset_name, data.url(dataset_name)) + load = Loader.from_backend(backend) + assert is_url("volcano", load.url) -def test_load(monkeypatch: pytest.MonkeyPatch) -> None: +@no_xdist +def test_load_infer_priority(monkeypatch: pytest.MonkeyPatch) -> None: """ Inferring the best backend available. @@ -187,7 +177,7 @@ def test_load(monkeypatch: pytest.MonkeyPatch) -> None: import altair.datasets._loader from altair.datasets import load - assert load._reader._name == "polars" + assert is_loader_backend(load, "polars") monkeypatch.delattr(altair.datasets._loader, "load", raising=False) monkeypatch.setitem(sys.modules, "polars", None) @@ -196,20 +186,20 @@ def test_load(monkeypatch: pytest.MonkeyPatch) -> None: if find_spec("pyarrow") is None: # NOTE: We can end the test early for the CI job that removes `pyarrow` - assert load._reader._name == "pandas" + assert is_loader_backend(load, "pandas") monkeypatch.delattr(altair.datasets._loader, "load") monkeypatch.setitem(sys.modules, "pandas", None) with pytest.raises(AltairDatasetsError, match=r"no.+backend"): from altair.datasets import load else: - assert load._reader._name == "pandas[pyarrow]" + assert is_loader_backend(load, "pandas[pyarrow]") monkeypatch.delattr(altair.datasets._loader, "load") monkeypatch.setitem(sys.modules, "pyarrow", None) from altair.datasets import load - assert load._reader._name == "pandas" + assert is_loader_backend(load, "pandas") monkeypatch.delattr(altair.datasets._loader, "load") monkeypatch.setitem(sys.modules, "pandas", None) @@ -217,7 +207,7 @@ def test_load(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setitem(sys.modules, "pyarrow", import_module("pyarrow")) from altair.datasets import load - assert load._reader._name == "pyarrow" + assert is_loader_backend(load, "pyarrow") monkeypatch.delattr(altair.datasets._loader, "load") monkeypatch.setitem(sys.modules, "pyarrow", None) @@ -225,40 +215,22 @@ def test_load(monkeypatch: pytest.MonkeyPatch) -> None: from altair.datasets import load -# HACK: Using a fixture to get a command line option -# https://docs.pytest.org/en/stable/example/simple.html#pass-different-values-to-a-test-function-depending-on-command-line-options -@pytest.mark.xfail( - is_flaky_datasets, # type: ignore - reason=( - "'pandas[pyarrow]' seems to break locally when running:\n" - ">>> pytest -p no:randomly -n logical tests -k test_datasets -m ''\n\n" - "Possibly related:\n" - " https://github.com/modin-project/modin/issues/951\n" - " https://github.com/pandas-dev/pandas/blob/1c986d6213904fd7d9acc5622dc91d029d3f1218/pandas/io/parquet.py#L164\n" - " https://github.com/pandas-dev/pandas/blob/1c986d6213904fd7d9acc5622dc91d029d3f1218/pandas/io/parquet.py#L257\n" - ), - raises=AttributeError, -) -@requires_pyarrow -def test_load_call(monkeypatch: pytest.MonkeyPatch) -> None: +@backends +def test_load_call(backend: _Backend, monkeypatch: pytest.MonkeyPatch) -> None: import altair.datasets._loader monkeypatch.delattr(altair.datasets._loader, "load", raising=False) from altair.datasets import load - assert load._reader._name == "polars" + assert is_loader_backend(load, "polars") default = load("cars") - df_pyarrow = load("cars", backend="pyarrow") - df_pandas = load("cars", backend="pandas[pyarrow]") + df = load("cars", backend=backend) default_2 = load("cars") - df_polars = load("cars", backend="polars") assert nw_dep.is_polars_dataframe(default) - assert nw_dep.is_pyarrow_table(df_pyarrow) - assert nw_dep.is_pandas_dataframe(df_pandas) + assert is_frame_backend(df, backend) assert nw_dep.is_polars_dataframe(default_2) - assert nw_dep.is_polars_dataframe(df_polars) @pytest.mark.parametrize( @@ -296,41 +268,36 @@ def test_load_call(monkeypatch: pytest.MonkeyPatch) -> None: def test_url(name: Dataset) -> None: from altair.datasets import url - assert match_url(name, url(name)) + assert is_url(name, url) def test_url_no_backend(monkeypatch: pytest.MonkeyPatch) -> None: - import altair.datasets from altair.datasets._cache import csv_cache + from altair.datasets._readers import infer_backend - monkeypatch.setitem(sys.modules, "polars", None) - monkeypatch.setitem(sys.modules, "pandas", None) - monkeypatch.setitem(sys.modules, "pyarrow", None) + priority: Any = ("fake_mod_1", "fake_mod_2", "fake_mod_3", "fake_mod_4") assert csv_cache._mapping == {} - - with contextlib.suppress(AltairDatasetsError): - monkeypatch.delattr(altair.datasets._loader, "load", raising=False) with pytest.raises(AltairDatasetsError): - from altair.datasets import load as load - - assert match_url("jobs", url("jobs")) + infer_backend(priority=priority) + url = csv_cache.url + assert is_url("jobs", url) assert csv_cache._mapping != {} - assert match_url("cars", url("cars")) - assert match_url("stocks", url("stocks")) - assert match_url("countries", url("countries")) - assert match_url("crimea", url("crimea")) - assert match_url("disasters", url("disasters")) - assert match_url("driving", url("driving")) - assert match_url("earthquakes", url("earthquakes")) - assert match_url("flare", url("flare")) - assert match_url("flights-10k", url("flights-10k")) - assert match_url("flights-200k", url("flights-200k")) + assert is_url("cars", url) + assert is_url("stocks", url) + assert is_url("countries", url) + assert is_url("crimea", url) + assert is_url("disasters", url) + assert is_url("driving", url) + assert is_url("earthquakes", url) + assert is_url("flare", url) + assert is_url("flights-10k", url) + assert is_url("flights-200k", url) if find_spec("vegafusion"): - assert match_url("flights-3m", url("flights-3m")) + assert is_url("flights-3m", url) with monkeypatch.context() as mp: mp.setitem(sys.modules, "vegafusion", None) @@ -344,51 +311,14 @@ def test_url_no_backend(monkeypatch: pytest.MonkeyPatch) -> None: @backends -def test_loader_call(backend: _Backend, monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv(CACHE_ENV_VAR, raising=False) - - data = Loader.from_backend(backend) - frame = data("stocks", ".csv") +def test_loader_call(backend: _Backend) -> None: + load = Loader.from_backend(backend) + frame = load("stocks", ".csv") assert nw_dep.is_into_dataframe(frame) nw_frame = nw.from_native(frame) assert set(nw_frame.columns) == {"symbol", "date", "price"} -@backends_single -def test_missing_dependency_single( - backend: _Backend, monkeypatch: pytest.MonkeyPatch -) -> None: - monkeypatch.setitem(sys.modules, backend, None) - - with pytest.raises( - ModuleNotFoundError, - match=re.compile( - rf"{backend}.+requires.+{backend}.+but.+{backend}.+not.+found.+pip install {backend}", - flags=re.DOTALL, - ), - ): - Loader.from_backend(backend) - - -@backends_multi -@skip_requires_pyarrow -def test_missing_dependency_multi( - backend: _Backend, monkeypatch: pytest.MonkeyPatch -) -> None: - secondary = "pyarrow" - primary = backend.removesuffix(f"[{secondary}]") - monkeypatch.setitem(sys.modules, secondary, None) - - with pytest.raises( - ModuleNotFoundError, - match=re.compile( - rf"{re.escape(backend)}.+requires.+'{primary}', '{secondary}'.+but.+{secondary}.+not.+found.+pip install {secondary}", - flags=re.DOTALL, - ), - ): - Loader.from_backend(backend) - - @backends def test_dataset_not_found(backend: _Backend) -> None: """ @@ -396,7 +326,7 @@ def test_dataset_not_found(backend: _Backend) -> None: ``Loader.url`` is used since it doesn't require a remote connection. """ - data = Loader.from_backend(backend) + load = Loader.from_backend(backend) real_name: Literal["disasters"] = "disasters" invalid_name: Literal["fake name"] = "fake name" invalid_suffix: Literal["fake suffix"] = "fake suffix" @@ -411,7 +341,7 @@ def test_dataset_not_found(backend: _Backend) -> None: ERR_NO_RESULT, match=re.compile(rf"{MSG_NO_RESULT}.+{NAME}.+{invalid_name}", re.DOTALL), ): - data.url(invalid_name) + load.url(invalid_name) with pytest.raises( TypeError, @@ -420,7 +350,7 @@ def test_dataset_not_found(backend: _Backend) -> None: re.DOTALL, ), ): - data.url(real_name, invalid_suffix) # type: ignore[arg-type] + load.url(real_name, invalid_suffix) # type: ignore[arg-type] with pytest.raises( ERR_NO_RESULT, @@ -429,7 +359,44 @@ def test_dataset_not_found(backend: _Backend) -> None: re.DOTALL, ), ): - data.url(real_name, incorrect_suffix) + load.url(real_name, incorrect_suffix) + + +def test_reader_missing_dependencies() -> None: + from packaging.requirements import Requirement + + from altair.datasets._readers import _Reader + + class MissingDeps(_Reader): + def __init__(self, name) -> None: + self._name = name + reqs = Requirement(name) + for req in (reqs.name, *reqs.extras): + self._import(req) + + self._read_fn = {} + self._scan_fn = {} + + fake_name = "not_a_real_package" + real_name = "altair" + fake_extra = "AnotherFakePackage" + backend = f"{real_name}[{fake_extra}]" + with pytest.raises( + ModuleNotFoundError, + match=re.compile( + rf"{fake_name}.+requires.+{fake_name}.+but.+{fake_name}.+not.+found.+pip install {fake_name}", + flags=re.DOTALL, + ), + ): + MissingDeps(fake_name) + with pytest.raises( + ModuleNotFoundError, + match=re.compile( + rf"{re.escape(backend)}.+requires.+'{real_name}', '{fake_extra}'.+but.+{fake_extra}.+not.+found.+pip install {fake_extra}", + flags=re.DOTALL, + ), + ): + MissingDeps(backend) @backends @@ -451,97 +418,112 @@ def test_reader_cache( monkeypatch.setenv(CACHE_ENV_VAR, str(tmp_path)) - data = Loader.from_backend(backend) - assert data.cache.is_active() - cache_dir = data.cache.path + load = Loader.from_backend(backend) + assert load.cache.is_active() + cache_dir = load.cache.path assert cache_dir == tmp_path - assert tuple(data.cache) == () + assert tuple(load.cache) == () # smallest csvs - lookup_groups = data("lookup_groups") - data("lookup_people") - data("iowa-electricity") - data("global-temp") + lookup_groups = load("lookup_groups") + load("lookup_people") + load("iowa-electricity") + load("global-temp") - cached_paths = tuple(data.cache) + cached_paths = tuple(load.cache) assert len(cached_paths) == 4 if nw_dep.is_polars_dataframe(lookup_groups): left, right = ( lookup_groups, - cast("pl.DataFrame", data("lookup_groups", ".csv")), + cast("pl.DataFrame", load("lookup_groups", ".csv")), ) else: left, right = ( pl.DataFrame(lookup_groups), - pl.DataFrame(data("lookup_groups", ".csv")), + pl.DataFrame(load("lookup_groups", ".csv")), ) assert_frame_equal(left, right) - assert len(tuple(data.cache)) == 4 - assert cached_paths == tuple(data.cache) + assert len(tuple(load.cache)) == 4 + assert cached_paths == tuple(load.cache) - data("iowa-electricity", ".csv") - data("global-temp", ".csv") - data("global-temp.csv") + load("iowa-electricity", ".csv") + load("global-temp", ".csv") + load("global-temp.csv") - assert len(tuple(data.cache)) == 4 - assert cached_paths == tuple(data.cache) + assert len(tuple(load.cache)) == 4 + assert cached_paths == tuple(load.cache) - data("lookup_people") - data("lookup_people.csv") - data("lookup_people", ".csv") - data("lookup_people") + load("lookup_people") + load("lookup_people.csv") + load("lookup_people", ".csv") + load("lookup_people") - assert len(tuple(data.cache)) == 4 - assert cached_paths == tuple(data.cache) + assert len(tuple(load.cache)) == 4 + assert cached_paths == tuple(load.cache) -@slow @datasets_debug @backends def test_reader_cache_exhaustive( - backend: _Backend, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + backend: _Backend, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + polars_loader: PolarsLoader, ) -> None: """ Fully populate and then purge the cache for all backends. - Does not attempt to read the files - Checking we can support pre-downloading and safely deleting + + Notes + ----- + - Requests work the same for all backends + - The logic for detecting the cache contents uses ``narhwals`` + - Here, we're testing that these ``narwhals`` ops are consistent + - `DatasetCache.download_all` is expensive for CI, so aiming for it to run at most once + - 34-45s per call (4x backends) """ + polars_loader.cache.download_all() + CLONED: Path = tmp_path / "clone" + fs.mkdir(CLONED) + fs.copytree(polars_loader.cache.path, CLONED) + monkeypatch.setenv(CACHE_ENV_VAR, str(tmp_path)) - data = Loader.from_backend(backend) - assert data.cache.is_active() - cache_dir = data.cache.path + load = Loader.from_backend(backend) + assert load.cache.is_active() + cache_dir = load.cache.path assert cache_dir == tmp_path - assert tuple(data.cache) == () + assert tuple(load.cache) == (CLONED,) - data.cache.download_all() - cached_paths = tuple(data.cache) + load.cache.path = CLONED + cached_paths = tuple(load.cache) assert cached_paths != () # NOTE: Approximating all datasets downloaded assert len(cached_paths) >= 40 assert all( bool(fp.exists() and is_ext_read(fp.suffix) and fp.stat().st_size) - for fp in data.cache + for fp in load.cache ) # NOTE: Confirm this is a no-op - data.cache.download_all() - assert len(cached_paths) == len(tuple(data.cache)) + load.cache.download_all() + assert len(cached_paths) == len(tuple(load.cache)) # NOTE: Ensure unrelated files in the directory are not removed dummy: Path = tmp_path / "dummy.json" dummy.touch(exist_ok=False) - data.cache.clear() + load.cache.clear() remaining = tuple(tmp_path.iterdir()) - assert len(remaining) == 1 - assert remaining[0] == dummy - dummy.unlink() + assert set(remaining) == {dummy, CLONED} + fs.rm(dummy, CLONED) +@no_xdist def test_reader_cache_disable(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: from altair.datasets import load @@ -572,68 +554,66 @@ def test_reader_cache_disable(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) - assert not load.cache.is_empty() -movies_fail: ParameterSet = pytest.param( - "movies", - marks=pytest.mark.xfail( - reason="Only working for `polars`.\n" - "`pyarrow` isn't happy with the mixed `int`/`str` column." - ), -) -earthquakes_fail: ParameterSet = pytest.param( - "earthquakes", - marks=pytest.mark.xfail( - reason="Only working for `polars`.\nGeoJSON fails on native `pyarrow`" - ), -) - - +# TODO: Investigate adding schemas for `pyarrow`. @pytest.mark.parametrize( - "name", + ("name", "fallback"), [ - "cars", - movies_fail, - "wheat", - "barley", - "gapminder", - "income", - "burtin", - earthquakes_fail, + ("cars", "polars"), + ("movies", "polars"), + ("wheat", "polars"), + ("barley", "polars"), + ("gapminder", "polars"), + ("income", "polars"), + ("burtin", "polars"), + ("cars", None), + pytest.param( + "movies", + None, + marks=pytest.mark.xfail( + True, + raises=TypeError, + reason=( + "msg: `Expected bytes, got a 'int' object`\n" + "Isn't happy with the mixed `int`/`str` column." + ), + strict=True, + ), + ), + ("wheat", None), + ("barley", None), + ("gapminder", None), + ("income", None), + ("burtin", None), ], ) -@pytest.mark.parametrize("fallback", ["polars", None]) -@skip_requires_pyarrow +@backends_pyarrow def test_pyarrow_read_json( - fallback: _Polars | None, name: Dataset, monkeypatch: pytest.MonkeyPatch + backend: _PyArrow, + fallback: _Polars | None, + name: Dataset, + monkeypatch: pytest.MonkeyPatch, ) -> None: - monkeypatch.delenv(CACHE_ENV_VAR, raising=False) - monkeypatch.delitem(sys.modules, "pandas", raising=False) if fallback is None: monkeypatch.setitem(sys.modules, "polars", None) - - data = Loader.from_backend("pyarrow") - - data(name, ".json") + load = Loader.from_backend(backend) + assert load(name, ".json") -@pytest.mark.parametrize( - ("spec", "column"), - [ - (DatasetSpec(name="cars"), "Year"), - (DatasetSpec(name="unemployment-across-industries"), "date"), - (DatasetSpec(name="flights-10k"), "date"), - (DatasetSpec(name="football"), "date"), - (DatasetSpec(name="crimea"), "date"), - (DatasetSpec(name="ohlc"), "date"), - ], -) -def test_polars_read_json_roundtrip( - polars_loader: PolarsLoader, spec: DatasetSpec, column: str -) -> None: - frame = polars_loader(spec["name"], ".json") - tp = frame.schema.to_python()[column] - assert tp is dt.date or issubclass(tp, dt.date) +@backends_no_polars +def test_spatial(spatial_datasets, backend: _Backend) -> None: + load = Loader.from_backend(backend) + if is_polars_backed_pyarrow(load): + assert nw_dep.is_pyarrow_table(load(spatial_datasets)) + else: + pattern = re.compile( + rf"{spatial_datasets}.+geospatial.+native.+{re.escape(backend)}.+try.+polars.+url", + flags=re.DOTALL | re.IGNORECASE, + ) + with pytest.raises(NotImplementedError, match=pattern): + load(spatial_datasets) +# TODO: Adapt into something useful or simplify into just param name def _dataset_params(*, skip: Container[str] = ()) -> Iterator[ParameterSet]: """Temp way of excluding datasets that were removed.""" names: tuple[Dataset, ...] = get_args(Dataset) @@ -646,9 +626,8 @@ def _dataset_params(*, skip: Container[str] = ()) -> Iterator[ParameterSet]: yield pytest.param(*args, marks=marks) -@slow -@datasets_debug @pytest.mark.parametrize(("name", "suffix"), list(_dataset_params())) +@datasets_debug def test_all_datasets( polars_loader: PolarsLoader, name: Dataset, suffix: Extension ) -> None: @@ -668,51 +647,62 @@ def _raise_exception(e: type[Exception], *args: Any, **kwds: Any): def test_no_remote_connection(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: from polars.testing import assert_frame_equal - data = Loader.from_backend("polars") - data.cache.path = tmp_path + load = Loader.from_backend("polars") + load.cache.path = tmp_path - data("londonCentroids") - data("stocks") - data("driving") + load("londonCentroids") + load("stocks") + load("driving") cached_paths = tuple(tmp_path.iterdir()) assert len(cached_paths) == 3 raiser = partial(_raise_exception, URLError) with monkeypatch.context() as mp: - mp.setattr(data._reader._opener, "open", raiser) + mp.setattr(load._reader._opener, "open", raiser) # Existing cache entries don't trigger an error - data("londonCentroids") - data("stocks") - data("driving") + load("londonCentroids") + load("stocks") + load("driving") # Mocking cache-miss without remote conn with pytest.raises(URLError): - data("birdstrikes") + load("birdstrikes") assert len(tuple(tmp_path.iterdir())) == 3 # Now we can get a cache-hit - frame = data("birdstrikes") + frame = load("birdstrikes") assert nw_dep.is_polars_dataframe(frame) assert len(tuple(tmp_path.iterdir())) == 4 with monkeypatch.context() as mp: - mp.setattr(data._reader._opener, "open", raiser) + mp.setattr(load._reader._opener, "open", raiser) # Here, the remote conn isn't considered - we already have the file - frame_from_cache = data("birdstrikes") + frame_from_cache = load("birdstrikes") assert len(tuple(tmp_path.iterdir())) == 4 assert_frame_equal(frame, frame_from_cache) -@backends -def test_metadata_columns(backend: _Backend, metadata_columns: frozenset[str]) -> None: - """Ensure all backends will query the same column names.""" - data = Loader.from_backend(backend) - schema_columns = data._reader._scan_metadata().collect().columns - assert set(schema_columns) == metadata_columns +@pytest.mark.parametrize( + ("name", "column"), + [ + ("cars", "Year"), + ("unemployment-across-industries", "date"), + ("flights-10k", "date"), + ("football", "date"), + ("crimea", "date"), + ("ohlc", "date"), + ], +) +def test_polars_date_read_json_roundtrip( + polars_loader: PolarsLoader, name: Dataset, column: str +) -> None: + """Ensure ``date`` columns are inferred using the roundtrip json -> csv method.""" + frame = polars_loader(name, ".json") + tp = frame.schema.to_python()[column] + assert tp is dt.date or issubclass(tp, dt.date) -@skip_requires_pyarrow @backends_pandas_any @pytest.mark.parametrize( ("name", "columns"),