Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use decorator for discovering input variables #24

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlpp_features/accessors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""""""

import logging
import warnings
from dataclasses import dataclass, field
Expand Down
25 changes: 25 additions & 0 deletions mlpp_features/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import wraps
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, List

import xarray as xr

Expand Down Expand Up @@ -84,3 +85,27 @@ def inner(*args, **kwargs):
return out.chunk("auto").persist()

return inner


def inputs(*vars: str):
"""
Raise a KeyError to expose all the required input data during discovery, which
is done by calling the feature function with a dictionary of empty datasets.
This decorator is currently needed for features that require multiple or no input data.
"""

def decorator(fn):
@wraps(fn)
def inner(*args, **kwargs):
data = args[0]
is_discovery = all(
arr is not None and len(arr) == 0 for arr in data.values()
)
if is_discovery:
raise KeyError([var.split(":")[1] for var in vars])
else:
return fn(*args, **kwargs)

return inner

return decorator
1 change: 1 addition & 0 deletions mlpp_features/experimental.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module implements experimental features"""

from typing import List, Tuple

import numpy as np
Expand Down
26 changes: 18 additions & 8 deletions mlpp_features/nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import xarray as xr
import numpy as np

from mlpp_features.decorators import cache, out_format
from mlpp_features.decorators import cache, inputs, out_format
from mlpp_features import calc

LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -476,20 +476,14 @@ def leadtime(data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwarg
return ds.astype("float32")


@inputs("nwp:HSURF", "terrain:DEM")
@out_format(units="m")
def model_height_difference(
data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs
) -> xr.DataArray:
"""
Difference between model height and height from the more precise DEM in m
"""
# try/except block necessary to expose all the required input data
try:
data["nwp"]["HSURF"]
data["terrain"]["DEM"]
except KeyError:
raise KeyError(["HSURF", "DEM"])

hsurf_on_poi = data["nwp"].preproc.get("HSURF").preproc.interp(stations)
dem_on_poi = data["terrain"].preproc.get("DEM").preproc.interp(stations)

Expand Down Expand Up @@ -547,6 +541,7 @@ def northward_wind_ensctrl(
)


@inputs("nwp:air_temperature", "nwp:surface_air_pressure")
@out_format(units="degC")
def potential_temperature_ens(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
Expand Down Expand Up @@ -649,6 +644,10 @@ def pressure_difference_GVE_GUT_ensctrl(
)


@inputs(
"nwp:dew_point_temperature",
"nwp:air_temperature",
)
@out_format(units="%")
def relative_humidity_ens(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
Expand Down Expand Up @@ -889,6 +888,7 @@ def surface_air_pressure_ensctrl(
)


@inputs("terrain:SX_50M_RADIUS500", "nwp:eastward_wind", "nwp:northward_wind")
@out_format()
def sx_500m_ensavg(
data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs
Expand Down Expand Up @@ -922,6 +922,7 @@ def sx_500m_ensavg(
return sx.astype("float32")


@inputs("terrain:SX_50M_RADIUS500", "nwp:eastward_wind", "nwp:northward_wind")
@out_format()
def sx_500m_ensctrl(
data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs
Expand Down Expand Up @@ -955,6 +956,7 @@ def sx_500m_ensctrl(
return sx.astype("float32")


@inputs("nwp:air_temperature", "nwp:surface_air_pressure", "nwp:dew_point_temperature")
@out_format(units="g kg-1")
def water_vapor_mixing_ratio_ens(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
Expand Down Expand Up @@ -993,6 +995,7 @@ def water_vapor_mixing_ratio_ensctrl(
)


@inputs("nwp:air_temperature", "nwp:dew_point_temperature")
@out_format(units="hPa")
def water_vapor_pressure_ens(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
Expand Down Expand Up @@ -1068,6 +1071,7 @@ def water_vapor_saturation_pressure_ensctrl(
)


@inputs("nwp:eastward_wind", "nwp:northward_wind")
@out_format(units="degrees")
def wind_from_direction_ens(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
Expand Down Expand Up @@ -1122,6 +1126,7 @@ def wind_from_direction_rank(
)


@inputs("nwp:eastward_wind", "nwp:northward_wind")
@out_format(units="m s-1")
def wind_speed_ens(data: Dict[str, xr.Dataset], stations, *args, **kwargs):
u = eastward_wind_ens(data, stations, *args, **kwargs)
Expand All @@ -1140,6 +1145,7 @@ def wind_speed_ensavg(
return uv.mean("realization").to_dataset().preproc.align_time(reftimes, leadtimes)


@inputs("nwp:eastward_wind", "nwp:northward_wind", "obs:wind_speed")
@out_format(units="m s-1")
def wind_speed_ensavg_error(
data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs
Expand Down Expand Up @@ -1209,6 +1215,7 @@ def wind_speed_ensctrl_5hmean(
)


@inputs("nwp:eastward_wind", "nwp:northward_wind", "obs:wind_speed")
@out_format(units="m s-1")
def wind_speed_ensctrl_error(
data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs
Expand Down Expand Up @@ -1303,6 +1310,7 @@ def wind_speed_of_gust_ensavg(
return ug.mean("realization").to_dataset().preproc.align_time(reftimes, leadtimes)


@inputs("nwp:wind_speed_of_gust", "obs:wind_speed_of_gust")
@out_format(units="m s-1")
def wind_speed_of_gust_ensavg_error(
data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs
Expand Down Expand Up @@ -1372,6 +1380,7 @@ def wind_speed_of_gust_ensctrl_5hmean(
)


@inputs("nwp:wind_speed_of_gust", "obs:wind_speed_of_gust")
@out_format(units="m s-1")
def wind_speed_of_gust_ensctrl_error(
data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs
Expand Down Expand Up @@ -1428,6 +1437,7 @@ def wind_speed_of_gust_rank(
)


@inputs("nwp:wind_speed_of_gust", "nwp:eastward_wind", "nwp:northward_wind")
@out_format()
def wind_gust_factor_ens(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
Expand Down
13 changes: 8 additions & 5 deletions mlpp_features/obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import xarray as xr

from mlpp_features.decorators import out_format
from mlpp_features.decorators import inputs, out_format
from mlpp_features import calc

LOGGER = logging.getLogger(__name__)
Expand All @@ -28,6 +28,7 @@ def air_temperature(
)


@inputs("obs:air_temperature", "obs:dew_point_temperature")
@out_format(units="degC")
def dew_point_depression(
data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs
Expand All @@ -40,6 +41,7 @@ def dew_point_depression(
return (t - t_d).astype("float32")


@inputs("obs:air_temperature", "obs:relative_humidity")
@out_format(units="degC")
def dew_point_temperature(
data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs
Expand Down Expand Up @@ -83,6 +85,7 @@ def relative_humidity(
)


@inputs("obs:air_temperature", "obs:relative_humidity", "obs:surface_air_pressure")
@out_format(units="g kg-1")
def water_vapor_mixing_ratio(
data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs
Expand Down Expand Up @@ -131,6 +134,7 @@ def sin_wind_from_direction(
)


@inputs("obs:air_temperature", "obs:relative_humidity")
@out_format(units="hPa")
def water_vapor_pressure(
data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs
Expand Down Expand Up @@ -199,6 +203,7 @@ def wind_speed_of_gust(
)


@inputs("obs:air_temperature", "obs:surface_air_pressure")
@out_format(units="degC")
def potential_temperature(
data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs
Expand Down Expand Up @@ -283,15 +288,14 @@ def distance_to_nearest_wind_speed_of_gust(
)


@inputs()
@out_format()
def weight_owner_id(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
) -> xr.Dataset:
"""
Weight the station owner.
"""
if data["obs"] is not None and len(data["obs"]) == 0:
raise KeyError([])
owner_id = stations.owner_id.to_xarray()
owner_weight = xr.full_like(owner_id, 1.0)
owner_weight = owner_weight.where(owner_id > 1, 2)
Expand All @@ -302,15 +306,14 @@ def weight_owner_id(
return ds.astype("float32")


@inputs()
@out_format()
def measurement_height(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
) -> xr.Dataset:
"""
Weight the station owner.
"""
if data["obs"] is not None and len(data["obs"]) == 0:
raise KeyError([])
pole_height = stations.pole_height.to_xarray()
fillvalue_pole_height = pole_height.median()
LOGGER.debug(f"Fill value pole height: {fillvalue_pole_height:.1f}")
Expand Down
13 changes: 12 additions & 1 deletion mlpp_features/terrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import xarray as xr

from mlpp_features.decorators import out_format
from mlpp_features.decorators import inputs, out_format
from mlpp_features import experimental as exp


Expand Down Expand Up @@ -41,6 +41,7 @@ def aspect_2000m(data: Dict[str, xr.Dataset], stations, *args, **kwargs) -> xr.D
)


@inputs("terrain:VALLEY_NORM_1000M_SMTHFACT0.5", "terrain:VALLEY_DIR_1000M_SMTHFACT0.5")
@out_format()
def cos_valley_index_1000m(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
Expand All @@ -61,6 +62,7 @@ def cos_valley_index_1000m(
return cos_valley.preproc.interp(stations).astype("float32")


@inputs("terrain:VALLEY_NORM_2000M_SMTHFACT0.5", "terrain:VALLEY_DIR_2000M_SMTHFACT0.5")
@out_format()
def cos_valley_index_2000m(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
Expand All @@ -81,6 +83,9 @@ def cos_valley_index_2000m(
return cos_valley.preproc.interp(stations).astype("float32")


@inputs(
"terrain:VALLEY_NORM_10000M_SMTHFACT0.5", "terrain:VALLEY_DIR_10000M_SMTHFACT0.5"
)
@out_format()
def cos_valley_index_10000m(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
Expand Down Expand Up @@ -110,6 +115,7 @@ def distance_to_alpine_ridge(

**Experimental feature, use with caution!**
"""
# raise KeyError during discover
if all([len(ds) == 0 for ds in data.values()]):
raise KeyError()
alpine_crest_wgs84 = [
Expand Down Expand Up @@ -158,6 +164,7 @@ def elevation_50m(data: Dict[str, xr.Dataset], stations, *args, **kwargs) -> xr.
return data["terrain"].preproc.get("DEM").preproc.interp(stations).astype("float32")


@inputs("terrain:VALLEY_NORM_1000M_SMTHFACT0.5", "terrain:VALLEY_DIR_1000M_SMTHFACT0.5")
@out_format()
def sin_valley_index_1000m(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
Expand All @@ -178,6 +185,7 @@ def sin_valley_index_1000m(
return sin_valley.preproc.interp(stations).astype("float32")


@inputs("terrain:VALLEY_NORM_2000M_SMTHFACT0.5", "terrain:VALLEY_DIR_2000M_SMTHFACT0.5")
@out_format()
def sin_valley_index_2000m(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
Expand All @@ -198,6 +206,9 @@ def sin_valley_index_2000m(
return sin_valley.preproc.interp(stations).astype("float32")


@inputs(
"terrain:VALLEY_NORM_10000M_SMTHFACT0.5", "terrain:VALLEY_DIR_10000M_SMTHFACT0.5"
)
@out_format()
def sin_valley_index_10000m(
data: Dict[str, xr.Dataset], stations, *args, **kwargs
Expand Down
Loading
Loading