Skip to content

Commit

Permalink
Add or edit multiple __set__ functions. (#150)
Browse files Browse the repository at this point in the history
* Created a _calculate_scale_factor function in templates.py to
  be used in the new ValueOf* __set__ functions.
* Added __set__ functions to all the ValueOf* classes that behind the
  scenes use the ScaledValue* and ScaleFactor* setters.
* Changed the ValueOf* __get__ to behind the scenes use the
  ScaledValue* and ScaleFactor
* Fixed the LeadTime __set__ for messages that are an accumulating
  statistic to set both LeadTime to the beginning of the accumulating
  period and the accumulating end date to the appropriate parts of
  section4.
* Now when writing with to_grib2 the messages will be sorted, if writing
  from a DataSet will sort on shortName, then for each DataArray will
  sort on dimension names, then sort dimension values.
  • Loading branch information
TimothyCera-NOAA authored Jun 5, 2024
1 parent b4c4266 commit e67a9be
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 40 deletions.
117 changes: 84 additions & 33 deletions src/grib2io/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@
8:[],}


def _calculate_scale_factor(value: float):
"""
Calculate the scale factor for a given value.
Parameters
----------
value : float
Value for which to calculate the scale factor.
Returns
-------
int
Scale factor for the value.
"""
return len(f"{value}".split(".")[1].rstrip("0"))


class Grib2Metadata:
"""
Class to hold GRIB2 metadata.
Expand Down Expand Up @@ -988,34 +1005,61 @@ def __set__(self, obj, value):

class LeadTime:
"""Forecast Lead Time. NOTE: This is a `datetime.timedelta` object."""

def __get__(self, obj, objtype=None):
return utils.get_leadtime(obj.section1,obj.section4[1],
obj.section4[2:])
return utils.get_leadtime(obj.section1, obj.section4[1], obj.section4[2:])

def __set__(self, obj, value):
pdt = obj.section4[2:]

# For the tables below, the key is the PDTN and the value is the slice
# of the PDT that contains the end date of the accumulation.
# This is only needed for PDTNs 8-12.
_key = {
8: slice(15, 21),
9: slice(22, 28),
10: slice(16, 22),
11: slice(18, 24),
12: slice(17, 23),
}
refdate = datetime.datetime(*obj.section1[5:11])
ivalue = int(timedelta64(value, "h") / timedelta64(1, "h"))
try:

accumulation_offset = 0
if obj.pdtn in _key:
accumulation_end_date = _key[obj.pdtn]
accumulation_key = accumulation_end_date.stop + 5
accumulation_offset = int(
timedelta64(pdt[accumulation_key], "h") / timedelta64(1, "h")
)
accumulation_offset = accumulation_offset / (
tables.get_value_from_table(
pdt[accumulation_key - 1], "scale_time_hours"
)
)

refdate = datetime.datetime(*obj.section1[5:11])
pdt[_key[obj.pdtn]] = (
datetime.timedelta(hours=ivalue) + refdate
datetime.timedelta(hours=accumulation_offset) + refdate
).timetuple()[:6]
except KeyError:
if obj.pdtn == 48:
pdt[19] = ivalue / (
tables.get_value_from_table(pdt[18], "scale_time_hours")
)
else:
pdt[8] = ivalue / (
tables.get_value_from_table(pdt[7], "scale_time_hours")

# All messages need the leadTime value set, but for PDTNs 8-12, the
# leadTime value has to be set to the beginning of the accumulation
# period which is done here by subtracting the already calculated
# value accumulation_offset.
lead_time_index = 8
if obj.pdtn == 48:
lead_time_index = 19

ivalue = int(timedelta64(value, "h") / timedelta64(1, "h"))

pdt[lead_time_index] = (
ivalue
/ (
tables.get_value_from_table(
pdt[lead_time_index - 1], "scale_time_hours"
)
)
- accumulation_offset
)

class FixedSfc1Info:
"""Information of the first fixed surface via [table 4.5](https://www.nco.ncep.noaa.gov/pmb/docs/grib2/grib2_doc/grib2_table4-5.shtml)"""
Expand Down Expand Up @@ -1076,13 +1120,13 @@ def __set__(self, obj, value):
class ValueOfFirstFixedSurface:
"""Value of First Fixed Surface"""
def __get__(self, obj, objtype=None):
return obj.section4[ScaledValueOfFirstFixedSurface._key[obj.pdtn] + 2] / (
10.0 ** obj.section4[ScaleFactorOfFirstFixedSurface._key[obj.pdtn] + 2]
)
scale_factor = getattr(obj, "scaleFactorOfFirstFixedSurface")
scaled_value = getattr(obj, "scaledValueOfFirstFixedSurface")
return scaled_value / (10.**scale_factor)
def __set__(self, obj, value):
obj.section4[ScaledValueOfFirstFixedSurface._key[obj.pdtn] + 2] = value * (
10.0 ** obj.section4[ScaleFactorOfFirstFixedSurface._key[obj.pdtn] + 2]
)
scale = _calculate_scale_factor(value)
setattr(obj, "scaleFactorOfFirstFixedSurface", scale)
setattr(obj, "scaledValueOfFirstFixedSurface", value * 10**scale)

class TypeOfSecondFixedSurface:
"""[Type of Second Fixed Surface](https://www.nco.ncep.noaa.gov/pmb/docs/grib2/grib2_doc/grib2_table4-5.shtml)"""
Expand Down Expand Up @@ -1121,10 +1165,13 @@ def __set__(self, obj, value):
class ValueOfSecondFixedSurface:
"""Value of Second Fixed Surface"""
def __get__(self, obj, objtype=None):
return obj.section4[ScaledValueOfFirstFixedSurface._key[obj.pdtn]+2]/\
(10.**obj.section4[ScaleFactorOfFirstFixedSurface._key[obj.pdtn]+2])
scale_factor = getattr(obj, "scaleFactorOfSecondFixedSurface")
scaled_value = getattr(obj, "scaledValueOfSecondFixedSurface")
return scaled_value / (10.**scale_factor)
def __set__(self, obj, value):
pass
scale = _calculate_scale_factor(value)
setattr(obj, "scaleFactorOfSecondFixedSurface", scale)
setattr(obj, "scaledValueOfSecondFixedSurface", value * 10**scale)

class Level:
"""Level (same as provided by [wgrib2](https://github.com/NOAA-EMC/NCEPLIBS-wgrib2/blob/develop/wgrib2/Level.c))"""
Expand Down Expand Up @@ -1246,24 +1293,28 @@ def __set__(self, obj, value):
class ThresholdLowerLimit:
"""Threshold Lower Limit"""
def __get__(self, obj, objtype=None):
if obj.section4[18+2] == -127 and \
obj.section4[19+2] == 255:
scale_factor = getattr(obj, "scaleFactorOfThresholdLowerLimit")
scaled_value = getattr(obj, "scaledValueOfThresholdLowerLimit")
if scale_factor == -127 and scaled_value == 255:
return 0.0
else:
return obj.section4[19+2]/(10.**obj.section4[18+2])
return scaled_value / (10.**scale_factor)
def __set__(self, obj, value):
pass
scale = _calculate_scale_factor(value)
setattr(obj, "scaleFactorOfThresholdLowerLimit", scale)
setattr(obj, "scaledValueOfThresholdLowerLimit", value * 10**scale)

class ThresholdUpperLimit:
"""Threshold Upper Limit"""
def __get__(self, obj, objtype=None):
if obj.section4[20+2] == -127 and \
obj.section4[21+2] == 255:
scale_factor = getattr(obj, "scaleFactorOfThresholdUpperLimit")
scaled_value = getattr(obj, "scaledValueOfThresholdUpperLimit")
if scale_factor == -127 and scaled_value == 255:
return 0.0
else:
return obj.section4[21+2]/(10.**obj.section4[20+2])
return scaled_value / (10.**scale_factor)
def __set__(self, obj, value):
pass
scale = _calculate_scale_factor(value)
setattr(obj, "scaleFactorOfThresholdUpperLimit", scale)
setattr(obj, "scaledValueOfThresholdUpperLimit", value * 10**scale)

class Threshold:
"""Threshold string (same as [wgrib2](https://github.com/NOAA-EMC/NCEPLIBS-wgrib2/blob/develop/wgrib2/Prob.c))"""
Expand Down
4 changes: 2 additions & 2 deletions src/grib2io/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ def get_leadtime(idsec: ArrayLike, pdtn: int, pdt: ArrayLike) -> datetime.timede
Parameters
----------
idsec
Seqeunce containing GRIB2 Identification Section (Section 1).
Sequence containing GRIB2 Identification Section (Section 1).
pdtn
GRIB2 Product Definition Template Number
pdt
Seqeunce containing GRIB2 Product Definition Template (Section 4).
Sequence containing GRIB2 Product Definition Template (Section 4).
Returns
-------
Expand Down
10 changes: 5 additions & 5 deletions src/grib2io/xarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ def interp_to_stations(self, method, calls, lats, lons, method_options=None, num
ds = da.to_dataset(dim='variable')
return ds


def to_grib2(self, filename):
"""
Write a DataSet to a grib2 file.
Expand All @@ -729,7 +730,7 @@ def to_grib2(self, filename):
"""
ds = self._obj

for shortName in ds:
for shortName in sorted(ds):
# make a DataArray from the "Data Variables" in the DataSet
da = ds[shortName]

Expand Down Expand Up @@ -896,6 +897,7 @@ def interp_to_stations(self, method, calls, lats, lons, method_options=None, num
new_da.name = da.name
return new_da


def to_grib2(self, filename, mode="w"):
"""
Write a DataArray to a grib2 file.
Expand All @@ -919,13 +921,13 @@ def to_grib2(self, filename, mode="w"):
k for k in index_keys if k not in ["latitude", "longitude", "validDate"]
]
indexes = []
for index in index_keys:
for index in sorted(index_keys):
values = da.coords[index].values
if not isinstance(values, np.ndarray):
continue
if values.ndim != 1:
continue
listeach = [{index: value} for value in list(set(values))]
listeach = [{index: value} for value in sorted(set(values))]
indexes.append(listeach)

for selectors in itertools.product(*indexes):
Expand All @@ -950,8 +952,6 @@ def to_grib2(self, filename, mode="w"):
for index, value in filters.items():
setattr(newmsg, index, value)

newmsg.pack()

# write the message to file
with grib2io.open(filename, mode=mode) as f:
f.write(newmsg)
Expand Down

0 comments on commit e67a9be

Please sign in to comment.