Skip to content

Commit

Permalink
feat(era5_extract): dataset filename as pipeline parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
yannforget committed Dec 20, 2024
1 parent 53f7c55 commit cc2fb08
Showing 1 changed file with 32 additions and 8 deletions.
40 changes: 32 additions & 8 deletions era5_extract/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
pipeline,
workspace,
)
from openhexa.sdk.datasets import DatasetFile
from openhexa.toolbox.era5.cds import CDS, VARIABLES


Expand Down Expand Up @@ -46,6 +47,14 @@
help="Input dataset containing boundaries geometries",
required=True,
)
@parameter(
"boundaries_file",
name="Boundaries filename in dataset",
type=str,
help="Filename of the boundaries file to use in the boundaries dataset",
required=False,
default="district.parquet",
)
@parameter(
"variable",
name="Variable",
Expand Down Expand Up @@ -82,13 +91,14 @@ def era5_extract(
boundaries_dataset: Dataset,
variable: str,
output_dir: str,
boundaries_file: str | None = None,
time: list[int] | None = None,
) -> None:
"""Download ERA5 products from the Climate Data Store."""
cds = CDS(key=cds_connection.key)
current_run.log_info("Successfully connected to the Climate Data Store")

boundaries = read_boundaries(boundaries_dataset)
boundaries = read_boundaries(boundaries_dataset, filename=boundaries_file)
bounds = get_bounds(boundaries)
current_run.log_info(f"Using area of interest: {bounds}")

Expand Down Expand Up @@ -122,13 +132,18 @@ def era5_extract(
)


def read_boundaries(boundaries_dataset: Dataset) -> gpd.GeoDataFrame:
def read_boundaries(
boundaries_dataset: Dataset, filename: str | None = None
) -> gpd.GeoDataFrame:
"""Read boundaries geographic file from input dataset.
Parameters
----------
boundaries_dataset : Dataset
Input dataset containing a "*district*.parquet" geoparquet file
filename : str
Filename of the boundaries file to read if there are several.
If set to None, the 1st parquet file found will be loaded.
Return
------
Expand All @@ -140,16 +155,25 @@ def read_boundaries(boundaries_dataset: Dataset) -> gpd.GeoDataFrame:
FileNotFoundError
If the boundaries file is not found
"""
boundaries: gpd.GeoDataFrame = None
ds = boundaries_dataset.latest_version

ds_file: DatasetFile | None = None
for f in ds.files:
if f.filename.endswith(".parquet") and "district" in f.filename:
boundaries = gpd.read_parquet(BytesIO(f.read()))
if boundaries is None:
msg = "Boundaries file not found"
if f.filename == filename:
if f.filename.endswith(".parquet"):
ds_file = f
if f.filename.endswith(".geojson") or f.filename.endswith(".gpkg"):
ds_file = f

if ds_file is None:
msg = f"File {filename} not found in dataset {ds.name}"
current_run.log_error(msg)
raise FileNotFoundError(msg)
return boundaries

if ds_file.filename.endswith(".parquet"):
return gpd.read_parquet(BytesIO(ds_file.read()))

return gpd.read_file(BytesIO(ds_file.read()))


def get_bounds(boundaries: gpd.GeoDataFrame) -> tuple[int]:
Expand Down

0 comments on commit cc2fb08

Please sign in to comment.