Skip to content

Commit

Permalink
Merge pull request #10 from PlasmaFAIR/preprocessor
Browse files Browse the repository at this point in the history
Add jobid checking to preprocess
  • Loading branch information
ZedThree authored Aug 19, 2024
2 parents 5a71efc + 0d26176 commit 248369b
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 3 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ xarray:
### Single file loading
```python
import xarray as xr
from sdf_xarray import SDFPreprocess

df = xr.open_dataset("0010.sdf")

Expand All @@ -50,7 +51,7 @@ ds = xr.open_mfdataset(
data_vars='minimal',
coords='minimal',
compat='override',
preprocess=lambda ds: ds.expand_dims(time=[ds.attrs["time"]])
preprocess=SDFPreprocess()
)

print(ds)
Expand Down
18 changes: 18 additions & 0 deletions src/sdf_xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,21 @@ def guess_can_open(self, filename_or_obj):
description = "Use .sdf files in Xarray"

url = "https://epochpic.github.io/documentation/visualising_output/python.html"


class SDFPreprocess:
"""Preprocess SDF files for xarray ensuring matching job ids and sets time dimension"""

def __init__(self):
self.job_id: int | None = None

def __call__(self, ds: xr.Dataset) -> xr.Dataset:
if self.job_id is None:
self.job_id = ds.attrs["jobid1"]

if self.job_id != ds.attrs["jobid1"]:
raise ValueError(
f"Mismatching job ids (got {ds.attrs['jobid1']}, expected {self.job_id})"
)

return ds.expand_dims(time=[ds.attrs["time"]])
Binary file added tests/example_mismatched_files/0000.sdf
Binary file not shown.
Binary file added tests/example_mismatched_files/0001.sdf
Binary file not shown.
Binary file added tests/example_mismatched_files/0002.sdf
Binary file not shown.
20 changes: 18 additions & 2 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import pathlib

import xarray as xr
from sdf_xarray import open_mfdataset

from sdf_xarray import open_mfdataset, SDFPreprocess
import pytest

EXAMPLE_FILES_DIR = pathlib.Path(__file__).parent / "example_files"
EXAMPLE_MISMATCHED_FILES_DIR = (
pathlib.Path(__file__).parent / "example_mismatched_files"
)


def test_basic():
Expand Down Expand Up @@ -42,3 +45,16 @@ def test_multiple_files_multiple_time_dims():
assert list(df["Electric Field/Ex"].coords) != list(df["Electric Field/Ez"].coords)
assert df["Electric Field/Ex"].shape == (11, 16)
assert df["Electric Field/Ez"].shape == (1, 16)


def test_erroring_on_mismatched_jobid_files():
with pytest.raises(ValueError):
xr.open_mfdataset(
EXAMPLE_MISMATCHED_FILES_DIR.glob("*.sdf"),
concat_dim="time",
combine="nested",
data_vars="minimal",
coords="minimal",
compat="override",
preprocess=SDFPreprocess(),
)

0 comments on commit 248369b

Please sign in to comment.