Skip to content

Commit

Permalink
Make it easier to run torax in a loop programatically
Browse files Browse the repository at this point in the history
Exposes useful fields so they can be read etc

PiperOrigin-RevId: 694502985
  • Loading branch information
tamaranorman authored and Torax team committed Nov 8, 2024
1 parent acd36e5 commit 1c4eadd
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 5 deletions.
6 changes: 6 additions & 0 deletions torax/config/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ def __init__(
self._stepper = stepper
self._construct_providers()

@property
def runtime_params_provider(
self,
) -> general_runtime_params_lib.GeneralRuntimeParamsProvider:
return self._runtime_params_provider

def _construct_providers(self):
self._runtime_params_provider = (
self._runtime_params.make_provider(
Expand Down
16 changes: 11 additions & 5 deletions torax/interpolated_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def __init__(
"""
xs, ys = value
self._is_bool_param = is_bool_param
self._interpolation_mode = interpolation_mode
match interpolation_mode:
case InterpolationMode.PIECEWISE_LINEAR:
self._param = PiecewiseLinearInterpolatedParam(xs=xs, ys=ys)
Expand All @@ -319,6 +320,16 @@ def __init__(
case _:
raise ValueError('Unknown interpolation mode.')

@property
def is_bool_param(self) -> bool:
"""Returns whether this param represents a bool."""
return self._is_bool_param

@property
def interpolation_mode(self) -> InterpolationMode:
"""Returns the interpolation mode used by this param."""
return self._interpolation_mode

def get_value(
self,
x: chex.Numeric,
Expand All @@ -334,11 +345,6 @@ def param(self) -> InterpolatedParamBase:
"""Returns the JAX-friendly interpolated param used under the hood."""
return self._param

@property
def is_bool_param(self) -> bool:
"""Returns whether this param represents a bool."""
return self._is_bool_param


class InterpolatedVarTimeRho(InterpolatedParamBase):
"""Interpolates on a grid (time, rho).
Expand Down
21 changes: 21 additions & 0 deletions torax/tests/interpolated_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,27 @@ def test_interpolated_param_get_value_is_jittable(

jax.jit(interpolated_var.get_value)(x=0.5)

@parameterized.product(
is_bool=[True, False],
interpolation_mode=[
interpolated_param.InterpolationMode.PIECEWISE_LINEAR,
interpolated_param.InterpolationMode.STEP,
],
)
def test_interpolated_var_properties(
self,
is_bool: bool,
interpolation_mode: interpolated_param.InterpolationMode,
):
"""Check the properties of the interpolated var are set correctly."""
var = interpolated_param.InterpolatedVarSingleAxis(
value=(np.array([0.0, 1.0]), np.array([0.0, 1.0])),
is_bool_param=is_bool,
interpolation_mode=interpolation_mode,
)
self.assertEqual(var.is_bool_param, is_bool)
self.assertEqual(var.interpolation_mode, interpolation_mode)


if __name__ == '__main__':
absltest.main()

0 comments on commit 1c4eadd

Please sign in to comment.