Skip to content

Commit

Permalink
Remove build_circular_geometry defaults. There should be a single s…
Browse files Browse the repository at this point in the history
…ource of truth for these defaults, which is the `CircularConfig` Pydantic model.

Removing defaults breaks many tests, so a general switch is made to using the Pydantic model to build circular geometries instead of `build_circular_geometry`.

Includes some fixes in tests for Rmax < Rmin which failed pydantic validation.

PiperOrigin-RevId: 731272969
  • Loading branch information
sbodenstein authored and Torax team committed Feb 27, 2025
1 parent a8dda40 commit e2e7bd4
Show file tree
Hide file tree
Showing 47 changed files with 253 additions and 213 deletions.
4 changes: 2 additions & 2 deletions torax/config/tests/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torax.config import build_sim
from torax.config import runtime_params as runtime_params_lib
from torax.config import runtime_params_slice
from torax.geometry import circular_geometry
from torax.geometry import pydantic_model as geometry_pydantic_model
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params_lib
from torax.stepper import linear_theta_method
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_general_runtime_params_with_time_dependent_args(self):
self.assertEqual(runtime_params.plasma_composition.main_ion, 'D')
self.assertEqual(runtime_params.profile_conditions.ne_is_fGW, False)
self.assertEqual(runtime_params.output_dir, '/tmp/this/is/a/test')
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
dynamic_runtime_params_slice = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params,
Expand Down
6 changes: 3 additions & 3 deletions torax/config/tests/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
from absl.testing import parameterized
from torax import interpolated_param
from torax.config import numerics
from torax.geometry import circular_geometry
from torax.geometry import pydantic_model as geometry_pydantic_model


class NumericsTest(parameterized.TestCase):
"""Unit tests for the `torax.config.numerics` module."""

def test_numerics_make_provider(self):
nums = numerics.Numerics()
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
provider = nums.make_provider(geo.torax_mesh)
provider.build_dynamic_params(t=0.0)

Expand All @@ -35,7 +35,7 @@ def test_interpolated_vars_are_only_constructed_once(
):
"""Tests that interpolated vars are only constructed once."""
nums = numerics.Numerics()
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
provider = nums.make_provider(geo.torax_mesh)
interpolated_params = {}
for field in provider:
Expand Down
10 changes: 5 additions & 5 deletions torax/config/tests/plasma_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torax import charge_states
from torax import interpolated_param
from torax.config import plasma_composition
from torax.geometry import circular_geometry
from torax.geometry import pydantic_model as geometry_pydantic_model


class PlasmaCompositionTest(parameterized.TestCase):
Expand All @@ -29,7 +29,7 @@ class PlasmaCompositionTest(parameterized.TestCase):
def test_plasma_composition_make_provider(self):
"""Checks provider construction with no issues."""
pc = plasma_composition.PlasmaComposition()
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
provider = pc.make_provider(geo.torax_mesh)
provider.build_dynamic_params(t=0.0)

Expand All @@ -40,7 +40,7 @@ def test_plasma_composition_make_provider(self):
)
def test_zeff_accepts_float_inputs(self, zeff: float):
"""Tests that zeff accepts a single float input."""
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
pc = plasma_composition.PlasmaComposition(Zeff=zeff)
provider = pc.make_provider(geo.torax_mesh)
dynamic_pc = provider.build_dynamic_params(t=0.0)
Expand All @@ -63,7 +63,7 @@ def test_zeff_and_zeff_face_match_expected(self):
1.0: {0.0: 1.8, 0.5: 2.1, 1.0: 2.4},
}

geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
pc = plasma_composition.PlasmaComposition(Zeff=zeff_profile)
provider = pc.make_provider(geo.torax_mesh)

Expand Down Expand Up @@ -102,7 +102,7 @@ def test_interpolated_vars_are_only_constructed_once(
):
"""Tests that interpolated vars are only constructed once."""
pc = plasma_composition.PlasmaComposition()
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
provider = pc.make_provider(geo.torax_mesh)
interpolated_params = {}
for field in provider:
Expand Down
14 changes: 7 additions & 7 deletions torax/config/tests/profile_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torax import interpolated_param
from torax.config import config_args
from torax.config import profile_conditions
from torax.geometry import circular_geometry
from torax.geometry import pydantic_model as geometry_pydantic_model
import xarray as xr


Expand All @@ -30,7 +30,7 @@ class ProfileConditionsTest(parameterized.TestCase):

def test_profile_conditions_make_provider(self):
pc = profile_conditions.ProfileConditions()
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
provider = pc.make_provider(geo.torax_mesh)
provider.build_dynamic_params(t=0.0)

Expand All @@ -46,7 +46,7 @@ def test_profile_conditions_sets_Te_bound_right_correctly(
Te={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}},
Te_bound_right=Te_bound_right,
)
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
provider = pc.make_provider(geo.torax_mesh)
dcs = provider.build_dynamic_params(t=0.0)
self.assertEqual(dcs.Te_bound_right, expected_initial_value)
Expand All @@ -65,7 +65,7 @@ def test_profile_conditions_sets_Ti_bound_right_correctly(
Ti={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}},
Ti_bound_right=Ti_bound_right,
)
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
provider = pc.make_provider(geo.torax_mesh)
dcs = provider.build_dynamic_params(t=0.0)
self.assertEqual(dcs.Ti_bound_right, expected_initial_value)
Expand All @@ -84,7 +84,7 @@ def test_profile_conditions_sets_ne_bound_right_correctly(
ne={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}},
ne_bound_right=ne_bound_right,
)
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
provider = pc.make_provider(geo.torax_mesh)
dcs = provider.build_dynamic_params(t=0.0)
self.assertEqual(dcs.ne_bound_right, expected_initial_value)
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_profile_conditions_sets_psi_correctly(
self, psi, expected_initial_value, expected_second_value
):
"""Tests that psi is set correctly."""
geo = circular_geometry.build_circular_geometry(n_rho=4)
geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry()
pc = profile_conditions.ProfileConditions(
psi=psi,
)
Expand All @@ -147,7 +147,7 @@ def test_interpolated_vars_are_only_constructed_once(
):
"""Tests that interpolated vars are only constructed once."""
pc = profile_conditions.ProfileConditions()
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
provider = pc.make_provider(geo.torax_mesh)
interpolated_params = {}
for field in provider:
Expand Down
6 changes: 4 additions & 2 deletions torax/config/tests/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torax.config import config_args
from torax.config import profile_conditions as profile_conditions_lib
from torax.config import runtime_params as general_runtime_params
from torax.geometry import circular_geometry
from torax.geometry import pydantic_model as geometry_pydantic_model


# pylint: disable=invalid-name
Expand Down Expand Up @@ -137,7 +137,9 @@ def test_runtime_params_make_provider(self):
runtime_params = general_runtime_params.GeneralRuntimeParams(
profile_conditions=profile_conditions_lib.ProfileConditions()
)
torax_mesh = circular_geometry.build_circular_geometry().torax_mesh
torax_mesh = (
geometry_pydantic_model.CircularConfig().build_geometry().torax_mesh
)
runtime_params_provider = runtime_params.make_provider(torax_mesh)
runtime_params_provider.build_dynamic_params(0.0)

Expand Down
14 changes: 7 additions & 7 deletions torax/config/tests/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torax.config import profile_conditions as profile_conditions_lib
from torax.config import runtime_params as general_runtime_params
from torax.config import runtime_params_slice as runtime_params_slice_lib
from torax.geometry import circular_geometry
from torax.geometry import pydantic_model as geometry_pydantic_model
from torax.pedestal_model import set_tped_nped
from torax.sources import electron_density_sources
from torax.sources import generic_current_source
Expand All @@ -37,7 +37,7 @@ class RuntimeParamsSliceTest(parameterized.TestCase):

def setUp(self):
super().setUp()
self._geo = circular_geometry.build_circular_geometry()
self._geo = geometry_pydantic_model.CircularConfig().build_geometry()

def test_dynamic_slice_can_be_input_to_jitted_function(self):
"""Tests that the slice can be input to a jitted function."""
Expand Down Expand Up @@ -351,7 +351,7 @@ def test_profile_conditions_set_electron_temperature_and_boundary_condition(
runtime_params = general_runtime_params.GeneralRuntimeParams(
profile_conditions=profile_conditions,
)
geo = circular_geometry.build_circular_geometry(n_rho=4)
geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry()
dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
torax_mesh=geo.torax_mesh,
Expand Down Expand Up @@ -394,7 +394,7 @@ def test_profile_conditions_set_electron_density_and_boundary_condition(
ne_is_fGW=ne_is_fGW,
),
)
geo = circular_geometry.build_circular_geometry(n_rho=4)
geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry()

dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
Expand Down Expand Up @@ -427,7 +427,7 @@ def test_update_dynamic_slice_provider_updates_runtime_params(
Ti_bound_right={0.0: 1.0, 1.0: 2.0},
),
)
geo = circular_geometry.build_circular_geometry(n_rho=4)
geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry()
provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
torax_mesh=geo.torax_mesh,
Expand Down Expand Up @@ -463,7 +463,7 @@ def test_update_dynamic_slice_provider_updates_sources(
source_models_builder.runtime_params[
generic_current_source.GenericCurrentSource.SOURCE_NAME
].Iext = 1.0
geo = circular_geometry.build_circular_geometry(n_rho=4)
geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry()
provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
sources=source_models_builder.runtime_params,
Expand Down Expand Up @@ -519,7 +519,7 @@ def test_update_dynamic_slice_provider_updates_transport(
"""Tests that the dynamic slice provider can be updated."""
runtime_params = general_runtime_params.GeneralRuntimeParams()
transport = transport_params_lib.RuntimeParams(De_inner=1.0)
geo = circular_geometry.build_circular_geometry(n_rho=4)
geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry()
provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
torax_mesh=geo.torax_mesh,
Expand Down
6 changes: 4 additions & 2 deletions torax/fvm/tests/calc_coeffs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torax.config import runtime_params as general_runtime_params
from torax.config import runtime_params_slice as runtime_params_slice_lib
from torax.fvm import calc_coeffs
from torax.geometry import circular_geometry
from torax.geometry import pydantic_model as geometry_pydantic_model
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_profile_builders
Expand Down Expand Up @@ -52,7 +52,9 @@ def test_calc_coeffs_smoke_test(
predictor_corrector=False,
theta_imp=theta_imp,
)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
geo = geometry_pydantic_model.CircularConfig(
n_rho=num_cells
).build_geometry()
transport_model_builder = (
constant_transport_model.ConstantTransportModelBuilder(
runtime_params=constant_transport_model.RuntimeParams(
Expand Down
18 changes: 13 additions & 5 deletions torax/fvm/tests/fvm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from torax.fvm import cell_variable
from torax.fvm import implicit_solve_block
from torax.fvm import residual_and_loss
from torax.geometry import circular_geometry
from torax.geometry import pydantic_model as geometry_pydantic_model
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_profile_builders
Expand Down Expand Up @@ -224,7 +224,9 @@ def test_nonlinear_solve_block_loss_minimum(
predictor_corrector=False,
theta_imp=theta_imp,
)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
geo = geometry_pydantic_model.CircularConfig(
n_rho=num_cells
).build_geometry()
transport_model_builder = (
constant_transport_model.ConstantTransportModelBuilder(
runtime_params=constant_transport_model.RuntimeParams(
Expand Down Expand Up @@ -392,7 +394,9 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
pedestal_model_builder = (
set_tped_nped.SetTemperatureDensityPedestalModelBuilder()
)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
geo = geometry_pydantic_model.CircularConfig(
n_rho=num_cells
).build_geometry()
dynamic_runtime_params_slice = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params,
Expand All @@ -413,7 +417,9 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
stepper=stepper_params,
)
)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
geo = geometry_pydantic_model.CircularConfig(
n_rho=num_cells
).build_geometry()
source_models = source_models_builder()
initial_core_profiles = core_profile_setters.initial_core_profiles(
static_runtime_params_slice,
Expand Down Expand Up @@ -515,7 +521,9 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
predictor_corrector=False,
theta_imp=0.0,
)
geo = circular_geometry.build_circular_geometry(n_rho=num_cells)
geo = geometry_pydantic_model.CircularConfig(
n_rho=num_cells
).build_geometry()
transport_model_builder = (
constant_transport_model.ConstantTransportModelBuilder(
runtime_params=constant_transport_model.RuntimeParams(
Expand Down
16 changes: 7 additions & 9 deletions torax/geometry/circular_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,18 @@
# external physics implementations
# pylint: disable=invalid-name
def build_circular_geometry(
n_rho: int = 25,
elongation_LCFS: float = 1.72,
Rmaj: float = 6.2,
Rmin: float = 2.0,
B0: float = 5.3,
hires_fac: int = 4,
n_rho: int,
elongation_LCFS: float,
Rmaj: float,
Rmin: float,
B0: float,
hires_fac: int,
) -> geometry.Geometry:
"""Constructs a circular Geometry instance used for testing only.
Args:
n_rho: Radial grid points (num cells)
elongation_LCFS: Elongation at last closed flux surface. Defaults to 1.72
for the ITER elongation, to approximately correct volume and area integral
Jacobians.
elongation_LCFS: Elongation at last closed flux surface.
Rmaj: major radius (R) in meters
Rmin: minor radius (a) in meters
B0: Toroidal magnetic field on axis [T]
Expand Down
15 changes: 10 additions & 5 deletions torax/geometry/tests/circular_geometry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,8 @@ def test_build_geometry_provider_from_circular(self):
B0=5.3,
hires_fac=4,
)
provider = (
geometry_provider.TimeDependentGeometryProvider.create_provider(
{0.0: geo_0, 10.0: geo_1}
)
provider = geometry_provider.TimeDependentGeometryProvider.create_provider(
{0.0: geo_0, 10.0: geo_1}
)
geo = provider(5.0)
np.testing.assert_allclose(geo.Rmaj, 6.7)
Expand All @@ -55,7 +53,14 @@ def test_circular_geometry_can_be_input_to_jitted_function(self):
def foo(geo: geometry.Geometry):
return geo.Rmaj

geo = circular_geometry.build_circular_geometry()
geo = circular_geometry.build_circular_geometry(
n_rho=25,
elongation_LCFS=1.72,
Rmaj=6.2,
Rmin=2.0,
B0=5.3,
hires_fac=4,
)
# Make sure you can call the function with geo as an arg.
foo(geo)

Expand Down
Loading

0 comments on commit e2e7bd4

Please sign in to comment.