Skip to content

Commit

Permalink
Add rhonorm1_defined_in_timerhoinput method to TimeVaryingArray.
Browse files Browse the repository at this point in the history
This will replace the more complicated implementation in `torax/interpolated_params.py` in a future CL.

PiperOrigin-RevId: 725346293
  • Loading branch information
sbodenstein authored and Torax team committed Feb 26, 2025
1 parent 3b3845d commit e23df1a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
9 changes: 9 additions & 0 deletions torax/torax_pydantic/interpolated_param_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ class TimeVaryingArray(interpolated_param_common.TimeVaryingBase):
)
rho_norm_grid: model_base.NumpyArray | None = None

@functools.cached_property
def right_boundary_conditions_defined(self) -> bool:
"""Checks if the boundary condition at rho=1.0 is always defined."""

for rho_norm, _ in self.value.values():
if 1.0 not in rho_norm:
return False
return True

@pydantic.model_validator(mode='before')
@classmethod
def _conform_data(
Expand Down
22 changes: 22 additions & 0 deletions torax/torax_pydantic/tests/interpolated_param_2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,28 @@ def test_mutation_behavior(self):
out2 = interpolated.get_value(x=0.0)
self.assertEqual(out2.tolist(), [v2, v2, v2])

def test_right_boundary_conditions_defined(self):
"""Tests that right_boundary_conditions_defined works correctly."""

with self.subTest('float_input'):
# A single float is interpreted as defined at rho=0.
self.assertFalse(
interpolated_param_2d.TimeVaryingArray.model_validate(
1.0
).right_boundary_conditions_defined
)

with self.subTest('xarray'):
value = xr.DataArray(
data=np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
coords={'time': [0.0, 1.0], 'rho_norm': [0.25, 0.5, 1.0]},
)
self.assertTrue(
interpolated_param_2d.TimeVaryingArray.model_validate(
value
).right_boundary_conditions_defined
)


if __name__ == '__main__':
absltest.main()

0 comments on commit e23df1a

Please sign in to comment.