Skip to content

Commit

Permalink
Merge pull request #24 from PlasmaFAIR/migrate-to-underscores
Browse files Browse the repository at this point in the history
Migrate from forward slashes, spaces and dashes to underscores
  • Loading branch information
JoelLucaAdams authored Oct 25, 2024
2 parents 2cca296 + 3b73f73 commit c898267
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 42 deletions.
14 changes: 10 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
read SDF files as created by the [EPOCH](https://epochpic.github.io)
plasma PIC code.

> ![IMPORTANT]
> All variable names now use snake_case to align with Epoch’s `sdf_helper`
> conventions. For example, `Electric Field/Ex` has been updated to
> `Electric_Field_Ex`.
## Installation

Until this is on PyPI, please install directly from this repo:
Expand Down Expand Up @@ -33,14 +38,15 @@ from sdf_xarray import SDFPreprocess

df = xr.open_dataset("0010.sdf")

print(df["Electric Field/Ex"])
print(df["Electric_Field_Ex"])

# <xarray.DataArray 'Electric Field/Ex' (X_x_px_deltaf/electron_beam: 16)> Size: 128B
# <xarray.DataArray 'Electric_Field_Ex' (X_x_px_deltaf_electron_beam: 16)> Size: 128B
# [16 values with dtype=float64]
# Coordinates:
# * X_x_px_deltaf/electron_beam (X_x_px_deltaf/electron_beam) float64 128B 1...
# * X_x_px_deltaf_electron_beam (X_x_px_deltaf_electron_beam) float64 128B 1...
# Attributes:
# units: V/m
# full_name: "Electric Field/Ex"
```

### Multi file loading
Expand All @@ -57,7 +63,7 @@ ds = xr.open_mfdataset(
print(ds)

# Dimensions:
# time: 301, X_Grid_mid: 128, Y_Grid_mid: 128, Px_px_py/Photon: 200, Py_px_py/Photon: 200, X_Grid: 129, Y_Grid: 129, Px_px_py/Photon_mid: 199, Py_px_py/Photon_mid: 199
# time: 301, X_Grid_mid: 128, Y_Grid_mid: 128, Px_px_py_Photon: 200, Py_px_py_Photon: 200, X_Grid: 129, Y_Grid: 129, Px_px_py_Photon_mid: 199, Py_px_py_Photon_mid: 199
# Coordinates: (9)
# Data variables: (18)
# Indexes: (9)
Expand Down
44 changes: 33 additions & 11 deletions src/sdf_xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from .sdf_interface import Constant, SDFFile


def _rename_with_underscore(name: str) -> str:
"""A lot of the variable names have spaces, forward slashes and dashes in them, which
are not valid in netCDF names so we replace them with underscores."""
return name.replace("/", "_").replace(" ", "_").replace("-", "_")


def combine_datasets(path_glob: Iterable | str, **kwargs) -> xr.Dataset:
"""Combine all datasets using a single time dimension"""

Expand Down Expand Up @@ -86,7 +92,7 @@ def open_mfdataset(
if df.coords[coord].attrs.get("point_data", False):
# We need to undo our renaming of the coordinates
base_name = coord.split("_", maxsplit=1)[-1]
sdf_coord_name = f"Grid/{base_name}"
sdf_coord_name = f"Grid_{base_name}"
df.coords[coord] = df.coords[coord].expand_dims(
dim={var_times_map[sdf_coord_name]: [df.attrs["time"]]}
)
Expand All @@ -106,9 +112,11 @@ def make_time_dims(path_glob):
for f in path_glob:
with SDFFile(str(f)) as sdf_file:
for key in sdf_file.variables:
vars_count[key].append(sdf_file.header["time"])
vars_count[_rename_with_underscore(key)].append(sdf_file.header["time"])
for grid in sdf_file.grids.values():
vars_count[grid.name].append(sdf_file.header["time"])
vars_count[_rename_with_underscore(grid.name)].append(
sdf_file.header["time"]
)

# Count the unique set of lists of times
times_count = Counter((tuple(v) for v in vars_count.values()))
Expand Down Expand Up @@ -236,6 +244,12 @@ def _norm_grid_name(grid_name: str) -> str:
def _grid_species_name(grid_name: str) -> str:
return grid_name.split("/")[-1]

def _process_grid_name(grid_name: str, transform_func) -> str:
"""Apply the given transformation function and then rename with underscores."""
transformed_name = transform_func(grid_name)
renamed_name = _rename_with_underscore(transformed_name)
return renamed_name

for key, value in self.ds.grids.items():
if "cpu" in key.lower():
# Had some problems with these variables, so just ignore them for now
Expand All @@ -244,12 +258,12 @@ def _grid_species_name(grid_name: str) -> str:
if not self.keep_particles and value.is_point_data:
continue

base_name = _norm_grid_name(value.name)
base_name = _process_grid_name(value.name, _norm_grid_name)

for label, coord, unit in zip(value.labels, value.data, value.units):
full_name = f"{label}_{base_name}"
dim_name = (
f"ID_{_grid_species_name(key)}"
f"ID_{_process_grid_name(key, _grid_species_name)}"
if value.is_point_data
else full_name
)
Expand All @@ -260,6 +274,7 @@ def _grid_species_name(grid_name: str) -> str:
"long_name": label,
"units": unit,
"point_data": value.is_point_data,
"full_name": value.name,
},
)

Expand All @@ -276,6 +291,7 @@ def _grid_species_name(grid_name: str) -> str:

if isinstance(value, Constant) or value.grid is None:
data_attrs = {}
data_attrs["full_name"] = key
if value.units is not None:
data_attrs["units"] = value.units

Expand All @@ -285,13 +301,14 @@ def _grid_species_name(grid_name: str) -> str:
# some (hopefully) unique dimension names
shape = getattr(value.data, "shape", ())
dims = [f"dim_{key}_{n}" for n, _ in enumerate(shape)]
base_name = _rename_with_underscore(key)

data_vars[key] = Variable(dims, value.data, attrs=data_attrs)
data_vars[base_name] = Variable(dims, value.data, attrs=data_attrs)
continue

if value.is_point_data:
# Point (particle) variables are 1D
var_coords = (f"ID_{_grid_species_name(key)}",)
var_coords = (f"ID_{_process_grid_name(key, _grid_species_name)}",)
else:
# These are DataArrays

Expand All @@ -307,12 +324,12 @@ def _grid_species_name(grid_name: str) -> str:
# for the corresponding coordinate
dim_size_lookup = defaultdict(dict)
grid = self.ds.grids[value.grid]
grid_base_name = _norm_grid_name(grid.name)
grid_base_name = _process_grid_name(grid.name, _norm_grid_name)
for dim_size, dim_name in zip(grid.shape, grid.labels):
dim_size_lookup[dim_name][dim_size] = f"{dim_name}_{grid_base_name}"

grid_mid = self.ds.grids[value.grid_mid]
grid_mid_base_name = _norm_grid_name(grid_mid.name)
grid_mid_base_name = _process_grid_name(grid_mid.name, _norm_grid_name)
for dim_size, dim_name in zip(grid_mid.shape, grid_mid.labels):
dim_size_lookup[dim_name][
dim_size
Expand All @@ -324,9 +341,14 @@ def _grid_species_name(grid_name: str) -> str:
]

# TODO: error handling here? other attributes?
data_attrs = {"units": value.units, "point_data": value.is_point_data}
data_attrs = {
"units": value.units,
"point_data": value.is_point_data,
"full_name": key,
}
lazy_data = indexing.LazilyIndexedArray(SDFBackendArray(key, self))
data_vars[key] = Variable(var_coords, lazy_data, data_attrs)
base_name = _rename_with_underscore(key)
data_vars[base_name] = Variable(var_coords, lazy_data, data_attrs)

# TODO: might need to decode if mult is set?

Expand Down
58 changes: 31 additions & 27 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pathlib

import xarray as xr
from sdf_xarray import open_mfdataset, SDFPreprocess
import pytest
import xarray as xr

from sdf_xarray import SDFPreprocess, open_mfdataset

EXAMPLE_FILES_DIR = pathlib.Path(__file__).parent / "example_files"
EXAMPLE_MISMATCHED_FILES_DIR = (
Expand All @@ -13,68 +14,71 @@

def test_basic():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0000.sdf") as df:
ex_field = "Electric Field/Ex"
ex_field = "Electric_Field_Ex"
assert ex_field in df
x_coord = "X_Grid_mid"
assert x_coord in df[ex_field].coords
assert df[x_coord].attrs["long_name"] == "X"

px_protons = "Particles/Px/proton"
px_protons = "Particles_Px_proton"
assert px_protons not in df
x_coord = "X_Particles/proton"
x_coord = "X_Particles_proton"
assert x_coord not in df.coords


def test_constant_name_and_units():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0000.sdf") as df:
name = "Absorption/Total Laser Energy Injected"
name = "Absorption_Total_Laser_Energy_Injected"
full_name = "Absorption/Total Laser Energy Injected"
assert name in df
assert df[name].units == "J"
assert df[name].attrs["full_name"] == full_name


def test_coords():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0010.sdf") as df:
px_electron = "dist_fn/x_px/electron"
px_electron = "dist_fn_x_px_electron"
assert px_electron in df
x_coord = "Px_x_px/electron"
print(df[px_electron].coords)
x_coord = "Px_x_px_electron"
assert x_coord in df[px_electron].coords
assert df[x_coord].attrs["long_name"] == "Px"
assert df[x_coord].attrs["full_name"] == "Grid/x_px/electron"


def test_particles():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0010.sdf", keep_particles=True) as df:
px_protons = "Particles/Px/proton"
px_protons = "Particles_Px_proton"
assert px_protons in df
x_coord = "X_Particles/proton"
x_coord = "X_Particles_proton"
assert x_coord in df[px_protons].coords
assert df[x_coord].attrs["long_name"] == "X"


def test_no_particles():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0010.sdf", keep_particles=False) as df:
px_protons = "Particles/Px/proton"
px_protons = "Particles_Px_proton"
assert px_protons not in df


def test_multiple_files_one_time_dim():
df = open_mfdataset(EXAMPLE_FILES_DIR.glob("*.sdf"), keep_particles=True)
ex_field = df["Electric Field/Ex"]
ex_field = df["Electric_Field_Ex"]
assert sorted(ex_field.coords) == sorted(("X_Grid_mid", "time"))
assert ex_field.shape == (11, 16)

ez_field = df["Electric Field/Ez"]
ez_field = df["Electric_Field_Ez"]
assert sorted(ez_field.coords) == sorted(("X_Grid_mid", "time"))
assert ez_field.shape == (11, 16)

px_protons = df["Particles/Px/proton"]
assert sorted(px_protons.coords) == sorted(("X_Particles/proton", "time"))
px_protons = df["Particles_Px_proton"]
assert sorted(px_protons.coords) == sorted(("X_Particles_proton", "time"))
assert px_protons.shape == (11, 1920)

px_protons = df["Particles/Weight/proton"]
assert sorted(px_protons.coords) == sorted(("X_Particles/proton", "time"))
px_protons = df["Particles_Weight_proton"]
assert sorted(px_protons.coords) == sorted(("X_Particles_proton", "time"))
assert px_protons.shape == (11, 1920)

absorption = df["Absorption/Total Laser Energy Injected"]
absorption = df["Absorption_Total_Laser_Energy_Injected"]
assert tuple(absorption.coords) == ("time",)
assert absorption.shape == (11,)

Expand All @@ -84,12 +88,12 @@ def test_multiple_files_multiple_time_dims():
EXAMPLE_FILES_DIR.glob("*.sdf"), separate_times=True, keep_particles=True
)

assert list(df["Electric Field/Ex"].coords) != list(df["Electric Field/Ez"].coords)
assert df["Electric Field/Ex"].shape == (11, 16)
assert df["Electric Field/Ez"].shape == (1, 16)
assert df["Particles/Px/proton"].shape == (1, 1920)
assert df["Particles/Weight/proton"].shape == (2, 1920)
assert df["Absorption/Total Laser Energy Injected"].shape == (11,)
assert list(df["Electric_Field_Ex"].coords) != list(df["Electric_Field_Ez"].coords)
assert df["Electric_Field_Ex"].shape == (11, 16)
assert df["Electric_Field_Ez"].shape == (1, 16)
assert df["Particles_Px_proton"].shape == (1, 1920)
assert df["Particles_Weight_proton"].shape == (2, 1920)
assert df["Absorption_Total_Laser_Energy_Injected"].shape == (11,)


def test_erroring_on_mismatched_jobid_files():
Expand All @@ -110,7 +114,7 @@ def test_arrays_with_no_grids():
assert laser_phase in df
assert df[laser_phase].shape == (1,)

random_states = "Random States"
random_states = "Random_States"
assert random_states in df
assert df[random_states].shape == (8,)

Expand All @@ -121,6 +125,6 @@ def test_arrays_with_no_grids_multifile():
assert laser_phase in df
assert df[laser_phase].shape == (2, 1)

random_states = "Random States"
random_states = "Random_States"
assert random_states in df
assert df[random_states].shape == (2, 8)

0 comments on commit c898267

Please sign in to comment.