Skip to content

Commit

Permalink
Merge pull request #11 from PlasmaFAIR/particle-data
Browse files Browse the repository at this point in the history
Read particle data
  • Loading branch information
ZedThree authored Aug 23, 2024
2 parents 248369b + 1b663db commit 4fcdc55
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 25 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ print(df["Electric Field/Ex"])
```python
ds = xr.open_mfdataset(
"*.sdf",
concat_dim="time",
combine="nested",
data_vars='minimal',
coords='minimal',
Expand All @@ -63,3 +62,13 @@ print(ds)
# Indexes: (9)
# Attributes: (22)
```

### Reading particle data

By default, particle data isn't kept. Pass `keep_particles=True` as a
keyword argument to `open_dataset` (for single files) or
`open_mfdataset` (for multiple files):

```python
df = xr.open_dataset("0010.sdf", keep_particles=True)
```
80 changes: 59 additions & 21 deletions src/sdf_xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
from . import sdf


def combine_datasets(path_glob: Iterable | str) -> xr.Dataset:
def combine_datasets(path_glob: Iterable | str, **kwargs) -> xr.Dataset:
"""Combine all datasets using a single time dimension"""

return xr.open_mfdataset(
path_glob, preprocess=lambda ds: ds.expand_dims(time=[ds.attrs["time"]])
)
return xr.open_mfdataset(path_glob, preprocess=SDFPreprocess(), **kwargs)


def open_mfdataset(
path_glob: Iterable | str | pathlib.Path | pathlib.Path.glob,
*,
separate_times: bool = False,
keep_particles: bool = False,
) -> xr.Dataset:
"""Open a set of EPOCH SDF files as one `xarray.Dataset`
Expand Down Expand Up @@ -48,7 +48,8 @@ def open_mfdataset(
separate_times :
If ``True``, create separate time dimensions for variables defined at
different output frequencies
keep_particles :
If ``True``, also load particle data (this may use a lot of memory!)
"""

# TODO: This is not very robust, look at how xarray.open_mfdataset does it
Expand All @@ -59,16 +60,24 @@ def open_mfdataset(
path_glob = sorted(list(path_glob))

if not separate_times:
return combine_datasets(path_glob)
return combine_datasets(path_glob, keep_particles=keep_particles)

time_dims, var_times_map = make_time_dims(path_glob)
all_dfs = [xr.open_dataset(f) for f in path_glob]
all_dfs = [xr.open_dataset(f, keep_particles=keep_particles) for f in path_glob]

for df in all_dfs:
for da in df:
df[da] = df[da].expand_dims(
dim={var_times_map[str(da)]: [df.attrs["time"]]}
)
for coord in df.coords:
if "Particles" in coord:
# We need to undo our renaming of the coordinates
base_name = coord.split("_", maxsplit=1)[-1]
sdf_coord_name = f"Grid/{base_name}"
df.coords[coord] = df.coords[coord].expand_dims(
dim={var_times_map[sdf_coord_name]: [df.attrs["time"]]}
)

return xr.merge(all_dfs)

Expand Down Expand Up @@ -109,7 +118,7 @@ def make_time_dims(path_glob):
return time_dims, var_times_map


def open_sdf_dataset(filename_or_obj, *, drop_variables=None):
def open_sdf_dataset(filename_or_obj, *, drop_variables=None, keep_particles=False):
if isinstance(filename_or_obj, pathlib.Path):
# sdf library takes a filename only
# TODO: work out if we need to deal with file handles
Expand All @@ -131,28 +140,41 @@ def open_sdf_dataset(filename_or_obj, *, drop_variables=None):
data_vars = {}
coords = {}

def _norm_grid_name(grid_name: str) -> str:
"""There may be multiple grids all with the same coordinate names, so
drop the "Grid/" from the start, and append the rest to the
dimension name. This lets us disambiguate them all. Probably"""
return grid_name.split("/", maxsplit=1)[-1]

def _grid_species_name(grid_name: str) -> str:
return grid_name.split("/")[-1]

# Read and convert SDF variables and meshes to xarray DataArrays and Coordinates
for key, value in data.items():
if "CPU" in key:
# Had some problems with these variables, so just ignore them for now
continue

if not keep_particles and "particles" in key.lower():
continue

if isinstance(value, sdf.BlockConstant):
# This might have consequences when reading in multiple files?
attrs[key] = value.data

elif isinstance(value, sdf.BlockPlainMesh):
elif isinstance(value, (sdf.BlockPlainMesh, sdf.BlockPointMesh)):
# These are Coordinates

# There may be multiple grids all with the same coordinate names, so
# drop the "Grid/" from the start, and append the rest to the
# dimension name. This lets us disambiguate them all. Probably
base_name = key.split("/", maxsplit=1)[-1]
is_point_mesh = isinstance(value, sdf.BlockPointMesh)
base_name = _norm_grid_name(key)

for label, coord, unit in zip(value.labels, value.data, value.units):
full_name = f"{label}_{base_name}"
dim_name = (
f"ID_{_grid_species_name(key)}" if is_point_mesh else full_name
)
coords[full_name] = (
full_name,
dim_name,
coord,
{"long_name": label, "units": unit},
)
Expand All @@ -170,13 +192,11 @@ def open_sdf_dataset(filename_or_obj, *, drop_variables=None):
# Then we can look up the dimension label and size to get *our* name
# for the corresponding coordinate
dim_size_lookup = defaultdict(dict)

# TODO: remove duplication with coords branch
grid_base_name = value.grid.name.split("/", maxsplit=1)[-1]
grid_base_name = _norm_grid_name(value.grid.name)
for dim_size, dim_name in zip(value.grid.dims, value.grid.labels):
dim_size_lookup[dim_name][dim_size] = f"{dim_name}_{grid_base_name}"

grid_mid_base_name = value.grid_mid.name.split("/", maxsplit=1)[-1]
grid_mid_base_name = _norm_grid_name(value.grid_mid.name)
for dim_size, dim_name in zip(value.grid_mid.dims, value.grid_mid.labels):
dim_size_lookup[dim_name][dim_size] = f"{dim_name}_{grid_mid_base_name}"

Expand All @@ -188,6 +208,12 @@ def open_sdf_dataset(filename_or_obj, *, drop_variables=None):
data_attrs = {"units": value.units}
data_vars[key] = (var_coords, value.data, data_attrs)

elif isinstance(value, sdf.BlockPointVariable):
# Point (particle) variables are 1D
var_coords = (f"ID_{_grid_species_name(key)}",)
data_attrs = {"units": value.units}
data_vars[key] = (var_coords, value.data, data_attrs)

# TODO: might need to decode if mult is set?

# # see also conventions.decode_cf_variables
Expand All @@ -209,10 +235,15 @@ def open_dataset(
filename_or_obj,
*,
drop_variables=None,
keep_particles=False,
):
return open_sdf_dataset(filename_or_obj, drop_variables=drop_variables)
return open_sdf_dataset(
filename_or_obj,
drop_variables=drop_variables,
keep_particles=keep_particles,
)

open_dataset_parameters = ["filename_or_obj", "drop_variables"]
open_dataset_parameters = ["filename_or_obj", "drop_variables", "keep_particles"]

def guess_can_open(self, filename_or_obj):
magic_number = try_read_magic_number_from_path(filename_or_obj)
Expand Down Expand Up @@ -245,4 +276,11 @@ def __call__(self, ds: xr.Dataset) -> xr.Dataset:
f"Mismatching job ids (got {ds.attrs['jobid1']}, expected {self.job_id})"
)

return ds.expand_dims(time=[ds.attrs["time"]])
ds = ds.expand_dims(time=[ds.attrs["time"]])

# Particles' spartial coordinates also evolve in time
for coord, value in ds.coords.items():
if "Particles" in coord:
ds.coords[coord] = value.expand_dims(time=[ds.attrs["time"]])

return ds
32 changes: 29 additions & 3 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,23 @@ def test_coords():
assert df[x_coord].attrs["long_name"] == "Px"


def test_particles():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0010.sdf", keep_particles=True) as df:
px_protons = "Particles/Px/proton"
assert px_protons in df
x_coord = "X_Particles/proton"
assert x_coord in df[px_protons].coords
assert df[x_coord].attrs["long_name"] == "X"


def test_no_particles():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0010.sdf", keep_particles=False) as df:
px_protons = "Particles/Px/proton"
assert px_protons not in df


def test_multiple_files_one_time_dim():
df = open_mfdataset(EXAMPLE_FILES_DIR.glob("*.sdf"))
df = open_mfdataset(EXAMPLE_FILES_DIR.glob("*.sdf"), keep_particles=True)
ex_field = df["Electric Field/Ex"]
assert sorted(ex_field.coords) == sorted(("X_Grid_mid", "time"))
assert ex_field.shape == (11, 16)
Expand All @@ -38,20 +53,31 @@ def test_multiple_files_one_time_dim():
assert sorted(ez_field.coords) == sorted(("X_Grid_mid", "time"))
assert ez_field.shape == (11, 16)

px_protons = df["Particles/Px/proton"]
assert sorted(px_protons.coords) == sorted(("X_Particles/proton", "time"))
assert px_protons.shape == (11, 1920)

px_protons = df["Particles/Weight/proton"]
assert sorted(px_protons.coords) == sorted(("X_Particles/proton", "time"))
assert px_protons.shape == (11, 1920)


def test_multiple_files_multiple_time_dims():
df = open_mfdataset(EXAMPLE_FILES_DIR.glob("*.sdf"), separate_times=True)
df = open_mfdataset(
EXAMPLE_FILES_DIR.glob("*.sdf"), separate_times=True, keep_particles=True
)

assert list(df["Electric Field/Ex"].coords) != list(df["Electric Field/Ez"].coords)
assert df["Electric Field/Ex"].shape == (11, 16)
assert df["Electric Field/Ez"].shape == (1, 16)
assert df["Particles/Px/proton"].shape == (1, 1920)
assert df["Particles/Weight/proton"].shape == (2, 1920)


def test_erroring_on_mismatched_jobid_files():
with pytest.raises(ValueError):
xr.open_mfdataset(
EXAMPLE_MISMATCHED_FILES_DIR.glob("*.sdf"),
concat_dim="time",
combine="nested",
data_vars="minimal",
coords="minimal",
Expand Down

0 comments on commit 4fcdc55

Please sign in to comment.