Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelLucaAdams committed Oct 25, 2024
1 parent 7a2e952 commit 6704923
Showing 1 changed file with 31 additions and 27 deletions.
58 changes: 31 additions & 27 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pathlib

import xarray as xr
from sdf_xarray import open_mfdataset, SDFPreprocess
import pytest
import xarray as xr

from sdf_xarray import SDFPreprocess, open_mfdataset

EXAMPLE_FILES_DIR = pathlib.Path(__file__).parent / "example_files"
EXAMPLE_MISMATCHED_FILES_DIR = (
Expand All @@ -13,68 +14,71 @@

def test_basic():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0000.sdf") as df:
ex_field = "Electric Field/Ex"
ex_field = "Electric_Field_Ex"
assert ex_field in df
x_coord = "X_Grid_mid"
assert x_coord in df[ex_field].coords
assert df[x_coord].attrs["long_name"] == "X"

px_protons = "Particles/Px/proton"
px_protons = "Particles_Px_proton"
assert px_protons not in df
x_coord = "X_Particles/proton"
x_coord = "X_Particles_proton"
assert x_coord not in df.coords


def test_constant_name_and_units():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0000.sdf") as df:
name = "Absorption/Total Laser Energy Injected"
name = "Absorption_Total_Laser_Energy_Injected"
full_name = "Absorption/Total Laser Energy Injected"
assert name in df
assert df[name].units == "J"
assert df[name].attrs["full_name"] == full_name


def test_coords():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0010.sdf") as df:
px_electron = "dist_fn/x_px/electron"
px_electron = "dist_fn_x_px_electron"
assert px_electron in df
x_coord = "Px_x_px/electron"
print(df[px_electron].coords)
x_coord = "Px_x_px_electron"
assert x_coord in df[px_electron].coords
assert df[x_coord].attrs["long_name"] == "Px"
assert df[x_coord].attrs["full_name"] == "Grid/x_px/electron"


def test_particles():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0010.sdf", keep_particles=True) as df:
px_protons = "Particles/Px/proton"
px_protons = "Particles_Px_proton"
assert px_protons in df
x_coord = "X_Particles/proton"
x_coord = "X_Particles_proton"
assert x_coord in df[px_protons].coords
assert df[x_coord].attrs["long_name"] == "X"


def test_no_particles():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0010.sdf", keep_particles=False) as df:
px_protons = "Particles/Px/proton"
px_protons = "Particles_Px_proton"
assert px_protons not in df


def test_multiple_files_one_time_dim():
df = open_mfdataset(EXAMPLE_FILES_DIR.glob("*.sdf"), keep_particles=True)
ex_field = df["Electric Field/Ex"]
ex_field = df["Electric_Field_Ex"]
assert sorted(ex_field.coords) == sorted(("X_Grid_mid", "time"))
assert ex_field.shape == (11, 16)

ez_field = df["Electric Field/Ez"]
ez_field = df["Electric_Field_Ez"]
assert sorted(ez_field.coords) == sorted(("X_Grid_mid", "time"))
assert ez_field.shape == (11, 16)

px_protons = df["Particles/Px/proton"]
assert sorted(px_protons.coords) == sorted(("X_Particles/proton", "time"))
px_protons = df["Particles_Px_proton"]
assert sorted(px_protons.coords) == sorted(("X_Particles_proton", "time"))
assert px_protons.shape == (11, 1920)

px_protons = df["Particles/Weight/proton"]
assert sorted(px_protons.coords) == sorted(("X_Particles/proton", "time"))
px_protons = df["Particles_Weight_proton"]
assert sorted(px_protons.coords) == sorted(("X_Particles_proton", "time"))
assert px_protons.shape == (11, 1920)

absorption = df["Absorption/Total Laser Energy Injected"]
absorption = df["Absorption_Total_Laser_Energy_Injected"]
assert tuple(absorption.coords) == ("time",)
assert absorption.shape == (11,)

Expand All @@ -84,12 +88,12 @@ def test_multiple_files_multiple_time_dims():
EXAMPLE_FILES_DIR.glob("*.sdf"), separate_times=True, keep_particles=True
)

assert list(df["Electric Field/Ex"].coords) != list(df["Electric Field/Ez"].coords)
assert df["Electric Field/Ex"].shape == (11, 16)
assert df["Electric Field/Ez"].shape == (1, 16)
assert df["Particles/Px/proton"].shape == (1, 1920)
assert df["Particles/Weight/proton"].shape == (2, 1920)
assert df["Absorption/Total Laser Energy Injected"].shape == (11,)
assert list(df["Electric_Field_Ex"].coords) != list(df["Electric_Field_Ez"].coords)
assert df["Electric_Field_Ex"].shape == (11, 16)
assert df["Electric_Field_Ez"].shape == (1, 16)
assert df["Particles_Px_proton"].shape == (1, 1920)
assert df["Particles_Weight_proton"].shape == (2, 1920)
assert df["Absorption_Total_Laser_Energy_Injected"].shape == (11,)


def test_erroring_on_mismatched_jobid_files():
Expand All @@ -110,7 +114,7 @@ def test_arrays_with_no_grids():
assert laser_phase in df
assert df[laser_phase].shape == (1,)

random_states = "Random States"
random_states = "Random_States"
assert random_states in df
assert df[random_states].shape == (8,)

Expand All @@ -121,6 +125,6 @@ def test_arrays_with_no_grids_multifile():
assert laser_phase in df
assert df[laser_phase].shape == (2, 1)

random_states = "Random States"
random_states = "Random_States"
assert random_states in df
assert df[random_states].shape == (2, 8)

0 comments on commit 6704923

Please sign in to comment.