Skip to content

Commit

Permalink
Migrate to earthkit-data (#45)
Browse files Browse the repository at this point in the history
* Migrate to earthkit-data

* fix indexes

* cleanup

* cleanup

* invalidate raw data

* handle sources

* switch off prompt

* remove markdown

* restore markdown

* Update environment.yml

* set conftest

* use cdsapi

* fix indexes

* Update environment.yml

* update env

* fix use of preprocess

* handle gribs

* add nocache

* cleanup NOCACHE

* csv reader

* fix pip deps

* add ecmwflibs

* fix installation

* use xr directly

* fix extension check

* fix memory issues

* fix shapefile

* pin packages
  • Loading branch information
malmans2 authored Aug 29, 2024
1 parent 680c19a commit 12de045
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 104 deletions.
152 changes: 88 additions & 64 deletions c3s_eqc_automatic_quality_control/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import calendar
import contextlib
import fnmatch
import datetime
import functools
import itertools
import os
Expand All @@ -28,25 +28,23 @@
from typing import Any

import cacholote
import cads_toolbox
import cf_xarray # noqa: F401
import cgul
import emohawk.readers.directory
import emohawk.readers.shapefile
import earthkit.data
import fsspec
import fsspec.implementations.local
import joblib
import pandas as pd
import tqdm
import xarray as xr

cads_toolbox.config.USE_CACHE = True
from earthkit.data.readers.csv import CSVReader
from earthkit.data.readers.grib.index import GribFieldList
from earthkit.data.readers.shapefile import ShapeFileReader
from earthkit.data.sources.file import File

N_JOBS = 1
INVALIDATE_CACHE = False
# TODO: This kwargs should somehow be handle upstream by the toolbox.
TO_XARRAY_KWARGS: dict[str, Any] = {
"pandas_read_csv_kwargs": {"comment": "#"},
}

NOCACHE = False
_SORTED_REQUEST_PARAMETERS = ("area", "grid")


Expand Down Expand Up @@ -309,29 +307,47 @@ def ensure_request_gets_cached(request: dict[str, Any]) -> dict[str, Any]:
return cacheable_request


def _cached_retrieve(collection_id: str, request: dict[str, Any]) -> emohawk.Data:
with cacholote.config.set(return_cache_entry=False):
return cads_toolbox.catalogue.retrieve(collection_id, request).data
def get_paths(sources: list[Any]) -> list[str]:
paths = []
for source in sources:
indexes = getattr(source, "_indexes", [source])
paths.extend([index.path for index in indexes])
return paths


@cacholote.cacheable
def _cached_retrieve(
collection_id: str, request: dict[str, Any]
) -> list[fsspec.implementations.local.LocalFileOpener]:
if NOCACHE:
request = request | {"nocache": datetime.datetime.now().isoformat()}
ds = earthkit.data.from_source("cds", collection_id, request, prompt=False)
if isinstance(ds, ShapeFileReader) and hasattr(ds._parent, "_path_and_parts"):
# Do not unzip vector data
sources = [ds._parent._path_and_parts]
else:
sources = ds.sources if hasattr(ds, "sources") else [ds]
fs = fsspec.filesystem("file")
return [fs.open(path) for path in get_paths(sources)]


def retrieve(collection_id: str, request: dict[str, Any]) -> list[str]:
with cacholote.config.set(
return_cache_entry=False,
io_delete_original=True,
):
return [file.path for file in _cached_retrieve(collection_id, request)]


def get_sources(
collection_id: str,
request_list: list[dict[str, Any]],
exclude: list[str] = ["*.png", "*.json"],
) -> list[str]:
source: set[str] = set()

sources: set[str] = set()
disable = os.getenv("TQDM_DISABLE", "False") == "True"
for request in tqdm.tqdm(request_list, disable=disable):
data = _cached_retrieve(collection_id, request)
if content := getattr(data, "_content", None):
source.update(map(str, content))
else:
source.add(str(data.source))

for pattern in exclude:
source -= set(fnmatch.filter(source, pattern))
return list(source)
sources.update(retrieve(collection_id, request))
return list(sources)


def _set_bound_coords(ds: xr.Dataset) -> xr.Dataset:
Expand Down Expand Up @@ -410,57 +426,57 @@ def _preprocess(
return harmonise(ds, collection_id)


def get_data(source: list[str]) -> Any:
if len(source) == 1:
return emohawk.open(source[0])

# TODO: emohawk not able to open a list of files
emohwak_dir = emohawk.readers.directory.DirectoryReader("")
emohwak_dir._content = source
return emohwak_dir


def _download_and_transform_requests(
collection_id: str,
request_list: list[dict[str, Any]],
transform_func: Callable[..., xr.Dataset] | None,
transform_func_kwargs: dict[str, Any],
**open_mfdataset_kwargs: Any,
) -> xr.Dataset:
# TODO: Ideally, we would always use emohawk.
# However, there is not a consistent behavior across backends.
# For example, GRIB silently ignore open_mfdataset_kwargs
sources = get_sources(collection_id, request_list)
try:
engine = open_mfdataset_kwargs.get(
"engine",
{xr.backends.plugins.guess_engine(source) for source in sources},
)
use_emohawk = len(engine) != 1
except ValueError:
use_emohawk = True

open_mfdataset_kwargs["preprocess"] = functools.partial(
preprocess = functools.partial(
_preprocess,
collection_id=collection_id,
preprocess=open_mfdataset_kwargs.get("preprocess", None),
preprocess=open_mfdataset_kwargs.pop("preprocess", None),
)

if use_emohawk:
data = get_data(sources)
if isinstance(data, emohawk.readers.shapefile.ShapefileReader):
# FIXME: emohawk NotImplementedError
ds: xr.Dataset = data.to_pandas().to_xarray()
grib_ext = (".grib", ".grb", ".grb1", ".grb2")
ext_to_skip = (".png", ".json")
if all(
isinstance(source, str) and source.endswith(grib_ext + ext_to_skip)
for source in sources
):
# TODO: Avoid memory issues
# https://github.com/ecmwf/earthkit-data/issues/378
# https://github.com/ecmwf/earthkit-data/issues/400
open_mfdataset_kwargs["preprocess"] = preprocess
ds = xr.open_mfdataset(
[source for source in sources if not source.endswith(ext_to_skip)],
**open_mfdataset_kwargs,
)
else:
ek_ds = earthkit.data.from_source("file", sources)
if isinstance(ek_ds, GribFieldList):
# TODO: squeeze=True is cfgrib default
# https://github.com/ecmwf/earthkit-data/issues/374
open_dataset_kwargs = {
"chunks": {},
"squeeze": True,
} | open_mfdataset_kwargs
ds = ek_ds.to_xarray(xarray_open_dataset_kwargs=open_dataset_kwargs)
ds = preprocess(ds)
elif (
isinstance(ek_ds, File) and isinstance(ek_ds._reader, CSVReader)
) or isinstance(ek_ds, ShapeFileReader):
assert not open_mfdataset_kwargs
ds = preprocess(ek_ds.to_xarray())
else:
ds = data.to_xarray(
xarray_open_mfdataset_kwargs=open_mfdataset_kwargs,
**TO_XARRAY_KWARGS,
)
open_mfdataset_kwargs["preprocess"] = preprocess
ds = ek_ds.to_xarray(xarray_open_mfdataset_kwargs=open_mfdataset_kwargs)
if not isinstance(ds, xr.Dataset):
# When emohawk fails to concat, it silently return a list
raise TypeError(f"`emohawk` returned {type(ds)} instead of a xr.Dataset")
else:
ds = xr.open_mfdataset(sources, **open_mfdataset_kwargs)
raise TypeError(
f"`earthkit.data` returned {type(ds)} instead of a xr.Dataset"
)

if transform_func is not None:
with cacholote.config.set(return_cache_entry=False):
Expand All @@ -481,7 +497,7 @@ def _delayed_download(
collection_id: str, request: dict[str, Any], config: cacholote.config.Settings
) -> None:
with cacholote.config.set(**dict(config)):
_cached_retrieve(collection_id, request)
retrieve(collection_id, request)


def download_and_transform(
Expand Down Expand Up @@ -568,7 +584,15 @@ def download_and_transform(
for request in ensure_list(requests):
request_list.extend(split_request(request, chunks, split_all))

if invalidate_cache and not use_cache:
# Delete raw data
for request in request_list:
cacholote.delete(
_cached_retrieve, collection_id=collection_id, request=request
)

if n_jobs != 1:
assert not NOCACHE, "n_jobs must be 1 when NOCACHE is True"
# Download all data in parallel
joblib.Parallel(n_jobs=n_jobs)(
_delayed_download(collection_id, request, cacholote.config.get())
Expand Down
7 changes: 3 additions & 4 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ channels:
# DO NOT EDIT ABOVE THIS LINE, ADD DEPENDENCIES BELOW AS SHOWN IN THE EXAMPLE
dependencies:
- cartopy
- cdsapi
- cdsapi >= 0.7.1
- cfgrib
- cf-units
- cf_xarray
- dask
- earthkit-data >= 0.9.4
- fsspec
- geopandas
- joblib
Expand All @@ -37,11 +38,9 @@ dependencies:
- tqdm
- typing_extensions
- xarray
- xesmf
- xesmf >= 0.8.7
- xskillscore
- pip:
- cacholote
- cads-toolbox
- cgul
- emohawk
- kaleido
9 changes: 5 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ classifiers = [
]
dependencies = [
"cacholote",
"cads-toolbox",
"cartopy",
"cf-xarray",
"cgul",
"emohawk",
"earthkit-data[cds]",
"ecmwflibs",
"fsspec",
"joblib",
"matplotlib",
Expand Down Expand Up @@ -53,9 +53,10 @@ ignore_missing_imports = true
module = [
"cads_toolbox",
"cartopy.*",
"cdsapi",
"cgul",
"emohawk.*",
"fsspec",
"earthkit.*",
"fsspec.*",
"joblib",
"plotly.*",
"shapely",
Expand Down
39 changes: 39 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,47 @@
import pathlib
import tempfile
from collections.abc import Generator
from typing import Any

import cacholote
import cdsapi
import fsspec
import pytest
import xarray as xr


class MockResult:
def __init__(self, name: str, request: dict[str, Any]) -> None:
self.name = name
self.request = request

@property
def location(self) -> str:
return tempfile.NamedTemporaryFile(suffix=".nc", delete=False).name

def download(self, target: str | pathlib.Path | None = None) -> str | pathlib.Path:
ds = xr.tutorial.open_dataset(self.name).sel(**self.request)
ds.to_netcdf(path := target or self.location)
return path


def mock_retrieve(
self: cdsapi.Client,
name: str,
request: dict[str, Any],
target: str | pathlib.Path | None = None,
) -> fsspec.spec.AbstractBufferedFile:
result = MockResult(name, request)
if target is None:
return result
return result.download(target)


@pytest.fixture(autouse=True)
def mock_download(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CDSAPI_URL", "")
monkeypatch.setenv("CDSAPI_KEY", "123456:1123e4567-e89b-12d3-a456-42665544000")
monkeypatch.setattr(cdsapi.Client, "retrieve", mock_retrieve)


@pytest.fixture(autouse=True)
Expand Down
14 changes: 1 addition & 13 deletions tests/test_10_download.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import datetime
from typing import Any

import cads_toolbox
import pandas as pd
import pytest
import xarray as xr
from utils import mock_download

from c3s_eqc_automatic_quality_control import download

Expand Down Expand Up @@ -271,12 +269,9 @@ def test_ensure_request_gets_cached() -> None:
],
)
def test_download_no_transform(
monkeypatch: pytest.MonkeyPatch,
chunks: dict[str, int],
dask_chunks: dict[str, tuple[int, ...]],
) -> None:
monkeypatch.setattr(cads_toolbox.catalogue, "_download", mock_download)

ds = download.download_and_transform(*AIR_TEMPERATURE_REQUEST, chunks=chunks)
assert dict(ds.chunks) == dask_chunks

Expand All @@ -289,12 +284,9 @@ def test_download_no_transform(
],
)
def test_download_and_transform(
monkeypatch: pytest.MonkeyPatch,
transform_chunks: bool,
dask_chunks: dict[str, tuple[int, ...]],
) -> None:
monkeypatch.setattr(cads_toolbox.catalogue, "_download", mock_download)

def transform_func(ds: xr.Dataset) -> xr.Dataset:
return ds.round().mean(("longitude", "latitude"))

Expand All @@ -310,11 +302,7 @@ def transform_func(ds: xr.Dataset) -> xr.Dataset:

@pytest.mark.parametrize("transform_chunks", [True, False])
@pytest.mark.parametrize("invalidate_cache", [True, False])
def test_invalidate_cache(
monkeypatch: pytest.MonkeyPatch, transform_chunks: bool, invalidate_cache: bool
) -> None:
monkeypatch.setattr(cads_toolbox.catalogue, "_download", mock_download)

def test_invalidate_cache(transform_chunks: bool, invalidate_cache: bool) -> None:
def transform_func(ds: xr.Dataset) -> xr.Dataset:
return ds * 0

Expand Down
19 changes: 0 additions & 19 deletions tests/utils.py

This file was deleted.

0 comments on commit 12de045

Please sign in to comment.