Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load_stac enhancements #284

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[submodule "openeo_processes_dask/specs/openeo-processes"]
path = openeo_processes_dask/specs/openeo-processes
url = git@github.com:eodcgmbh/openeo-processes.git
url = https://github.com/interTwin-eu/openeo-processes.git
146 changes: 133 additions & 13 deletions openeo_processes_dask/process_implementations/cubes/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from urllib.parse import unquote, urljoin, urlparse

import numpy as np
import odc.stac
import planetary_computer as pc
import pyproj
import pystac_client
import stackstac
import xarray as xr
from openeo_pg_parser_networkx.pg_schema import BoundingBox, TemporalInterval
from stac_validator import stac_validator
Expand All @@ -23,6 +24,7 @@
from openeo_processes_dask.process_implementations.data_model import RasterCube
from openeo_processes_dask.process_implementations.exceptions import (
NoDataAvailable,
OpenEOException,
TemporalExtentEmpty,
)

Expand Down Expand Up @@ -86,10 +88,19 @@ def load_stac(
temporal_extent: Optional[TemporalInterval] = None,
bands: Optional[list[str]] = None,
properties: Optional[dict] = None,
resolution: Optional[float] = None,
projection: Optional[Union[int, str]] = None,
resampling: Optional[str] = None,
) -> RasterCube:
asset_type = _validate_stac(url)
stac_type = _validate_stac(url)

if asset_type == "COLLECTION":
# TODO: load_stac should have a parameter to enable scale and offset?

# If the user provide the bands list as a single string, wrap it in a list:
if isinstance(bands, str):
bands = [bands]

if stac_type == "COLLECTION":
# If query parameters are passed, try to get the parent Catalog if possible/exists, to use the /search endpoint
if spatial_extent or temporal_extent or bands or properties:
# If query parameters are passed, try to get the parent Catalog if possible/exists, to use the /search endpoint
Expand Down Expand Up @@ -139,28 +150,137 @@ def load_stac(
raise Exception(
f"No parameters for filtering provided. Loading the whole STAC Collection is not supported yet."
)

elif asset_type == "ITEM":
elif stac_type == "ITEM":
stac_api = pystac_client.stac_api_io.StacApiIO()
stac_dict = json.loads(stac_api.read_text(url))
items = stac_api.stac_object_from_dict(stac_dict)

items = [stac_api.stac_object_from_dict(stac_dict)]
else:
raise Exception(
f"The provided URL is a STAC {asset_type}, which is not yet supported. Please provide a valid URL to a STAC Collection or Item."
f"The provided URL is a STAC {stac_type}, which is not yet supported. Please provide a valid URL to a STAC Collection or Item."
)

if bands is not None:
stack = stackstac.stack(items, assets=bands)
available_assets = {tuple(i.assets.keys()) for i in items}
if (len(available_assets)) > 1:
raise OpenEOException(
f"The resulting STAC Items contain two separate set of assets: {available_assets}. We can't load them at the same time."
)
available_assets = [x for t in available_assets for x in t]
if len(set(available_assets) & set(bands)) == 0:
raise OpenEOException(
f"The provided bands: {bands} can't be found in the STAC assets: {available_assets}"
)
reference_system = None
# Check if the reference system is available under properties with the datacube extension
item_dict = items[0].to_dict()
if "properties" in item_dict:
if "cube:dimensions" in item_dict["properties"]:
for d in item_dict["properties"]["cube:dimensions"]:
if "reference_system" in item_dict["properties"]["cube:dimensions"][d]:
reference_system = item_dict["properties"]["cube:dimensions"][d][
"reference_system"
]
break

asset_scale_offset = {}
zarr_assets = False
use_xarray_open_kwargs = False
for asset in available_assets:
if asset in bands:
asset_scale = 1
asset_offset = 0
asset_nodata = None
asset_dtype = None
asset_type = None
asset_dict = items[0].assets[asset].to_dict()
if "raster:bands" in asset_dict:
asset_scale = asset_dict["raster:bands"][0].get("scale", 1)
asset_offset = asset_dict["raster:bands"][0].get("offset", 0)
asset_nodata = asset_dict["raster:bands"][0].get("nodata", None)
asset_dtype = asset_dict["raster:bands"][0].get("data_type", None)
if "type" in asset_dict:
asset_type = asset_dict["type"]
if asset_type == "application/vnd+zarr":
zarr_assets = True
if "xarray:open_kwargs" in asset_dict:
use_xarray_open_kwargs = True
asset_scale_offset[asset] = {
"scale": asset_scale,
"offset": asset_offset,
"nodata": asset_nodata,
"data_type": asset_dtype,
"type": asset_type,
}

if zarr_assets:
if use_xarray_open_kwargs:
datasets = [
xr.open_dataset(asset.href, **asset.extra_fields["xarray:open_kwargs"])
for item in items
for asset in item.assets.values()
if any(b in asset.href for b in bands)
]
else:
datasets = [
xr.open_dataset(asset.href, engine="zarr", consolidated=True, chunks={})
for item in items
for asset in item.assets.values()
if any(b in asset.href for b in bands)
]
stack = xr.combine_by_coords(
datasets, join="exact", combine_attrs="drop_conflicts"
)
stack.rio.write_crs(reference_system, inplace=True)
# TODO: now drop data which consist in dates. Probably we should allow it if not conflicitng with other data types.
for d in stack.data_vars:
if "datetime" in str(stack[d].dtype):
stack = stack.drop(d)
stack = stack.to_dataarray(dim="bands")
else:
stack = stackstac.stack(items)
# If at least one band has the nodata field set, we have to apply it at loading time
apply_nodata = True
nodata_set = {asset_scale_offset[k]["nodata"] for k in asset_scale_offset}
dtype_set = {asset_scale_offset[k]["data_type"] for k in asset_scale_offset}
kwargs = {}
if resolution is not None:
kwargs["resolution"] = resolution
if projection is not None:
kwargs["crs"] = projection
if resampling is not None:
kwargs["resampling"] = resampling

if len(nodata_set) == 1 and list(nodata_set)[0] == None:
apply_nodata = False
if apply_nodata:
# We can pass only a single nodata value for all the assets/variables/bands https://github.com/opendatacube/odc-stac/issues/147#issuecomment-2005315438
# Therefore, if we load multiple assets having different nodata values, the first one will be used
kwargs["nodata"] = list(nodata_set)[0]
dtype = list(dtype_set)[0]
if dtype is not None:
kwargs["nodata"] = np.dtype(dtype).type(kwargs["nodata"])
# TODO: the dimension names (like "bands") should come from the STAC metadata and not hardcoded
# Note: unfortunately, converting the dataset to a dataarray, casts all the data types to the same
if bands is not None:
stack = odc.stac.load(items, bands=bands, chunks={}, **kwargs).to_dataarray(
dim="bands"
)
else:
stack = odc.stac.load(items, chunks={}, **kwargs).to_dataarray(dim="bands")

if spatial_extent is not None:
stack = filter_bbox(stack, spatial_extent)

if temporal_extent is not None and asset_type == "ITEM":
if temporal_extent is not None and (stac_type == "ITEM" or zarr_assets):
stack = filter_temporal(stack, temporal_extent)

# If at least one band requires to apply scale and/or offset, the datatype of the whole DataArray must be cast to float -> do not apply it automatically yet. see https://github.com/Open-EO/openeo-processes/issues/503
# b_dim = stack.openeo.band_dims[0]
# for b in stack[b_dim]:
# scale = asset_scale_offset[b.item(0)]["scale"]
# offset = asset_scale_offset[b.item(0)]["offset"]
# if scale != 1:
# stack.loc[{b_dim: b.item(0)}] *= scale
# if offset != 0:
# stack.loc[{b_dim: b.item(0)}] += offset

return stack


Expand Down
2 changes: 2 additions & 0 deletions openeo_processes_dask/process_implementations/cubes/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,15 @@ def merge_cubes(
"context": context,
}

crs = concat_both_cubes_rechunked.rio.crs
merged_cube = concat_both_cubes_rechunked.reduce(
overlap_resolver,
dim=NEW_DIM_NAME,
keep_attrs=True,
positional_parameters=positional_parameters,
named_parameters=named_parameters,
)
merged_cube.rio.write_crs(crs,inplace=True)
else:
# Example 1 & 2
dims_requiring_resolve = [
Expand Down
2 changes: 1 addition & 1 deletion openeo_processes_dask/specs/openeo-processes
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ rioxarray = { version = ">=0.12.0,<1", optional = true }
openeo-pg-parser-networkx = { version = ">=2024.7", optional = true }
odc-geo = { version = ">=0.4.1,<1", optional = true }
stac_validator = { version = ">=3.3.1", optional = true }
stackstac = { version = ">=0.4.3", optional = true }
odc-stac = { version = ">=0.3.9", optional = true }
pystac_client = { version = ">=0.6.1", optional = true }
planetary_computer = { version = ">=0.5.1", optional = true }
scipy = "^1.11.3"
Expand All @@ -54,7 +54,7 @@ pre-commit = "^2.20.0"
pytest-cov = "^4.0.0"

[tool.poetry.extras]
implementations = ["geopandas", "xarray", "dask", "rasterio", "dask-geopandas", "rioxarray", "openeo-pg-parser-networkx", "odc-geo", "stackstac", "planetary_computer", "pystac_client", "stac_validator", "xvec", "joblib"]
implementations = ["geopandas", "xarray", "dask", "rasterio", "dask-geopandas", "rioxarray", "openeo-pg-parser-networkx", "odc-geo", "odc-stac", "planetary_computer", "pystac_client", "stac_validator", "xvec", "joblib"]
ml = ["xgboost"]

[tool.pytest.ini_options]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_ndvi(bounding_box):
).isel({"x": slice(0, 20), "y": slice(0, 20)})

# Test whether this works with different band names
input_cube = input_cube.rename({"band": "b"})
input_cube = input_cube.rename({"bands": "b"})

import dask.array as da

Expand Down