Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up date time attribute setting. #159

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/grib2io/_grib2io.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ def __str__(self):
"""
return (f'{self._msgnum}:d={self.refDate}:{self.shortName}:'
f'{self.fullName} ({self.units}):{self.level}:'
f'{self.leadTime}')
f'{self.leadTime+self.duration}')


def _generate_signature(self):
Expand Down
159 changes: 85 additions & 74 deletions src/grib2io/templates.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""GRIB2 section templates classes and metadata descriptor classes."""
from dataclasses import dataclass, field
from collections import defaultdict
import datetime
from typing import Union

from numpy import timedelta64, datetime64
import copy
import datetime
import numpy as np
import warnings

from . import tables
from . import utils
Expand All @@ -25,6 +26,14 @@
7:[],
8:[],}

_continuous_pdtns = [
int(k) for k, v in tables.get_table("4.0").items() if "a point in time" in v
]
_timeinterval_pdtns = [
int(k)
for k, v in tables.get_table("4.0").items()
if "continuous or non-continuous time interval" in v
]

def _calculate_scale_factor(value: float):
"""
Expand Down Expand Up @@ -168,62 +177,74 @@ class Year:
def __get__(self, obj, objtype=None):
return obj.section1[5]
def __set__(self, obj, value):
obj.section1[5] = value
rd = copy.copy(obj.section1[5:11])
rd[0] = value
# Test validity of datetime values
_ = datetime.datetime(*obj.section1[5:11])
_ = datetime.datetime(*rd)
obj.section1[5] = value

class Month:
"""Month of reference time"""
def __get__(self, obj, objtype=None):
return obj.section1[6]
def __set__(self, obj, value):
obj.section1[6] = value
rd = copy.copy(obj.section1[5:11])
rd[1] = value
# Test validity of datetime values
_ = datetime.datetime(*obj.section1[5:11])
_ = datetime.datetime(*rd)
obj.section1[6] = value

class Day:
"""Day of reference time"""
def __get__(self, obj, objtype=None):
return obj.section1[7]
def __set__(self, obj, value):
obj.section1[7] = value
rd = copy.copy(obj.section1[5:11])
rd[2] = value
# Test validity of datetime values
_ = datetime.datetime(*obj.section1[5:11])
_ = datetime.datetime(*rd)
#obj.section1[7] = value

class Hour:
"""Hour of reference time"""
def __get__(self, obj, objtype=None):
return obj.section1[8]
def __set__(self, obj, value):
obj.section1[8] = value
rd = copy.copy(obj.section1[5:11])
rd[3] = value
# Test validity of datetime values
_ = datetime.datetime(*obj.section1[5:11])
_ = datetime.datetime(*rd)
obj.section1[8] = value

class Minute:
"""Minute of reference time"""
def __get__(self, obj, objtype=None):
return obj.section1[9]
def __set__(self, obj, value):
obj.section1[9] = value
rd = copy.copy(obj.section1[5:11])
rd[4] = value
# Test validity of datetime values
_ = datetime.datetime(*obj.section1[5:11])
_ = datetime.datetime(*rd)
obj.section1[9] = value

class Second:
"""Second of reference time"""
def __get__(self, obj, objtype=None):
return obj.section1[10]
def __set__(self, obj, value):
obj.section1[10] = value
rd = copy.copy(obj.section1[5:11])
rd[5] = value
# Test validity of datetime values
_ = datetime.datetime(*obj.section1[5:11])
_ = datetime.datetime(*rd)
obj.section1[10] = value

class RefDate:
"""Reference Date. NOTE: This is a `datetime.datetime` object."""
def __get__(self, obj, objtype=None):
return datetime.datetime(*obj.section1[5:11])
def __set__(self, obj, value):
if isinstance(value, datetime64):
timestamp = (value - datetime64("1970-01-01T00:00:00")) / timedelta64(
if isinstance(value, np.datetime64):
timestamp = (value - np.datetime64("1970-01-01T00:00:00")) / np.timedelta64(
1, "s"
)
value = datetime.datetime.utcfromtimestamp(timestamp)
Expand All @@ -234,6 +255,15 @@ def __set__(self, obj, value):
obj.section1[8] = value.hour
obj.section1[9] = value.minute
obj.section1[10] = value.second
# IMPORTANT: Update validDate components when message is time interval
if obj.pdtn in _timeinterval_pdtns:
vd = value + obj.leadTime + obj.duration
obj.yearOfEndOfTimePeriod = vd.year
obj.monthOfEndOfTimePeriod = vd.month
obj.dayOfEndOfTimePeriod = vd.day
obj.hourOfEndOfTimePeriod = vd.hour
obj.minuteOfEndOfTimePeriod = vd.minute
obj.secondOfEndOfTimePeriod = vd.second
else:
msg = "Reference date must be a datetime.datetime or np.datetime64 object."
raise TypeError(msg)
Expand Down Expand Up @@ -1029,61 +1059,24 @@ def __set__(self, obj, value):

class LeadTime:
"""Forecast Lead Time. NOTE: This is a `datetime.timedelta` object."""

def __get__(self, obj, objtype=None):
return utils.get_leadtime(obj.section1, obj.section4[1], obj.section4[2:])

def __set__(self, obj, value):
pdt = obj.section4[2:]

# For the tables below, the key is the PDTN and the value is the slice
# of the PDT that contains the end date of the accumulation.
# This is only needed for PDTNs 8-12.
_key = {
8: slice(15, 21),
9: slice(22, 28),
10: slice(16, 22),
11: slice(18, 24),
12: slice(17, 23),
}

accumulation_offset = 0
if obj.pdtn in _key:
accumulation_end_date = _key[obj.pdtn]
accumulation_key = accumulation_end_date.stop + 5
accumulation_offset = int(
timedelta64(pdt[accumulation_key], "h") / timedelta64(1, "h")
)
accumulation_offset = accumulation_offset / (
tables.get_value_from_table(
pdt[accumulation_key - 1], "scale_time_hours"
)
)

refdate = datetime.datetime(*obj.section1[5:11])
pdt[_key[obj.pdtn]] = (
datetime.timedelta(hours=accumulation_offset) + refdate
).timetuple()[:6]

# All messages need the leadTime value set, but for PDTNs 8-12, the
# leadTime value has to be set to the beginning of the accumulation
# period which is done here by subtracting the already calculated
# value accumulation_offset.
lead_time_index = 8
if obj.pdtn == 48:
lead_time_index = 19

ivalue = int(timedelta64(value, "h") / timedelta64(1, "h"))

pdt[lead_time_index] = (
ivalue
/ (
tables.get_value_from_table(
pdt[lead_time_index - 1], "scale_time_hours"
)
)
- accumulation_offset
)
_key = ValueOfForecastTime._key
def __get__(self, obj, objtype=None):
return utils.get_leadtime(obj.section4[1], obj.section4[2:])
def __set__(self, obj, value):
if isinstance(value, np.timedelta64):
# Allows setting from xarray
value = datetime.timedelta(
seconds=int(value/np.timedelta64(1, 's')))
obj.section4[self._key[obj.pdtn]+2] = int(value.total_seconds()/3600)
# IMPORTANT: Update validDate components when message is time interval
if obj.pdtn in _timeinterval_pdtns:
vd = obj.refDate + value + obj.duration
obj.yearOfEndOfTimePeriod = vd.year
obj.monthOfEndOfTimePeriod = vd.month
obj.dayOfEndOfTimePeriod = vd.day
obj.hourOfEndOfTimePeriod = vd.hour
obj.minuteOfEndOfTimePeriod = vd.minute
obj.secondOfEndOfTimePeriod = vd.second

class FixedSfc1Info:
"""Information of the first fixed surface via [table 4.5](https://www.nco.ncep.noaa.gov/pmb/docs/grib2/grib2_doc/grib2_table4-5.shtml)"""
Expand Down Expand Up @@ -1422,7 +1415,25 @@ class Duration:
def __get__(self, obj, objtype=None):
return utils.get_duration(obj.section4[1],obj.section4[2:])
def __set__(self, obj, value):
pass
if obj.pdtn in _continuous_pdtns:
pass
elif obj.pdtn in _timeinterval_pdtns:
_key = TimeRangeOfStatisticalProcess._key
if isinstance(value, np.timedelta64):
# Allows setting from xarray
value = datetime.timedelta(
seconds=int(value/np.timedelta64(1, 's')))
obj.section4[_key[obj.pdtn]+2] = int(value.total_seconds()/3600)
# IMPORTANT: Update validDate components when message is time interval
if obj.pdtn in _timeinterval_pdtns:
print(obj.refDate, value, obj.leadTime)
vd = obj.refDate + value + obj.leadTime
obj.yearOfEndOfTimePeriod = vd.year
obj.monthOfEndOfTimePeriod = vd.month
obj.dayOfEndOfTimePeriod = vd.day
obj.hourOfEndOfTimePeriod = vd.hour
obj.minuteOfEndOfTimePeriod = vd.minute
obj.secondOfEndOfTimePeriod = vd.second

class ValidDate:
"""Valid Date of the forecast. NOTE: This is a `datetime.datetime` object."""
Expand All @@ -1435,7 +1446,7 @@ def __get__(self, obj, objtype=None):
except(KeyError):
return obj.refDate + obj.leadTime
def __set__(self, obj, value):
pass
warnings.warn(f"validDate attribute is read-only.")

class NumberOfTimeRanges:
"""Number of time ranges specifications describing the time intervals used to calculate the statistically-processed field"""
Expand Down
36 changes: 17 additions & 19 deletions src/grib2io/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from numpy.typing import ArrayLike

from .. import tables
from .. import templates

def int2bin(i: int, nbits: int=8, output: Union[Type[str], Type[List]]=str):
"""
Expand Down Expand Up @@ -80,17 +81,15 @@ def ieee_int_to_float(i):
return np.float32(f)


def get_leadtime(idsec: ArrayLike, pdtn: int, pdt: ArrayLike) -> datetime.timedelta:
def get_leadtime(pdtn: int, pdt: ArrayLike) -> datetime.timedelta:
"""
Compute lead time as a datetime.timedelta object.

Using information from GRIB2 Identification Section (Section 1), Product
Definition Template Number, and Product Definition Template (Section 4).
Using information from GRIB2 Product Definition Template
Number, and Product Definition Template (Section 4).

Parameters
----------
idsec
Sequence containing GRIB2 Identification Section (Section 1).
pdtn
GRIB2 Product Definition Template Number
pdt
Expand All @@ -101,15 +100,9 @@ def get_leadtime(idsec: ArrayLike, pdtn: int, pdt: ArrayLike) -> datetime.timede
leadTime
datetime.timedelta object representing the lead time of the GRIB2 message.
"""
_key = {8:slice(15,21), 9:slice(22,28), 10:slice(16,22), 11:slice(18,24), 12:slice(17,23)}
refdate = datetime.datetime(*idsec[5:11])
try:
return datetime.datetime(*pdt[_key[pdtn]])-refdate
except(KeyError):
if pdtn == 48:
return datetime.timedelta(hours=pdt[19]*(tables.get_value_from_table(pdt[18],'scale_time_hours')))
else:
return datetime.timedelta(hours=pdt[8]*(tables.get_value_from_table(pdt[7],'scale_time_hours')))
lt = tables.get_value_from_table(pdt[templates.UnitOfForecastTime._key[pdtn]], 'scale_time_hours')
lt *= pdt[templates.ValueOfForecastTime._key[pdtn]]
return datetime.timedelta(hours=int(lt))


def get_duration(pdtn: int, pdt: ArrayLike) -> datetime.timedelta:
Expand All @@ -132,11 +125,16 @@ def get_duration(pdtn: int, pdt: ArrayLike) -> datetime.timedelta:
datetime.timedelta object representing the time duration of the GRIB2
message.
"""
_key = {8:25, 9:32, 10:26, 11:28, 12:27}
try:
return datetime.timedelta(hours=pdt[_key[pdtn]+1]*tables.get_value_from_table(pdt[_key[pdtn]],'scale_time_hours'))
except(KeyError):
return datetime.timedelta(hours=0)
if pdtn in templates._timeinterval_pdtns:
ntime = pdt[templates.NumberOfTimeRanges._key[pdtn]]
duration_unit = tables.get_value_from_table(
pdt[templates.UnitOfTimeRangeOfStatisticalProcess._key[pdtn]],
'scale_time_hours')
d = ntime * duration_unit * pdt[
templates.TimeRangeOfStatisticalProcess._key[pdtn]]
else:
d = 0
return datetime.timedelta(hours=int(d))


def decode_wx_strings(lus: bytes) -> Dict[int, str]:
Expand Down
6 changes: 3 additions & 3 deletions src/grib2io/xarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ def update_attrs(self, **kwargs):
coords_keys = [
k
for k in da.coords.keys()
if (k in AVAILABLE_NON_GEO_DIMS) and (k in da.dims)
if k in AVAILABLE_NON_GEO_DIMS
]

for grib2_name, value in kwargs.items():
Expand All @@ -1128,7 +1128,7 @@ def update_attrs(self, **kwargs):
)
if grib2_name in coords_keys:
warn(
f"Skipping attribute '{grib2_name}' because it is a dimension coordinate and cannot be updated."
f"Skipping attribute '{grib2_name}' because it is a coordinate. Use da.assign_coords() to change coordinate values."
)
continue
if hasattr(newmsg, grib2_name):
Expand Down Expand Up @@ -1190,4 +1190,4 @@ def subset(self, lats, lons) -> xr.DataArray:
da.longitude <= newmsg.longitudeLastGridpoint
)

return da.where((mask_lon & mask_lat).compute(), drop=True)
return da.where((mask_lon & mask_lat).compute(), drop=True)
41 changes: 41 additions & 0 deletions tests/test_grib2_datetime_attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
import numpy as np
import datetime
import grib2io

def test_datetime_attrs(request):
data = request.config.rootdir / 'tests' / 'data' / 'gfs_20221107'
with grib2io.open(data / 'gfs.t00z.pgrb2.1p00.f012_subset') as f:
msg = f['TMAX'][0]

expected_refDate = datetime.datetime(2022, 11, 7, 0, 0)
expected_leadTime = datetime.timedelta(seconds=21600)
expected_duration = datetime.timedelta(seconds=21600)
expected_validDate = datetime.datetime(2022, 11, 7, 12, 0)

assert msg.refDate == expected_refDate
assert msg.leadTime == expected_leadTime
assert msg.duration == expected_duration
assert msg.validDate == expected_validDate

msg.leadTime = datetime.timedelta(hours=24)

assert msg.refDate == datetime.datetime(2022, 11, 7, 0, 0)
assert msg.leadTime == datetime.timedelta(days=1)
assert msg.duration == datetime.timedelta(seconds=21600)
assert msg.validDate == datetime.datetime(2022, 11, 8, 6, 0)

msg.duration = datetime.timedelta(hours=18)

assert msg.refDate == datetime.datetime(2022, 11, 7, 0, 0)
assert msg.leadTime == datetime.timedelta(days=1)
assert msg.duration == datetime.timedelta(seconds=64800)
assert msg.validDate == datetime.datetime(2022, 11, 8, 18, 0)

msg.leadTime = datetime.timedelta(seconds=21600)
msg.duration = datetime.timedelta(seconds=21600)

assert msg.refDate == datetime.datetime(2022, 11, 7, 0, 0)
assert msg.leadTime == datetime.timedelta(seconds=21600)
assert msg.duration == datetime.timedelta(seconds=21600)
assert msg.validDate == datetime.datetime(2022, 11, 7, 12, 0)
Loading
Loading