Skip to content

Commit

Permalink
Add scalars as variables instead of attributes
Browse files Browse the repository at this point in the history
Fixes #17
  • Loading branch information
ZedThree committed Oct 17, 2024
1 parent 96b965d commit 962a59e
Show file tree
Hide file tree
Showing 16 changed files with 38 additions and 7 deletions.
9 changes: 6 additions & 3 deletions src/sdf_xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,12 @@ def _grid_species_name(grid_name: str) -> str:
continue

if isinstance(value, Constant) or value.grid is None:
# No grid, so not physical data, just stick it in as an attribute
# This might have consequences when reading in multiple files?
attrs[key] = value.data
# No grid, so just a scalar
data_attrs = {}
if value.units is not None:
data_attrs["units"] = value.units

data_vars[key] = Variable((), value.data, data_attrs)
continue

if value.is_point_data:
Expand Down
19 changes: 17 additions & 2 deletions src/sdf_xarray/sdf_interface.pyx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
cimport csdf

import dataclasses
import re
import time

from libc.string cimport memcpy
Expand Down Expand Up @@ -103,11 +104,14 @@ cdef class Mesh(Block):
)


_CONSTANT_UNITS_RE = re.compile(r"(?P<name>.*) \((?P<units>.*)\)$")

@dataclasses.dataclass
cdef class Constant:
_id: str
name: str
data: int | str | float
units: str | None

@staticmethod
cdef Constant from_block(str name, csdf.sdf_block_t* block):
Expand All @@ -122,8 +126,16 @@ cdef class Constant:
if block.datatype == csdf.SDF_DATATYPE_INTEGER8:
data = (<csdf.int64_t*>block.const_value)[0]

# There's no metadata with e.g. units, but there's a
# convention to put one in brackets at the end of the name,
# if so, strip it off to give the name and units
units = None
if match := _CONSTANT_UNITS_RE.match(name):
name = match["name"]
units = match["units"]

return Constant(
_id=block.id.decode("UTF-8"), name=name, data=data
_id=block.id.decode("UTF-8"), name=name, data=data, units=units
)

@property
Expand Down Expand Up @@ -203,7 +215,10 @@ cdef class SDFFile:
}

elif block.blocktype == csdf.SDF_BLOCKTYPE_CONSTANT:
self.variables[name] = Constant.from_block(name, block)
# We modify the name to remove units, so convert it
# first so we can get the new name
constant = Constant.from_block(name, block)
self.variables[constant.name] = constant

elif block.blocktype in (
csdf.SDF_BLOCKTYPE_PLAIN_MESH,
Expand Down
Binary file modified tests/example_files/0000.sdf
Binary file not shown.
Binary file modified tests/example_files/0001.sdf
Binary file not shown.
Binary file modified tests/example_files/0002.sdf
Binary file not shown.
Binary file modified tests/example_files/0003.sdf
Binary file not shown.
Binary file modified tests/example_files/0004.sdf
Binary file not shown.
Binary file modified tests/example_files/0005.sdf
Binary file not shown.
Binary file modified tests/example_files/0006.sdf
Binary file not shown.
Binary file modified tests/example_files/0007.sdf
Binary file not shown.
Binary file modified tests/example_files/0008.sdf
Binary file not shown.
Binary file modified tests/example_files/0009.sdf
Binary file not shown.
Binary file modified tests/example_files/0010.sdf
Binary file not shown.
1 change: 1 addition & 0 deletions tests/example_files/input.deck
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ begin:output

# Extended IO
distribution_functions = always
absorption = always
end:output


Expand Down
12 changes: 12 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def test_basic():
assert x_coord not in df.coords


def test_constant_name_and_units():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0000.sdf") as df:
name = "Absorption/Total Laser Energy Injected"
assert name in df
assert df[name].units == "J"


def test_coords():
with xr.open_dataset(EXAMPLE_FILES_DIR / "0010.sdf") as df:
px_electron = "dist_fn/x_px/electron"
Expand Down Expand Up @@ -66,6 +73,10 @@ def test_multiple_files_one_time_dim():
assert sorted(px_protons.coords) == sorted(("X_Particles/proton", "time"))
assert px_protons.shape == (11, 1920)

absorption = df["Absorption/Total Laser Energy Injected"]
assert tuple(absorption.coords) == ("time",)
assert absorption.shape == (11,)


def test_multiple_files_multiple_time_dims():
df = open_mfdataset(
Expand All @@ -77,6 +88,7 @@ def test_multiple_files_multiple_time_dims():
assert df["Electric Field/Ez"].shape == (1, 16)
assert df["Particles/Px/proton"].shape == (1, 1920)
assert df["Particles/Weight/proton"].shape == (2, 1920)
assert df["Absorption/Total Laser Energy Injected"].shape == (11,)


def test_erroring_on_mismatched_jobid_files():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cython.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_sdffile():

assert f.run_info["version"] == "4.19.3"

assert f.variables["Wall-time"].data == 0.0032005560000000002
assert f.variables["Wall-time"].data == 0.014211990000000008


def test_sdffile_with_more_things():
Expand All @@ -29,7 +29,7 @@ def test_sdffile_with_more_things():

assert f.run_info["version"] == "4.19.3"

assert f.variables["Wall-time"].data == 3.968111756
assert f.variables["Wall-time"].data == 4.068961859


def test_variable_names():
Expand Down

0 comments on commit 962a59e

Please sign in to comment.