Skip to content

Commit

Permalink
Store if variable is particles or not in attributes
Browse files Browse the repository at this point in the history
Also reduce some duplication and use lazy array backend for particle
data too
  • Loading branch information
ZedThree committed Aug 23, 2024
1 parent 26f83cf commit 7f03e41
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
33 changes: 17 additions & 16 deletions src/sdf_xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def open_mfdataset(
dim={var_times_map[str(da)]: [df.attrs["time"]]}
)
for coord in df.coords:
if "Particles" in coord:
if df.coords[coord].attrs.get("point_data", False):
# We need to undo our renaming of the coordinates
base_name = coord.split("_", maxsplit=1)[-1]
sdf_coord_name = f"Grid/{base_name}"
Expand Down Expand Up @@ -241,7 +241,7 @@ def _grid_species_name(grid_name: str) -> str:
# Had some problems with these variables, so just ignore them for now
continue

if not self.keep_particles and "particles" in value.name.lower():
if not self.keep_particles and value.is_point_data:
continue

base_name = _norm_grid_name(value.name)
Expand All @@ -256,7 +256,11 @@ def _grid_species_name(grid_name: str) -> str:
coords[full_name] = (
dim_name,
coord,
{"long_name": label, "units": unit},
{
"long_name": label,
"units": unit,
"point_data": value.is_point_data,
},
)

# Read and convert SDF variables and meshes to xarray DataArrays and Coordinates
Expand All @@ -268,19 +272,15 @@ def _grid_species_name(grid_name: str) -> str:
if not self.keep_particles and "particles" in key.lower():
continue

if isinstance(value, Constant):
# This might have consequences when reading in multiple files?
attrs[key] = value.data

elif value.grid is None:
if isinstance(value, Constant) or value.grid is None:
# No grid, so not physical data, just stick it in as an attribute
# This might have consequences when reading in multiple files?
attrs[key] = value.data
continue

elif value.is_point_data:
if value.is_point_data:
# Point (particle) variables are 1D
var_coords = (f"ID_{_grid_species_name(key)}",)
data_attrs = {"units": value.units}
data_vars[key] = Variable(var_coords, value.data, data_attrs)
else:
# These are DataArrays

Expand Down Expand Up @@ -311,10 +311,11 @@ def _grid_species_name(grid_name: str) -> str:
dim_size_lookup[dim_name][dim_size]
for dim_name, dim_size in zip(grid.labels, value.shape)
]
# TODO: error handling here? other attributes?
data_attrs = {"units": value.units}
lazy_data = indexing.LazilyIndexedArray(SDFBackendArray(key, self))
data_vars[key] = Variable(var_coords, lazy_data, data_attrs)

# TODO: error handling here? other attributes?
data_attrs = {"units": value.units, "point_data": value.is_point_data}
lazy_data = indexing.LazilyIndexedArray(SDFBackendArray(key, self))
data_vars[key] = Variable(var_coords, lazy_data, data_attrs)

# TODO: might need to decode if mult is set?

Expand Down Expand Up @@ -390,7 +391,7 @@ def __call__(self, ds: xr.Dataset) -> xr.Dataset:

# Particles' spartial coordinates also evolve in time
for coord, value in ds.coords.items():
if "Particles" in coord:
if value.attrs.get("point_data", False):
ds.coords[coord] = value.expand_dims(time=[ds.attrs["time"]])

return ds
5 changes: 5 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ def test_basic():
assert x_coord in df[ex_field].coords
assert df[x_coord].attrs["long_name"] == "X"

px_protons = "Particles/Px/proton"
assert px_protons not in df
x_coord = "X_Particles/proton"
assert x_coord not in df.coords


def test_coords():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0010.sdf") as df:
Expand Down

0 comments on commit 7f03e41

Please sign in to comment.