Skip to content

Commit

Permalink
Write open to "x" new, "w" overwrite, "a" append. (#151)
Browse files Browse the repository at this point in the history
Added "x" to available open modes and set that to the new default
where "x" is new file that will not overwrite, "w" is a new file
that will overwrite and "a" is for appending.  This mimics the
allowable modes for builtins.open used in grib2io.open.
  • Loading branch information
TimothyCera-NOAA authored Jun 6, 2024
1 parent e67a9be commit 68e2103
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 24 deletions.
33 changes: 19 additions & 14 deletions src/grib2io/_grib2io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,22 @@
"""

from dataclasses import dataclass, field
from numpy.typing import NDArray
from typing import Union, Optional
from typing import Literal, Optional, Union
import builtins
import collections
import copy
import datetime
import hashlib
import numpy as np
import os
import pyproj
import re
import struct
import sys
import warnings

from numpy.typing import NDArray
import numpy as np
import pyproj

from . import g2clib
from . import tables
from . import templates
Expand Down Expand Up @@ -91,10 +92,12 @@ class `grib2io.open`, the file named `filename` is opened for reading (`mode
Tuple containing a unique list of variable short names (i.e. GRIB2
abbreviation names).
"""

__slots__ = ('_fileid', '_filehandle', '_hasindex', '_index', '_nodata',
'_pos', 'closed', 'current_message', 'messages', 'mode',
'name', 'size')
def __init__(self, filename: str, mode: str="r", **kwargs):

def __init__(self, filename: str, mode: Literal["r", "w", "x"] = "r", **kwargs):
"""
Initialize GRIB2 File object instance.
Expand All @@ -104,18 +107,20 @@ def __init__(self, filename: str, mode: str="r", **kwargs):
File name containing GRIB2 messages.
mode: default="r"
File access mode where "r" opens the files for reading only; "w"
opens the file for writing.
opens the file for overwriting and "x" for writing to a new file.
"""
# Manage keywords
if "_xarray_backend" not in kwargs:
kwargs["_xarray_backend"] = False
self._nodata = False
else:
self._nodata = kwargs["_xarray_backend"]
if mode in {'a','r','w'}:
mode = mode+'b'
if 'w' in mode: mode += '+'
if 'a' in mode: mode += '+'

# All write modes are read/write.
# All modes are binary.
if mode in ("a", "x", "w"):
mode += "+"
mode = mode + "b"

# Some GRIB2 files are gzipped, so check for that here, but
# raise error when using xarray backend.
Expand Down Expand Up @@ -534,7 +539,7 @@ class Grib2Message:
inherits from `_Grib2Message` and grid, product, data representation
template classes according to the template numbers for the respective
sections. If `section3`, `section4`, or `section5` are omitted, then
the appropriate keyword arguments for the template number `gdtn=`,
the appropriate keyword arguments for the template number `gdtn=`,
`pdtn=`, or `drtn=` must be provided.
Parameters
Expand All @@ -551,12 +556,12 @@ class Grib2Message:
GRIB2 section 4 array.
section5
GRIB2 section 5 array.
Returns
-------
Msg
A dynamically-create Grib2Message object that inherits from
_Grib2Message, a grid definition template class, product
_Grib2Message, a grid definition template class, product
definition template class, and a data representation template
class.
"""
Expand Down Expand Up @@ -1493,7 +1498,7 @@ def set_auto_nans(value: bool):
raise TypeError(f"Argument must be bool")


def interpolate(a, method: Union[int, str], grid_def_in, grid_def_out,
def interpolate(a, method: Union[int, str], grid_def_in, grid_def_out,
method_options=None, num_threads=1):
"""
This is the module-level interpolation function.
Expand Down
39 changes: 31 additions & 8 deletions src/grib2io/xarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass, field, astuple
import itertools
import logging
from pathlib import Path
import typing

import numpy as np
Expand Down Expand Up @@ -718,23 +719,36 @@ def interp_to_stations(self, method, calls, lats, lons, method_options=None, num
ds = da.to_dataset(dim='variable')
return ds


def to_grib2(self, filename):
def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"):
"""
Write a DataSet to a grib2 file.
Parameters
----------
filename
Name of the grib2 file to write to.
mode: {"x", "w", "a"}, optional, default="x"
Persistence mode
+------+-----------------------------------+
| mode | Description |
+======+===================================+
| x | create (fail if exists) |
+------+-----------------------------------+
| w | create (overwrite if exists) |
+------+-----------------------------------+
| a | append (create if does not exist) |
+------+-----------------------------------+
"""
ds = self._obj

for shortName in sorted(ds):
# make a DataArray from the "Data Variables" in the DataSet
da = ds[shortName]

da.grib2io.to_grib2(filename, mode="a")
da.grib2io.to_grib2(filename, mode=mode)
mode = "a"


@xr.register_dataarray_accessor("grib2io")
Expand Down Expand Up @@ -897,18 +911,27 @@ def interp_to_stations(self, method, calls, lats, lons, method_options=None, num
new_da.name = da.name
return new_da


def to_grib2(self, filename, mode="w"):
def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"):
"""
Write a DataArray to a grib2 file.
Parameters
----------
filename
Name of the grib2 file to write to.
mode
Mode to open the file in. Can be 'w' for write or 'a' for append.
Default is 'w'.
mode: {"x", "w", "a"}, optional, default="x"
Persistence mode
+------+-----------------------------------+
| mode | Description |
+======+===================================+
| x | create (fail if exists) |
+------+-----------------------------------+
| w | create (overwrite if exists) |
+------+-----------------------------------+
| a | append (create if does not exist) |
+------+-----------------------------------+
"""
da = self._obj.copy(deep=True)

Expand Down
24 changes: 22 additions & 2 deletions tests/test_to_grib2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import itertools
from pathlib import Path

import grib2io
import pytest
import xarray as xr
from numpy.testing import assert_allclose, assert_array_equal
Expand Down Expand Up @@ -47,7 +49,12 @@ def test_da_write(tmp_path, request):
filters=filters,
)

ds1["TMP"].grib2io.to_grib2(target_file)
Path(target_file).touch()

with pytest.raises(FileExistsError):
ds1["TMP"].grib2io.to_grib2(target_file)

ds1["TMP"].grib2io.to_grib2(target_file, mode="w")

ds2 = xr.open_dataset(target_file, engine="grib2io")

Expand All @@ -74,13 +81,26 @@ def test_ds_write(tmp_path, request):
filters=filters,
)

ds1.grib2io.to_grib2(target_file)
Path(target_file).touch()

with pytest.raises(FileExistsError):
ds1.grib2io.to_grib2(target_file)

ds1.grib2io.to_grib2(target_file, mode="w")

ds2 = xr.open_dataset(target_file, engine="grib2io")

ds2_msgs = grib2io.open(target_file)

for var in ["APTMP", "DPT", "RH", "SPFH", "TMP"]:
_test_any_differences(ds1[var], ds2[var])

ds1.grib2io.to_grib2(target_file, mode="a")

ds3_msgs = grib2io.open(target_file)

assert len(ds3_msgs) == 2 * len(ds2_msgs)


def test_ds_write_levels(tmp_path, request):
"""Test writing a Dataset with multiple levels to multiple grib2 messages."""
Expand Down

0 comments on commit 68e2103

Please sign in to comment.