From a9a8ac803420995ac4312a4ecf723a158b939f33 Mon Sep 17 00:00:00 2001 From: Eric Engle Date: Mon, 23 Sep 2024 14:41:00 -0400 Subject: [PATCH] Clean up date time attribute setting. This commit cleans up code for setting refDate, leadTime, and duration Grib2Message attributes. The xarray backend now ignore those attrs when calling the grib2io.update_attrs() accessor method for DataArrays. This commit references NOAA-MDL/grib2io#158 --- src/grib2io/_grib2io.py | 2 +- src/grib2io/templates.py | 159 +++++++++++++++-------------- src/grib2io/utils/__init__.py | 36 +++---- src/grib2io/xarray_backend.py | 6 +- tests/test_grib2_datetime_attrs.py | 41 ++++++++ tests/test_xarray_update_attrs.py | 60 +---------- 6 files changed, 150 insertions(+), 154 deletions(-) create mode 100755 tests/test_grib2_datetime_attrs.py diff --git a/src/grib2io/_grib2io.py b/src/grib2io/_grib2io.py index e8d5047..0f97d29 100644 --- a/src/grib2io/_grib2io.py +++ b/src/grib2io/_grib2io.py @@ -847,7 +847,7 @@ def __str__(self): """ return (f'{self._msgnum}:d={self.refDate}:{self.shortName}:' f'{self.fullName} ({self.units}):{self.level}:' - f'{self.leadTime}') + f'{self.leadTime+self.duration}') def _generate_signature(self): diff --git a/src/grib2io/templates.py b/src/grib2io/templates.py index ab9fa6a..dc2b2b7 100644 --- a/src/grib2io/templates.py +++ b/src/grib2io/templates.py @@ -1,10 +1,11 @@ """GRIB2 section templates classes and metadata descriptor classes.""" from dataclasses import dataclass, field from collections import defaultdict -import datetime from typing import Union - -from numpy import timedelta64, datetime64 +import copy +import datetime +import numpy as np +import warnings from . import tables from . import utils @@ -25,6 +26,14 @@ 7:[], 8:[],} +_continuous_pdtns = [ + int(k) for k, v in tables.get_table("4.0").items() if "a point in time" in v +] +_timeinterval_pdtns = [ + int(k) + for k, v in tables.get_table("4.0").items() + if "continuous or non-continuous time interval" in v +] def _calculate_scale_factor(value: float): """ @@ -168,62 +177,74 @@ class Year: def __get__(self, obj, objtype=None): return obj.section1[5] def __set__(self, obj, value): - obj.section1[5] = value + rd = copy.copy(obj.section1[5:11]) + rd[0] = value # Test validity of datetime values - _ = datetime.datetime(*obj.section1[5:11]) + _ = datetime.datetime(*rd) + obj.section1[5] = value class Month: """Month of reference time""" def __get__(self, obj, objtype=None): return obj.section1[6] def __set__(self, obj, value): - obj.section1[6] = value + rd = copy.copy(obj.section1[5:11]) + rd[1] = value # Test validity of datetime values - _ = datetime.datetime(*obj.section1[5:11]) + _ = datetime.datetime(*rd) + obj.section1[6] = value class Day: """Day of reference time""" def __get__(self, obj, objtype=None): return obj.section1[7] def __set__(self, obj, value): - obj.section1[7] = value + rd = copy.copy(obj.section1[5:11]) + rd[2] = value # Test validity of datetime values - _ = datetime.datetime(*obj.section1[5:11]) + _ = datetime.datetime(*rd) + #obj.section1[7] = value class Hour: """Hour of reference time""" def __get__(self, obj, objtype=None): return obj.section1[8] def __set__(self, obj, value): - obj.section1[8] = value + rd = copy.copy(obj.section1[5:11]) + rd[3] = value # Test validity of datetime values - _ = datetime.datetime(*obj.section1[5:11]) + _ = datetime.datetime(*rd) + obj.section1[8] = value class Minute: """Minute of reference time""" def __get__(self, obj, objtype=None): return obj.section1[9] def __set__(self, obj, value): - obj.section1[9] = value + rd = copy.copy(obj.section1[5:11]) + rd[4] = value # Test validity of datetime values - _ = datetime.datetime(*obj.section1[5:11]) + _ = datetime.datetime(*rd) + obj.section1[9] = value class Second: """Second of reference time""" def __get__(self, obj, objtype=None): return obj.section1[10] def __set__(self, obj, value): - obj.section1[10] = value + rd = copy.copy(obj.section1[5:11]) + rd[5] = value # Test validity of datetime values - _ = datetime.datetime(*obj.section1[5:11]) + _ = datetime.datetime(*rd) + obj.section1[10] = value class RefDate: """Reference Date. NOTE: This is a `datetime.datetime` object.""" def __get__(self, obj, objtype=None): return datetime.datetime(*obj.section1[5:11]) def __set__(self, obj, value): - if isinstance(value, datetime64): - timestamp = (value - datetime64("1970-01-01T00:00:00")) / timedelta64( + if isinstance(value, np.datetime64): + timestamp = (value - np.datetime64("1970-01-01T00:00:00")) / np.timedelta64( 1, "s" ) value = datetime.datetime.utcfromtimestamp(timestamp) @@ -234,6 +255,15 @@ def __set__(self, obj, value): obj.section1[8] = value.hour obj.section1[9] = value.minute obj.section1[10] = value.second + # IMPORTANT: Update validDate components when message is time interval + if obj.pdtn in _timeinterval_pdtns: + vd = value + obj.leadTime + obj.duration + obj.yearOfEndOfTimePeriod = vd.year + obj.monthOfEndOfTimePeriod = vd.month + obj.dayOfEndOfTimePeriod = vd.day + obj.hourOfEndOfTimePeriod = vd.hour + obj.minuteOfEndOfTimePeriod = vd.minute + obj.secondOfEndOfTimePeriod = vd.second else: msg = "Reference date must be a datetime.datetime or np.datetime64 object." raise TypeError(msg) @@ -1029,61 +1059,24 @@ def __set__(self, obj, value): class LeadTime: """Forecast Lead Time. NOTE: This is a `datetime.timedelta` object.""" - - def __get__(self, obj, objtype=None): - return utils.get_leadtime(obj.section1, obj.section4[1], obj.section4[2:]) - - def __set__(self, obj, value): - pdt = obj.section4[2:] - - # For the tables below, the key is the PDTN and the value is the slice - # of the PDT that contains the end date of the accumulation. - # This is only needed for PDTNs 8-12. - _key = { - 8: slice(15, 21), - 9: slice(22, 28), - 10: slice(16, 22), - 11: slice(18, 24), - 12: slice(17, 23), - } - - accumulation_offset = 0 - if obj.pdtn in _key: - accumulation_end_date = _key[obj.pdtn] - accumulation_key = accumulation_end_date.stop + 5 - accumulation_offset = int( - timedelta64(pdt[accumulation_key], "h") / timedelta64(1, "h") - ) - accumulation_offset = accumulation_offset / ( - tables.get_value_from_table( - pdt[accumulation_key - 1], "scale_time_hours" - ) - ) - - refdate = datetime.datetime(*obj.section1[5:11]) - pdt[_key[obj.pdtn]] = ( - datetime.timedelta(hours=accumulation_offset) + refdate - ).timetuple()[:6] - - # All messages need the leadTime value set, but for PDTNs 8-12, the - # leadTime value has to be set to the beginning of the accumulation - # period which is done here by subtracting the already calculated - # value accumulation_offset. - lead_time_index = 8 - if obj.pdtn == 48: - lead_time_index = 19 - - ivalue = int(timedelta64(value, "h") / timedelta64(1, "h")) - - pdt[lead_time_index] = ( - ivalue - / ( - tables.get_value_from_table( - pdt[lead_time_index - 1], "scale_time_hours" - ) - ) - - accumulation_offset - ) + _key = ValueOfForecastTime._key + def __get__(self, obj, objtype=None): + return utils.get_leadtime(obj.section4[1], obj.section4[2:]) + def __set__(self, obj, value): + if isinstance(value, np.timedelta64): + # Allows setting from xarray + value = datetime.timedelta( + seconds=int(value/np.timedelta64(1, 's'))) + obj.section4[self._key[obj.pdtn]+2] = int(value.total_seconds()/3600) + # IMPORTANT: Update validDate components when message is time interval + if obj.pdtn in _timeinterval_pdtns: + vd = obj.refDate + value + obj.duration + obj.yearOfEndOfTimePeriod = vd.year + obj.monthOfEndOfTimePeriod = vd.month + obj.dayOfEndOfTimePeriod = vd.day + obj.hourOfEndOfTimePeriod = vd.hour + obj.minuteOfEndOfTimePeriod = vd.minute + obj.secondOfEndOfTimePeriod = vd.second class FixedSfc1Info: """Information of the first fixed surface via [table 4.5](https://www.nco.ncep.noaa.gov/pmb/docs/grib2/grib2_doc/grib2_table4-5.shtml)""" @@ -1422,7 +1415,25 @@ class Duration: def __get__(self, obj, objtype=None): return utils.get_duration(obj.section4[1],obj.section4[2:]) def __set__(self, obj, value): - pass + if obj.pdtn in _continuous_pdtns: + pass + elif obj.pdtn in _timeinterval_pdtns: + _key = TimeRangeOfStatisticalProcess._key + if isinstance(value, np.timedelta64): + # Allows setting from xarray + value = datetime.timedelta( + seconds=int(value/np.timedelta64(1, 's'))) + obj.section4[_key[obj.pdtn]+2] = int(value.total_seconds()/3600) + # IMPORTANT: Update validDate components when message is time interval + if obj.pdtn in _timeinterval_pdtns: + print(obj.refDate, value, obj.leadTime) + vd = obj.refDate + value + obj.leadTime + obj.yearOfEndOfTimePeriod = vd.year + obj.monthOfEndOfTimePeriod = vd.month + obj.dayOfEndOfTimePeriod = vd.day + obj.hourOfEndOfTimePeriod = vd.hour + obj.minuteOfEndOfTimePeriod = vd.minute + obj.secondOfEndOfTimePeriod = vd.second class ValidDate: """Valid Date of the forecast. NOTE: This is a `datetime.datetime` object.""" @@ -1435,7 +1446,7 @@ def __get__(self, obj, objtype=None): except(KeyError): return obj.refDate + obj.leadTime def __set__(self, obj, value): - pass + warnings.warn(f"validDate attribute is read-only.") class NumberOfTimeRanges: """Number of time ranges specifications describing the time intervals used to calculate the statistically-processed field""" diff --git a/src/grib2io/utils/__init__.py b/src/grib2io/utils/__init__.py index f3d1d35..fdb28b8 100644 --- a/src/grib2io/utils/__init__.py +++ b/src/grib2io/utils/__init__.py @@ -11,6 +11,7 @@ from numpy.typing import ArrayLike from .. import tables +from .. import templates def int2bin(i: int, nbits: int=8, output: Union[Type[str], Type[List]]=str): """ @@ -80,17 +81,15 @@ def ieee_int_to_float(i): return np.float32(f) -def get_leadtime(idsec: ArrayLike, pdtn: int, pdt: ArrayLike) -> datetime.timedelta: +def get_leadtime(pdtn: int, pdt: ArrayLike) -> datetime.timedelta: """ Compute lead time as a datetime.timedelta object. - Using information from GRIB2 Identification Section (Section 1), Product - Definition Template Number, and Product Definition Template (Section 4). + Using information from GRIB2 Product Definition Template + Number, and Product Definition Template (Section 4). Parameters ---------- - idsec - Sequence containing GRIB2 Identification Section (Section 1). pdtn GRIB2 Product Definition Template Number pdt @@ -101,15 +100,9 @@ def get_leadtime(idsec: ArrayLike, pdtn: int, pdt: ArrayLike) -> datetime.timede leadTime datetime.timedelta object representing the lead time of the GRIB2 message. """ - _key = {8:slice(15,21), 9:slice(22,28), 10:slice(16,22), 11:slice(18,24), 12:slice(17,23)} - refdate = datetime.datetime(*idsec[5:11]) - try: - return datetime.datetime(*pdt[_key[pdtn]])-refdate - except(KeyError): - if pdtn == 48: - return datetime.timedelta(hours=pdt[19]*(tables.get_value_from_table(pdt[18],'scale_time_hours'))) - else: - return datetime.timedelta(hours=pdt[8]*(tables.get_value_from_table(pdt[7],'scale_time_hours'))) + lt = tables.get_value_from_table(pdt[templates.UnitOfForecastTime._key[pdtn]], 'scale_time_hours') + lt *= pdt[templates.ValueOfForecastTime._key[pdtn]] + return datetime.timedelta(hours=int(lt)) def get_duration(pdtn: int, pdt: ArrayLike) -> datetime.timedelta: @@ -132,11 +125,16 @@ def get_duration(pdtn: int, pdt: ArrayLike) -> datetime.timedelta: datetime.timedelta object representing the time duration of the GRIB2 message. """ - _key = {8:25, 9:32, 10:26, 11:28, 12:27} - try: - return datetime.timedelta(hours=pdt[_key[pdtn]+1]*tables.get_value_from_table(pdt[_key[pdtn]],'scale_time_hours')) - except(KeyError): - return datetime.timedelta(hours=0) + if pdtn in templates._timeinterval_pdtns: + ntime = pdt[templates.NumberOfTimeRanges._key[pdtn]] + duration_unit = tables.get_value_from_table( + pdt[templates.UnitOfTimeRangeOfStatisticalProcess._key[pdtn]], + 'scale_time_hours') + d = ntime * duration_unit * pdt[ + templates.TimeRangeOfStatisticalProcess._key[pdtn]] + else: + d = 0 + return datetime.timedelta(hours=int(d)) def decode_wx_strings(lus: bytes) -> Dict[int, str]: diff --git a/src/grib2io/xarray_backend.py b/src/grib2io/xarray_backend.py index a50f1fc..51934e2 100755 --- a/src/grib2io/xarray_backend.py +++ b/src/grib2io/xarray_backend.py @@ -1110,7 +1110,7 @@ def update_attrs(self, **kwargs): coords_keys = [ k for k in da.coords.keys() - if (k in AVAILABLE_NON_GEO_DIMS) and (k in da.dims) + if k in AVAILABLE_NON_GEO_DIMS ] for grib2_name, value in kwargs.items(): @@ -1128,7 +1128,7 @@ def update_attrs(self, **kwargs): ) if grib2_name in coords_keys: warn( - f"Skipping attribute '{grib2_name}' because it is a dimension coordinate and cannot be updated." + f"Skipping attribute '{grib2_name}' because it is a coordinate. Use da.assign_coords() to change coordinate values." ) continue if hasattr(newmsg, grib2_name): @@ -1190,4 +1190,4 @@ def subset(self, lats, lons) -> xr.DataArray: da.longitude <= newmsg.longitudeLastGridpoint ) - return da.where((mask_lon & mask_lat).compute(), drop=True) \ No newline at end of file + return da.where((mask_lon & mask_lat).compute(), drop=True) diff --git a/tests/test_grib2_datetime_attrs.py b/tests/test_grib2_datetime_attrs.py new file mode 100755 index 0000000..5b2bd1b --- /dev/null +++ b/tests/test_grib2_datetime_attrs.py @@ -0,0 +1,41 @@ +import pytest +import numpy as np +import datetime +import grib2io + +def test_datetime_attrs(request): + data = request.config.rootdir / 'tests' / 'data' / 'gfs_20221107' + with grib2io.open(data / 'gfs.t00z.pgrb2.1p00.f012_subset') as f: + msg = f['TMAX'][0] + + expected_refDate = datetime.datetime(2022, 11, 7, 0, 0) + expected_leadTime = datetime.timedelta(seconds=21600) + expected_duration = datetime.timedelta(seconds=21600) + expected_validDate = datetime.datetime(2022, 11, 7, 12, 0) + + assert msg.refDate == expected_refDate + assert msg.leadTime == expected_leadTime + assert msg.duration == expected_duration + assert msg.validDate == expected_validDate + + msg.leadTime = datetime.timedelta(hours=24) + + assert msg.refDate == datetime.datetime(2022, 11, 7, 0, 0) + assert msg.leadTime == datetime.timedelta(days=1) + assert msg.duration == datetime.timedelta(seconds=21600) + assert msg.validDate == datetime.datetime(2022, 11, 8, 6, 0) + + msg.duration = datetime.timedelta(hours=18) + + assert msg.refDate == datetime.datetime(2022, 11, 7, 0, 0) + assert msg.leadTime == datetime.timedelta(days=1) + assert msg.duration == datetime.timedelta(seconds=64800) + assert msg.validDate == datetime.datetime(2022, 11, 8, 18, 0) + + msg.leadTime = datetime.timedelta(seconds=21600) + msg.duration = datetime.timedelta(seconds=21600) + + assert msg.refDate == datetime.datetime(2022, 11, 7, 0, 0) + assert msg.leadTime == datetime.timedelta(seconds=21600) + assert msg.duration == datetime.timedelta(seconds=21600) + assert msg.validDate == datetime.datetime(2022, 11, 7, 12, 0) diff --git a/tests/test_xarray_update_attrs.py b/tests/test_xarray_update_attrs.py index 1c3556e..d5df023 100755 --- a/tests/test_xarray_update_attrs.py +++ b/tests/test_xarray_update_attrs.py @@ -131,15 +131,6 @@ "", # error_message id="warning_dims", ), - pytest.param( - { - "leadTime": 4, - }, # kwargs - set, # expected_type - set(), # expected - None, - id="warning_dims", - ), pytest.param( { "zebra": 4, @@ -149,42 +140,15 @@ "", # error_message id="warning_not_found", ), - pytest.param( - { - "zebra": 4, - }, # kwargs - set, # expected_type - set(), # expected - None, # error_message - id="warning_not_found", - ), pytest.param( { "refDate": datetime.datetime(2022, 11, 7, 0, 0), }, # kwargs - set, # expected_type - set(), # expected - None, + Warning, # expected_type + UserWarning, # expected + "", # error message id="refDate", ), - pytest.param( - { - "refDate": datetime.datetime(2021, 11, 7, 0, 0), - }, # kwargs - set, # expected_type - { - ( - "GRIB2IO_section1", - "[ 7 0 2 1 1 2022 11 7 0 0 0 0 1]", - ), - ( - "GRIB2IO_section1", - "[ 7 0 2 1 1 2021 11 7 0 0 0 0 1]", - ), - }, # expected - None, - id="refDate_year", - ), pytest.param( { "year": 2021, @@ -203,24 +167,6 @@ None, id="year", ), - pytest.param( - { - "refDate": datetime.datetime(2022, 10, 7, 0, 0), - }, # kwargs - set, # expected_type - { - ( - "GRIB2IO_section1", - "[ 7 0 2 1 1 2022 11 7 0 0 0 0 1]", - ), - ( - "GRIB2IO_section1", - "[ 7 0 2 1 1 2022 10 7 0 0 0 0 1]", - ), - }, # expected - None, - id="refDate_month", - ), pytest.param( { "month": 10,