From 794071d487c134d4955b8b79151f5e3e50824047 Mon Sep 17 00:00:00 2001 From: ned Date: Fri, 23 Feb 2024 09:47:02 +0100 Subject: [PATCH] Use decorator for discovering input variables --- mlpp_features/accessors.py | 1 + mlpp_features/decorators.py | 25 +++++++++++++++++++++++ mlpp_features/experimental.py | 1 + mlpp_features/nwp.py | 26 ++++++++++++++++-------- mlpp_features/obs.py | 13 +++++++----- mlpp_features/terrain.py | 13 +++++++++++- mlpp_features/time.py | 20 +++++++------------ tests/test_discover.py | 37 ++++++++++++++++++++++++++++++++++- tests/test_features.py | 2 +- 9 files changed, 109 insertions(+), 29 deletions(-) diff --git a/mlpp_features/accessors.py b/mlpp_features/accessors.py index 70e6a14..ba1244e 100644 --- a/mlpp_features/accessors.py +++ b/mlpp_features/accessors.py @@ -1,4 +1,5 @@ """""" + import logging import warnings from dataclasses import dataclass, field diff --git a/mlpp_features/decorators.py b/mlpp_features/decorators.py index 2737a4b..c930a10 100644 --- a/mlpp_features/decorators.py +++ b/mlpp_features/decorators.py @@ -4,6 +4,7 @@ from functools import wraps from pathlib import Path from tempfile import TemporaryDirectory +from typing import Dict, List import xarray as xr @@ -84,3 +85,27 @@ def inner(*args, **kwargs): return out.chunk("auto").persist() return inner + + +def inputs(*vars: str): + """ + Raise a KeyError to expose all the required input data during discovery, which + is done by calling the feature function with a dictionary of empty datasets. + This decorator is currently needed for features that require multiple or no input data. + """ + + def decorator(fn): + @wraps(fn) + def inner(*args, **kwargs): + data = args[0] + is_discovery = all( + arr is not None and len(arr) == 0 for arr in data.values() + ) + if is_discovery: + raise KeyError([var.split(":")[1] for var in vars]) + else: + return fn(*args, **kwargs) + + return inner + + return decorator diff --git a/mlpp_features/experimental.py b/mlpp_features/experimental.py index ef2c8fb..5b0d50e 100644 --- a/mlpp_features/experimental.py +++ b/mlpp_features/experimental.py @@ -1,4 +1,5 @@ """This module implements experimental features""" + from typing import List, Tuple import numpy as np diff --git a/mlpp_features/nwp.py b/mlpp_features/nwp.py index 757ad80..e6a8e9d 100644 --- a/mlpp_features/nwp.py +++ b/mlpp_features/nwp.py @@ -4,7 +4,7 @@ import xarray as xr import numpy as np -from mlpp_features.decorators import cache, out_format +from mlpp_features.decorators import cache, inputs, out_format from mlpp_features import calc LOGGER = logging.getLogger(__name__) @@ -476,6 +476,7 @@ def leadtime(data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwarg return ds.astype("float32") +@inputs("nwp:HSURF", "terrain:DEM") @out_format(units="m") def model_height_difference( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -483,13 +484,6 @@ def model_height_difference( """ Difference between model height and height from the more precise DEM in m """ - # try/except block necessary to expose all the required input data - try: - data["nwp"]["HSURF"] - data["terrain"]["DEM"] - except KeyError: - raise KeyError(["HSURF", "DEM"]) - hsurf_on_poi = data["nwp"].preproc.get("HSURF").preproc.interp(stations) dem_on_poi = data["terrain"].preproc.get("DEM").preproc.interp(stations) @@ -547,6 +541,7 @@ def northward_wind_ensctrl( ) +@inputs("nwp:air_temperature", "nwp:surface_air_pressure") @out_format(units="degC") def potential_temperature_ens( data: Dict[str, xr.Dataset], stations, *args, **kwargs @@ -649,6 +644,10 @@ def pressure_difference_GVE_GUT_ensctrl( ) +@inputs( + "nwp:dew_point_temperature", + "nwp:air_temperature", +) @out_format(units="%") def relative_humidity_ens( data: Dict[str, xr.Dataset], stations, *args, **kwargs @@ -889,6 +888,7 @@ def surface_air_pressure_ensctrl( ) +@inputs("terrain:SX_50M_RADIUS500", "nwp:eastward_wind", "nwp:northward_wind") @out_format() def sx_500m_ensavg( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -922,6 +922,7 @@ def sx_500m_ensavg( return sx.astype("float32") +@inputs("terrain:SX_50M_RADIUS500", "nwp:eastward_wind", "nwp:northward_wind") @out_format() def sx_500m_ensctrl( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -955,6 +956,7 @@ def sx_500m_ensctrl( return sx.astype("float32") +@inputs("nwp:air_temperature", "nwp:surface_air_pressure", "nwp:dew_point_temperature") @out_format(units="g kg-1") def water_vapor_mixing_ratio_ens( data: Dict[str, xr.Dataset], stations, *args, **kwargs @@ -993,6 +995,7 @@ def water_vapor_mixing_ratio_ensctrl( ) +@inputs("nwp:air_temperature", "nwp:dew_point_temperature") @out_format(units="hPa") def water_vapor_pressure_ens( data: Dict[str, xr.Dataset], stations, *args, **kwargs @@ -1068,6 +1071,7 @@ def water_vapor_saturation_pressure_ensctrl( ) +@inputs("nwp:eastward_wind", "nwp:northward_wind") @out_format(units="degrees") def wind_from_direction_ens( data: Dict[str, xr.Dataset], stations, *args, **kwargs @@ -1122,6 +1126,7 @@ def wind_from_direction_rank( ) +@inputs("nwp:eastward_wind", "nwp:northward_wind") @out_format(units="m s-1") def wind_speed_ens(data: Dict[str, xr.Dataset], stations, *args, **kwargs): u = eastward_wind_ens(data, stations, *args, **kwargs) @@ -1140,6 +1145,7 @@ def wind_speed_ensavg( return uv.mean("realization").to_dataset().preproc.align_time(reftimes, leadtimes) +@inputs("nwp:eastward_wind", "nwp:northward_wind", "obs:wind_speed") @out_format(units="m s-1") def wind_speed_ensavg_error( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -1209,6 +1215,7 @@ def wind_speed_ensctrl_5hmean( ) +@inputs("nwp:eastward_wind", "nwp:northward_wind", "obs:wind_speed") @out_format(units="m s-1") def wind_speed_ensctrl_error( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -1303,6 +1310,7 @@ def wind_speed_of_gust_ensavg( return ug.mean("realization").to_dataset().preproc.align_time(reftimes, leadtimes) +@inputs("nwp:wind_speed_of_gust", "obs:wind_speed_of_gust") @out_format(units="m s-1") def wind_speed_of_gust_ensavg_error( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -1372,6 +1380,7 @@ def wind_speed_of_gust_ensctrl_5hmean( ) +@inputs("nwp:wind_speed_of_gust", "obs:wind_speed_of_gust") @out_format(units="m s-1") def wind_speed_of_gust_ensctrl_error( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -1428,6 +1437,7 @@ def wind_speed_of_gust_rank( ) +@inputs("nwp:wind_speed_of_gust", "nwp:eastward_wind", "nwp:northward_wind") @out_format() def wind_gust_factor_ens( data: Dict[str, xr.Dataset], stations, *args, **kwargs diff --git a/mlpp_features/obs.py b/mlpp_features/obs.py index 36cf879..5726a32 100644 --- a/mlpp_features/obs.py +++ b/mlpp_features/obs.py @@ -4,7 +4,7 @@ import numpy as np import xarray as xr -from mlpp_features.decorators import out_format +from mlpp_features.decorators import inputs, out_format from mlpp_features import calc LOGGER = logging.getLogger(__name__) @@ -28,6 +28,7 @@ def air_temperature( ) +@inputs("obs:air_temperature", "obs:dew_point_temperature") @out_format(units="degC") def dew_point_depression( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -40,6 +41,7 @@ def dew_point_depression( return (t - t_d).astype("float32") +@inputs("obs:air_temperature", "obs:relative_humidity") @out_format(units="degC") def dew_point_temperature( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -83,6 +85,7 @@ def relative_humidity( ) +@inputs("obs:air_temperature", "obs:relative_humidity", "obs:surface_air_pressure") @out_format(units="g kg-1") def water_vapor_mixing_ratio( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -131,6 +134,7 @@ def sin_wind_from_direction( ) +@inputs("obs:air_temperature", "obs:relative_humidity") @out_format(units="hPa") def water_vapor_pressure( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -199,6 +203,7 @@ def wind_speed_of_gust( ) +@inputs("obs:air_temperature", "obs:surface_air_pressure") @out_format(units="degC") def potential_temperature( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -283,6 +288,7 @@ def distance_to_nearest_wind_speed_of_gust( ) +@inputs() @out_format() def weight_owner_id( data: Dict[str, xr.Dataset], stations, *args, **kwargs @@ -290,8 +296,6 @@ def weight_owner_id( """ Weight the station owner. """ - if data["obs"] is not None and len(data["obs"]) == 0: - raise KeyError([]) owner_id = stations.owner_id.to_xarray() owner_weight = xr.full_like(owner_id, 1.0) owner_weight = owner_weight.where(owner_id > 1, 2) @@ -302,6 +306,7 @@ def weight_owner_id( return ds.astype("float32") +@inputs() @out_format() def measurement_height( data: Dict[str, xr.Dataset], stations, *args, **kwargs @@ -309,8 +314,6 @@ def measurement_height( """ Weight the station owner. """ - if data["obs"] is not None and len(data["obs"]) == 0: - raise KeyError([]) pole_height = stations.pole_height.to_xarray() fillvalue_pole_height = pole_height.median() LOGGER.debug(f"Fill value pole height: {fillvalue_pole_height:.1f}") diff --git a/mlpp_features/terrain.py b/mlpp_features/terrain.py index dcc86ee..456b397 100644 --- a/mlpp_features/terrain.py +++ b/mlpp_features/terrain.py @@ -4,7 +4,7 @@ import numpy as np import xarray as xr -from mlpp_features.decorators import out_format +from mlpp_features.decorators import inputs, out_format from mlpp_features import experimental as exp @@ -41,6 +41,7 @@ def aspect_2000m(data: Dict[str, xr.Dataset], stations, *args, **kwargs) -> xr.D ) +@inputs("terrain:VALLEY_NORM_1000M_SMTHFACT0.5", "terrain:VALLEY_DIR_1000M_SMTHFACT0.5") @out_format() def cos_valley_index_1000m( data: Dict[str, xr.Dataset], stations, *args, **kwargs @@ -61,6 +62,7 @@ def cos_valley_index_1000m( return cos_valley.preproc.interp(stations).astype("float32") +@inputs("terrain:VALLEY_NORM_2000M_SMTHFACT0.5", "terrain:VALLEY_DIR_2000M_SMTHFACT0.5") @out_format() def cos_valley_index_2000m( data: Dict[str, xr.Dataset], stations, *args, **kwargs @@ -81,6 +83,9 @@ def cos_valley_index_2000m( return cos_valley.preproc.interp(stations).astype("float32") +@inputs( + "terrain:VALLEY_NORM_10000M_SMTHFACT0.5", "terrain:VALLEY_DIR_10000M_SMTHFACT0.5" +) @out_format() def cos_valley_index_10000m( data: Dict[str, xr.Dataset], stations, *args, **kwargs @@ -110,6 +115,7 @@ def distance_to_alpine_ridge( **Experimental feature, use with caution!** """ + # raise KeyError during discover if all([len(ds) == 0 for ds in data.values()]): raise KeyError() alpine_crest_wgs84 = [ @@ -158,6 +164,7 @@ def elevation_50m(data: Dict[str, xr.Dataset], stations, *args, **kwargs) -> xr. return data["terrain"].preproc.get("DEM").preproc.interp(stations).astype("float32") +@inputs("terrain:VALLEY_NORM_1000M_SMTHFACT0.5", "terrain:VALLEY_DIR_1000M_SMTHFACT0.5") @out_format() def sin_valley_index_1000m( data: Dict[str, xr.Dataset], stations, *args, **kwargs @@ -178,6 +185,7 @@ def sin_valley_index_1000m( return sin_valley.preproc.interp(stations).astype("float32") +@inputs("terrain:VALLEY_NORM_2000M_SMTHFACT0.5", "terrain:VALLEY_DIR_2000M_SMTHFACT0.5") @out_format() def sin_valley_index_2000m( data: Dict[str, xr.Dataset], stations, *args, **kwargs @@ -198,6 +206,9 @@ def sin_valley_index_2000m( return sin_valley.preproc.interp(stations).astype("float32") +@inputs( + "terrain:VALLEY_NORM_10000M_SMTHFACT0.5", "terrain:VALLEY_DIR_10000M_SMTHFACT0.5" +) @out_format() def sin_valley_index_10000m( data: Dict[str, xr.Dataset], stations, *args, **kwargs diff --git a/mlpp_features/time.py b/mlpp_features/time.py index 0d6f097..a349493 100644 --- a/mlpp_features/time.py +++ b/mlpp_features/time.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from mlpp_features.decorators import out_format +from mlpp_features.decorators import inputs, out_format LOGGER = logging.getLogger(__name__) @@ -25,6 +25,7 @@ def _make_time_dataset(reftimes, leadtimes): return ds.assign_coords(time=ds.forecast_reference_time + ds.t) +@inputs() @out_format() def cos_dayofyear( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -32,8 +33,6 @@ def cos_dayofyear( """ Compute the cosine of day-of-year """ - if data["nwp"] is not None and len(data["nwp"]) == 0: - raise KeyError([]) ds = xr.Dataset( None, coords={ @@ -48,6 +47,7 @@ def cos_dayofyear( return ds.pipe(np.cos).astype("float32") +@inputs() @out_format() def cos_hourofday( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -55,8 +55,6 @@ def cos_hourofday( """ Compute the cosine of hour-of-day """ - if data["nwp"] is not None and len(data["nwp"]) == 0: - raise KeyError([]) ds = xr.Dataset( None, coords={ @@ -69,6 +67,7 @@ def cos_hourofday( return ds.pipe(np.cos).astype("float32") +@inputs() @out_format() def sin_dayofyear( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -76,8 +75,6 @@ def sin_dayofyear( """ Compute the sine of day-of-year """ - if data["nwp"] is not None and len(data["nwp"]) == 0: - raise KeyError([]) ds = _make_time_dataset(reftimes, leadtimes) ds["sin_dayofyear"] = ( (ds["time.dayofyear"] + ds["time.hour"] / 24) * 2 * np.pi / 366 @@ -85,6 +82,7 @@ def sin_dayofyear( return ds.pipe(np.sin).astype("float32") +@inputs() @out_format() def sin_hourofday( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -92,13 +90,12 @@ def sin_hourofday( """ Compute the sine of hour-of-day """ - if data["nwp"] is not None and len(data["nwp"]) == 0: - raise KeyError([]) ds = _make_time_dataset(reftimes, leadtimes) ds["sin_hourofday"] = ds["time.hour"] * 2 * np.pi / 24 return ds.pipe(np.sin).astype("float32") +@inputs() @out_format() def weight_sample_age( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -106,8 +103,6 @@ def weight_sample_age( """ Compute the inverse of the sample age in years """ - if data["nwp"] is not None and len(data["nwp"]) == 0: - raise KeyError([]) ds = _make_time_dataset(reftimes, leadtimes) this_year = datetime.today().year * 365 + datetime.today().timetuple().tm_yday ds["weight_sample_age"] = ( @@ -116,6 +111,7 @@ def weight_sample_age( return ds.astype("float32") +@inputs() @out_format() def weight_leadtime( data: Dict[str, xr.Dataset], stations, reftimes, leadtimes, **kwargs @@ -123,8 +119,6 @@ def weight_leadtime( """ Weight the lead time. """ - if data["nwp"] is not None and len(data["nwp"]) == 0: - raise KeyError([]) weight_leadtime = 1.5 / (1 + leadtimes / pd.Timedelta(hours=24)) ds = xr.Dataset( {"weight_leadtime": ("t", weight_leadtime)}, diff --git a/tests/test_discover.py b/tests/test_discover.py index 872819a..1e689bb 100644 --- a/tests/test_discover.py +++ b/tests/test_discover.py @@ -3,9 +3,16 @@ import mlpp_features.discover as di -@pytest.mark.skip(reason="2022-11-08 ned: currently broken, see #16") def test_discover_inputs(): + vars = di.discover_inputs("model_height_difference") + assert vars == ["DEM", "HSURF"] + + vars = di.discover_inputs( + ["air_temperature_ensctrl", "water_vapor_mixing_ratio_ensavg"] + ) + assert vars == ["air_temperature", "dew_point_temperature", "surface_air_pressure"] + vars = di.discover_inputs("wind_speed_ensavg") assert vars == ["eastward_wind", "northward_wind"] @@ -17,3 +24,31 @@ def test_discover_inputs(): vars = di.discover_inputs(["wind_speed", "nearest_wind_speed_of_gust"]) assert vars == ["wind_speed", "wind_speed_of_gust"] # observations + + vars = di.discover_inputs(["water_vapor_mixing_ratio"]) + assert vars == ["air_temperature", "relative_humidity", "surface_air_pressure"] + + vars = di.discover_inputs("weight_owner_id") + assert vars == [] + + vars = di.discover_inputs("cos_valley_index_1000m") + assert vars == ["VALLEY_DIR_1000M_SMTHFACT0.5", "VALLEY_NORM_1000M_SMTHFACT0.5"] + + vars = di.discover_inputs( + [ + "air_temperature_ensavg", + "cos_dayofyear", + "cos_hourofday", + "boundary_layer_height_ensavg", + "pressure_difference_BAS_LUG_ensavg", + "pressure_difference_GVE_GUT_ensavg", + "wind_speed_of_gust_ensavg", + "wind_speed_of_gust_ensstd", + ] + ) + assert vars == [ + "air_temperature", + "atmosphere_boundary_layer_thickness", + "surface_air_pressure", + "wind_speed_of_gust", + ] diff --git a/tests/test_features.py b/tests/test_features.py index c72cddc..244ffd1 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -28,7 +28,7 @@ def _make_datasets( @pytest.mark.parametrize("pipeline,", pipelines) def test_raise_keyerror(self, pipeline): - """Test that all features raise a KeyError when called with empy inputs""" + """Test that all features raise a KeyError when called with empty inputs""" empty_data = { "nwp": xr.Dataset(), "terrain": xr.Dataset(),