Skip to content

Commit

Permalink
Added spatial subset for both messages and grib2io xarray.
Browse files Browse the repository at this point in the history
  • Loading branch information
TimothyCera-NOAA committed Aug 4, 2024
1 parent d9552d2 commit 03c861a
Show file tree
Hide file tree
Showing 4 changed files with 348 additions and 321 deletions.
424 changes: 183 additions & 241 deletions demos/plotting_examples.ipynb

Large diffs are not rendered by default.

60 changes: 30 additions & 30 deletions src/grib2io/_grib2io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,9 @@ def subset(self, lats, lons):
A spatial subset of a GRIB2 message.
"""
if self.gdtn not in [0, 1, 40, 10, 20, 30, 31, 110, 32769]:
raise ValueError('Subset only works for regular lat/lon, Gaussian, mercator, stereographic, lambert conformal, albers equal-area, and azimuthal equidistant grids.')
raise ValueError(
"Subset only works for regular lat/lon, Gaussian, mercator, stereographic, lambert conformal, albers equal-area, and azimuthal equidistant grids. Grid Definition Template Numbers of 0, 1, 40, 10, 20, 30, 31, 110, and 32769 are supported."
)

newmsg = Grib2Message(
np.copy(self.section0),
Expand All @@ -1384,44 +1386,42 @@ def subset(self, lats, lons):

msglats, msglons = self.grid()

la1 = np.min(lats)
la2 = np.max(lats)
la1 = np.max(lats)
lo1 = np.min(lons)
la2 = np.min(lats)
lo2 = np.max(lons)

first_lat = np.abs(msglats - la1)
first_lon = np.abs(msglons - lo1)
max_idx = np.maximum(first_lon, first_lat)
first_i, first_j = np.where(max_idx == np.min(max_idx))
max_idx = np.maximum(first_lat, first_lon)
first_j, first_i = np.where(max_idx == np.min(max_idx))

print("first_i, first_j", first_i, first_j)
last_lat = np.abs(msglats - la2)
last_lon = np.abs(msglons - lo2)
max_idx = np.maximum(last_lon, last_lat)
last_i, last_j = np.where(max_idx == np.min(max_idx))
print("last_i, last_j", last_i, last_j)

setattr(newmsg, "latitudeFirstGridpoint" , msglats[first_i[0], first_j[0]])
print("latitudeFirstGridpoint", newmsg.latitudeFirstGridpoint)
setattr(newmsg, "longitudeFirstGridpoint" , msglons[first_i[0], first_j[0]])
print("longitudeFirstGridpoint", newmsg.longitudeFirstGridpoint)
setattr(newmsg, "nx" , np.abs(first_i[0] - last_i[0]))
setattr(newmsg, "ny" , np.abs(first_j[0] - last_j[0]))
print("newmsg.nx, newmsg.ny", newmsg.nx, newmsg.ny)
print(self._data.shape)
setattr(newmsg, "data" , np.copy(self._data[
min(first_i[0] , last_i[0]) : max(first_i[0] , last_i[0]),
min(first_j[0] , last_j[0]) : max(first_j[0] , last_j[0])]))
if self.gdtn in [0, 1, 40]:
setattr(newmsg, "latitudeLastGridpoint" , msglats[last_i[0], last_j[0]])
print("latitudeLastGridpoint", newmsg.latitudeLastGridpoint)
setattr(newmsg, "longitudeLastGridpoint" , msglons[last_i[0], last_j[0]])
print("longitudeLastGridpoint", newmsg.longitudeLastGridpoint)
if self._sha1_section3 in _latlon_datastore.keys():
del _latlon_datastore[self._sha1_section3]
max_idx = np.maximum(last_lat, last_lon)
last_j, last_i = np.where(max_idx == np.min(max_idx))

setattr(newmsg, "latitudeFirstGridpoint", msglats[first_j[0], first_i[0]])
setattr(newmsg, "longitudeFirstGridpoint", msglons[first_j[0], first_i[0]])
setattr(newmsg, "nx", np.abs(first_i[0] - last_i[0]))
setattr(newmsg, "ny", np.abs(first_j[0] - last_j[0]))

# Set *LastGridpoint attributes even if only used for gdtn=[0,1,40].
# This information is used to subset xarray datasets.
setattr(newmsg, "latitudeLastGridpoint", msglats[last_j[0], last_i[0]])
setattr(newmsg, "longitudeLastGridpoint", msglons[last_j[0], last_i[0]])

setattr(
newmsg,
"data",
self.data[
min(first_j[0], last_j[0]) : max(first_j[0], last_j[0]),
min(first_i[0], last_i[0]) : max(first_i[0], last_i[0]),
].copy(),
)

newmsg._sha1_section3 = ""
newmsg.grid()
print(newmsg.nx, newmsg.ny)
print(newmsg.grid())

return newmsg

Expand Down
64 changes: 64 additions & 0 deletions src/grib2io/xarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,29 @@ def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"):
da.grib2io.to_grib2(filename, mode=mode)
mode = "a"

def subset(self, lats, lons) -> xr.Dataset:
"""
Subset the DataSet to a region defined by latitudes and longitudes.
Parameters
----------
lats
Latitude bounds of the region.
lons
Longitude bounds of the region.
Returns
-------
subset
DataSet subset to the region.
"""
ds = self._obj

newds = xr.Dataset()
for shortName in ds:
newds[shortName] = ds[shortName].grib2io.subset(lats, lons).copy()

return newds

@xr.register_dataarray_accessor("grib2io")
class Grib2ioDataArray:
Expand Down Expand Up @@ -1014,3 +1037,44 @@ def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"):
with grib2io.open(filename, mode=mode) as f:
f.write(newmsg)
mode = "a"

def subset(self, lats, lons) -> xr.DataArray:
"""
Subset the DataArray to a region defined by latitudes and longitudes.
Parameters
----------
lats
Latitude bounds of the region.
lons
Longitude bounds of the region.
Returns
-------
subset
DataArray subset to the region.
"""
da = self._obj.copy(deep=True)

newmsg = Grib2Message(
da.attrs["GRIB2IO_section0"],
da.attrs["GRIB2IO_section1"],
da.attrs["GRIB2IO_section2"],
da.attrs["GRIB2IO_section3"],
da.attrs["GRIB2IO_section4"],
da.attrs["GRIB2IO_section5"],
)
newmsg.data = np.copy(da.values)

newmsg = newmsg.subset(lats, lons)

da.attrs["GRIB2IO_section3"] = newmsg.section3

mask_lat = (da.latitude >= newmsg.latitudeLastGridpoint) & (
da.latitude <= newmsg.latitudeFirstGridpoint
)
mask_lon = (da.longitude >= newmsg.longitudeFirstGridpoint) & (
da.longitude <= newmsg.longitudeLastGridpoint
)

return da.where((mask_lon & mask_lat).compute(), drop=True)
121 changes: 71 additions & 50 deletions tests/test_subset.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,86 @@
import itertools
from pathlib import Path

import grib2io
import pytest
import xarray as xr
from numpy.testing import assert_allclose, assert_array_equal
from numpy.testing import assert_array_equal

import grib2io

def _del_list_inplace(input_list, indices):
for index in sorted(indices, reverse=True):
del input_list[index]
return input_list

@pytest.fixture()
def inp_ds(request):
datadir = request.config.rootdir / "tests" / "data" / "gfs_20221107"

def _test_any_differences(da1, da2, atol=0.005, rtol=0):
"""Test if two DataArrays are equal, including most attributes."""
assert_array_equal(
da1.attrs["GRIB2IO_section0"][:-1], da2.attrs["GRIB2IO_section0"][:-1]
)
assert_array_equal(da1.attrs["GRIB2IO_section1"], da2.attrs["GRIB2IO_section1"])
assert_array_equal(da1.attrs["GRIB2IO_section2"], da2.attrs["GRIB2IO_section2"])
assert_array_equal(da1.attrs["GRIB2IO_section3"], da2.attrs["GRIB2IO_section3"])
assert_array_equal(da1.attrs["GRIB2IO_section4"], da2.attrs["GRIB2IO_section4"])
skip = [2, 9, 10, 11, 16, 17]
assert_array_equal(
_del_list_inplace(list(da1.attrs["GRIB2IO_section5"]), skip),
_del_list_inplace(list(da2.attrs["GRIB2IO_section5"]), skip),
filters = {
"typeOfFirstFixedSurface": 103,
"valueOfFirstFixedSurface": 2,
"productDefinitionTemplateNumber": 0,
"shortName": "TMP",
}

ids = xr.open_mfdataset(
[
datadir / "gfs.t00z.pgrb2.1p00.f009_subset",
datadir / "gfs.t00z.pgrb2.1p00.f012_subset",
],
combine="nested",
concat_dim="leadTime",
engine="grib2io",
filters=filters,
)
assert_allclose(da1.data, da2.data, atol=atol, rtol=rtol)

yield ids

def test_da_write(tmp_path, request):
"""Test writing a single DataArray to a single grib2 message."""
target_dir = tmp_path / "test_to_grib2"
target_dir.mkdir()
target_file = target_dir / "test_to_grib2_da.grib2"

@pytest.fixture()
def inp_msgs(request):
datadir = request.config.rootdir / "tests" / "data" / "gfs_20221107"

with grib2io.open(datadir / "gfs.t00z.pgrb2.1p00.f012_subset") as inp:
print(inp[0].section3)
newmsg = inp[0].subset(lats=(43, 32.7), lons=(117, 79))
with grib2io.open(datadir / "gfs.t00z.pgrb2.1p00.f012_subset") as imsgs:
yield imsgs


print(inp[0])
print(newmsg)
print(newmsg.section0)
print(inp[0].section0)
print(newmsg.section1)
print(inp[0].section1)
print(newmsg.section2)
print(inp[0].section2)
print(newmsg.section3)
print(inp[0].section3)
print(newmsg.section4)
print(inp[0].section4)
print(newmsg.section5)
print(inp[0].section5)
@pytest.mark.parametrize(
"lats, lons, expected_section3",
[
pytest.param(
(43, 32.7),
(117, 79),
[
0,
380,
0,
0,
0,
6,
0,
0,
0,
0,
0,
0,
38,
10,
0,
-1,
43000000,
79000000,
48,
33000000,
117000000,
1000000,
1000000,
0,
],
id="subset_1",
),
],
)
def test_message_subset(inp_msgs, inp_ds, lats, lons, expected_section3):
"""Test subsetting a single DataArray to a single grib2 message."""
newmsg = inp_msgs[0].subset(lats=lats, lons=lons)
assert_array_equal(newmsg.section3, expected_section3)

print(inp[0].data.shape)
print(newmsg.data.shape)
newds = inp_ds["TMP"].grib2io.subset(lats=lats, lons=lons)
assert_array_equal(newds.attrs["GRIB2IO_section3"], expected_section3)

with grib2io.open(target_file, mode="w") as out:
out.write(newmsg)
assert False
newds = inp_ds.grib2io.subset(lats=lats, lons=lons)
assert_array_equal(newds["TMP"].attrs["GRIB2IO_section3"], expected_section3)

0 comments on commit 03c861a

Please sign in to comment.