Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add or edit multiple __set__ functions. #150

Merged
merged 1 commit into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 84 additions & 33 deletions src/grib2io/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@
8:[],}


def _calculate_scale_factor(value: float):
"""
Calculate the scale factor for a given value.

Parameters
----------
value : float
Value for which to calculate the scale factor.

Returns
-------
int
Scale factor for the value.
"""
return len(f"{value}".split(".")[1].rstrip("0"))


class Grib2Metadata:
"""
Class to hold GRIB2 metadata.
Expand Down Expand Up @@ -988,34 +1005,61 @@ def __set__(self, obj, value):

class LeadTime:
"""Forecast Lead Time. NOTE: This is a `datetime.timedelta` object."""

def __get__(self, obj, objtype=None):
return utils.get_leadtime(obj.section1,obj.section4[1],
obj.section4[2:])
return utils.get_leadtime(obj.section1, obj.section4[1], obj.section4[2:])

def __set__(self, obj, value):
pdt = obj.section4[2:]

# For the tables below, the key is the PDTN and the value is the slice
# of the PDT that contains the end date of the accumulation.
# This is only needed for PDTNs 8-12.
_key = {
8: slice(15, 21),
9: slice(22, 28),
10: slice(16, 22),
11: slice(18, 24),
12: slice(17, 23),
}
refdate = datetime.datetime(*obj.section1[5:11])
ivalue = int(timedelta64(value, "h") / timedelta64(1, "h"))
try:

accumulation_offset = 0
if obj.pdtn in _key:
accumulation_end_date = _key[obj.pdtn]
accumulation_key = accumulation_end_date.stop + 5
accumulation_offset = int(
timedelta64(pdt[accumulation_key], "h") / timedelta64(1, "h")
)
accumulation_offset = accumulation_offset / (
tables.get_value_from_table(
pdt[accumulation_key - 1], "scale_time_hours"
)
)

refdate = datetime.datetime(*obj.section1[5:11])
pdt[_key[obj.pdtn]] = (
datetime.timedelta(hours=ivalue) + refdate
datetime.timedelta(hours=accumulation_offset) + refdate
).timetuple()[:6]
except KeyError:
if obj.pdtn == 48:
pdt[19] = ivalue / (
tables.get_value_from_table(pdt[18], "scale_time_hours")
)
else:
pdt[8] = ivalue / (
tables.get_value_from_table(pdt[7], "scale_time_hours")

# All messages need the leadTime value set, but for PDTNs 8-12, the
# leadTime value has to be set to the beginning of the accumulation
# period which is done here by subtracting the already calculated
# value accumulation_offset.
lead_time_index = 8
if obj.pdtn == 48:
lead_time_index = 19

ivalue = int(timedelta64(value, "h") / timedelta64(1, "h"))

pdt[lead_time_index] = (
ivalue
/ (
tables.get_value_from_table(
pdt[lead_time_index - 1], "scale_time_hours"
)
)
- accumulation_offset
)

class FixedSfc1Info:
"""Information of the first fixed surface via [table 4.5](https://www.nco.ncep.noaa.gov/pmb/docs/grib2/grib2_doc/grib2_table4-5.shtml)"""
Expand Down Expand Up @@ -1076,13 +1120,13 @@ def __set__(self, obj, value):
class ValueOfFirstFixedSurface:
"""Value of First Fixed Surface"""
def __get__(self, obj, objtype=None):
return obj.section4[ScaledValueOfFirstFixedSurface._key[obj.pdtn] + 2] / (
10.0 ** obj.section4[ScaleFactorOfFirstFixedSurface._key[obj.pdtn] + 2]
)
scale_factor = getattr(obj, "scaleFactorOfFirstFixedSurface")
scaled_value = getattr(obj, "scaledValueOfFirstFixedSurface")
return scaled_value / (10.**scale_factor)
def __set__(self, obj, value):
obj.section4[ScaledValueOfFirstFixedSurface._key[obj.pdtn] + 2] = value * (
10.0 ** obj.section4[ScaleFactorOfFirstFixedSurface._key[obj.pdtn] + 2]
)
scale = _calculate_scale_factor(value)
setattr(obj, "scaleFactorOfFirstFixedSurface", scale)
setattr(obj, "scaledValueOfFirstFixedSurface", value * 10**scale)

class TypeOfSecondFixedSurface:
"""[Type of Second Fixed Surface](https://www.nco.ncep.noaa.gov/pmb/docs/grib2/grib2_doc/grib2_table4-5.shtml)"""
Expand Down Expand Up @@ -1121,10 +1165,13 @@ def __set__(self, obj, value):
class ValueOfSecondFixedSurface:
"""Value of Second Fixed Surface"""
def __get__(self, obj, objtype=None):
return obj.section4[ScaledValueOfFirstFixedSurface._key[obj.pdtn]+2]/\
(10.**obj.section4[ScaleFactorOfFirstFixedSurface._key[obj.pdtn]+2])
scale_factor = getattr(obj, "scaleFactorOfSecondFixedSurface")
scaled_value = getattr(obj, "scaledValueOfSecondFixedSurface")
return scaled_value / (10.**scale_factor)
def __set__(self, obj, value):
pass
scale = _calculate_scale_factor(value)
setattr(obj, "scaleFactorOfSecondFixedSurface", scale)
setattr(obj, "scaledValueOfSecondFixedSurface", value * 10**scale)

class Level:
"""Level (same as provided by [wgrib2](https://github.com/NOAA-EMC/NCEPLIBS-wgrib2/blob/develop/wgrib2/Level.c))"""
Expand Down Expand Up @@ -1246,24 +1293,28 @@ def __set__(self, obj, value):
class ThresholdLowerLimit:
"""Threshold Lower Limit"""
def __get__(self, obj, objtype=None):
if obj.section4[18+2] == -127 and \
obj.section4[19+2] == 255:
scale_factor = getattr(obj, "scaleFactorOfThresholdLowerLimit")
scaled_value = getattr(obj, "scaledValueOfThresholdLowerLimit")
if scale_factor == -127 and scaled_value == 255:
return 0.0
else:
return obj.section4[19+2]/(10.**obj.section4[18+2])
return scaled_value / (10.**scale_factor)
def __set__(self, obj, value):
pass
scale = _calculate_scale_factor(value)
setattr(obj, "scaleFactorOfThresholdLowerLimit", scale)
setattr(obj, "scaledValueOfThresholdLowerLimit", value * 10**scale)

class ThresholdUpperLimit:
"""Threshold Upper Limit"""
def __get__(self, obj, objtype=None):
if obj.section4[20+2] == -127 and \
obj.section4[21+2] == 255:
scale_factor = getattr(obj, "scaleFactorOfThresholdUpperLimit")
scaled_value = getattr(obj, "scaledValueOfThresholdUpperLimit")
if scale_factor == -127 and scaled_value == 255:
return 0.0
else:
return obj.section4[21+2]/(10.**obj.section4[20+2])
return scaled_value / (10.**scale_factor)
def __set__(self, obj, value):
pass
scale = _calculate_scale_factor(value)
setattr(obj, "scaleFactorOfThresholdUpperLimit", scale)
setattr(obj, "scaledValueOfThresholdUpperLimit", value * 10**scale)

class Threshold:
"""Threshold string (same as [wgrib2](https://github.com/NOAA-EMC/NCEPLIBS-wgrib2/blob/develop/wgrib2/Prob.c))"""
Expand Down
4 changes: 2 additions & 2 deletions src/grib2io/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ def get_leadtime(idsec: ArrayLike, pdtn: int, pdt: ArrayLike) -> datetime.timede
Parameters
----------
idsec
Seqeunce containing GRIB2 Identification Section (Section 1).
Sequence containing GRIB2 Identification Section (Section 1).
pdtn
GRIB2 Product Definition Template Number
pdt
Seqeunce containing GRIB2 Product Definition Template (Section 4).
Sequence containing GRIB2 Product Definition Template (Section 4).

Returns
-------
Expand Down
10 changes: 5 additions & 5 deletions src/grib2io/xarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ def interp_to_stations(self, method, calls, lats, lons, method_options=None, num
ds = da.to_dataset(dim='variable')
return ds


def to_grib2(self, filename):
"""
Write a DataSet to a grib2 file.
Expand All @@ -729,7 +730,7 @@ def to_grib2(self, filename):
"""
ds = self._obj

for shortName in ds:
for shortName in sorted(ds):
# make a DataArray from the "Data Variables" in the DataSet
da = ds[shortName]

Expand Down Expand Up @@ -896,6 +897,7 @@ def interp_to_stations(self, method, calls, lats, lons, method_options=None, num
new_da.name = da.name
return new_da


def to_grib2(self, filename, mode="w"):
"""
Write a DataArray to a grib2 file.
Expand All @@ -919,13 +921,13 @@ def to_grib2(self, filename, mode="w"):
k for k in index_keys if k not in ["latitude", "longitude", "validDate"]
]
indexes = []
for index in index_keys:
for index in sorted(index_keys):
values = da.coords[index].values
if not isinstance(values, np.ndarray):
continue
if values.ndim != 1:
continue
listeach = [{index: value} for value in list(set(values))]
listeach = [{index: value} for value in sorted(set(values))]
indexes.append(listeach)

for selectors in itertools.product(*indexes):
Expand All @@ -950,8 +952,6 @@ def to_grib2(self, filename, mode="w"):
for index, value in filters.items():
setattr(newmsg, index, value)

newmsg.pack()

# write the message to file
with grib2io.open(filename, mode=mode) as f:
f.write(newmsg)
Expand Down
Loading