Skip to content

Commit

Permalink
Remove build_circular_geometry defaults. There should be a single s…
Browse files Browse the repository at this point in the history
…ource of truth for these defaults, which is the `CircularConfig` Pydantic model.

Removing defaults breaks many tests, so a general switch is made to using the Pydantic model to build circular geometries instead of `build_circular_geometry`.

Includes some fixes in tests for Rmax < Rmin which failed pydantic validation.

PiperOrigin-RevId: 731272969
  • Loading branch information
sbodenstein authored and Torax team committed Feb 26, 2025
1 parent 0aa9d33 commit c588818
Show file tree
Hide file tree
Showing 55 changed files with 884 additions and 534 deletions.
13 changes: 7 additions & 6 deletions run_simulation_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
from torax.config import build_sim
from torax.config import config_loader
from torax.config import runtime_params
from torax.geometry import pydantic_model as geometry_pydantic_model
from torax.plotting import plotruns_lib


# String used when prompting the user to make a choice of command
CHOICE_PROMPT = 'Your choice: '
# String used when prompting the user to make a yes / no choice
Expand Down Expand Up @@ -246,9 +246,9 @@ def change_config(
new_runtime_params = build_sim.build_runtime_params_from_config(
sim_config['runtime_params']
)
new_geo_provider = build_sim.build_geometry_provider_from_config(
sim_config['geometry'],
)
new_geo_provider = geometry_pydantic_model.Geometry.from_dict(
sim_config['geometry']
).build_provider()
new_transport_model_builder = (
build_sim.build_transport_model_builder_from_config(
sim_config['transport']
Expand Down Expand Up @@ -299,7 +299,7 @@ def change_config(


def change_sim_obj(
config_module_str: str
config_module_str: str,
) -> tuple[sim_lib.Sim, runtime_params.GeneralRuntimeParams, str]:
"""Builds a new Sim from the config module.
Expand Down Expand Up @@ -554,7 +554,8 @@ def main(_):
try:
start_time = time.time()
sim_and_runtime_params_or_none = change_config(
sim, config_module_str)
sim, config_module_str
)
if sim_and_runtime_params_or_none is not None:
sim, new_runtime_params = sim_and_runtime_params_or_none
config_change_time = time.time() - start_time
Expand Down
152 changes: 5 additions & 147 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
from torax import sim as sim_lib
from torax.config import config_args
from torax.config import runtime_params as runtime_params_lib
from torax.geometry import circular_geometry
from torax.geometry import geometry_provider
from torax.geometry import standard_geometry
from torax.geometry import pydantic_model as geometry_pydantic_model

from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.pedestal_model import set_pped_tpedratio_nped
from torax.pedestal_model import set_tped_nped
Expand Down Expand Up @@ -53,149 +52,6 @@
# pylint: disable=invalid-name


def _build_standard_geometry_provider(
geometry_type: str,
**kwargs,
) -> geometry_provider.GeometryProvider:
"""Constructs a geometry provider for a standard geometry."""
global_params = {'Ip_from_parameters', 'n_rho', 'geometry_dir'}
if geometry_type == 'chease':
intermediate_builder = (
standard_geometry.StandardGeometryIntermediates.from_chease
)
elif geometry_type == 'fbt':
# Check if parameters indicate a bundled FBT file and input validity.
if 'LY_bundle_object' in kwargs:
if 'geometry_configs' in kwargs:
raise ValueError(
"Cannot use 'geometry_configs' together with a bundled FBT file"
)
if 'LY_object' in kwargs:
raise ValueError(
"Cannot use 'LY_object' together with a bundled FBT file"
)
# Build and return the GeometryProvider for the bundled case.
intermediates = (
standard_geometry.StandardGeometryIntermediates.from_fbt_bundle(
**kwargs,
)
)
geometries = {
t: standard_geometry.build_standard_geometry(intermediates[t])
for t in intermediates
}
return standard_geometry.StandardGeometryProvider.create_provider(
geometries
)
else:
intermediate_builder = (
standard_geometry.StandardGeometryIntermediates.from_fbt_single_slice
)
elif geometry_type == 'eqdsk':
intermediate_builder = (
standard_geometry.StandardGeometryIntermediates.from_eqdsk
)
else:
raise ValueError(f'Unknown geometry type: {geometry_type}')
if 'geometry_configs' in kwargs:
# geometry config has sequence of standalone geometry files.
if not isinstance(kwargs['geometry_configs'], dict):
raise ValueError('geometry_configs must be a dict.')
geometries = {}
global_kwargs = {key: kwargs[key] for key in global_params if key in kwargs}
for time, config in kwargs['geometry_configs'].items():
if x := global_params.intersection(config):
raise ValueError(
'The following parameters cannot be set per geometry_config:'
f' {", ".join(x)}'
)
config.update(global_kwargs)
geometries[time] = standard_geometry.build_standard_geometry(
intermediate_builder(
**config,
)
)
return standard_geometry.StandardGeometryProvider.create_provider(
geometries
)
return geometry_provider.ConstantGeometryProvider(
standard_geometry.build_standard_geometry(
intermediate_builder(
**kwargs,
)
)
)


def _build_circular_geometry_provider(
**kwargs,
) -> geometry_provider.GeometryProvider:
"""Builds a `GeometryProvider` from the input config."""
if 'geometry_configs' in kwargs:
if not isinstance(kwargs['geometry_configs'], dict):
raise ValueError('geometry_configs must be a dict.')
if 'n_rho' not in kwargs:
raise ValueError('n_rho must be set in the input config.')
geometries = {}
for time, c in kwargs['geometry_configs'].items():
geometries[time] = circular_geometry.build_circular_geometry(
n_rho=kwargs['n_rho'], **c
)
return geometry_provider.TimeDependentGeometryProvider.create_provider(
geometries
)
return geometry_provider.ConstantGeometryProvider(
circular_geometry.build_circular_geometry(**kwargs)
)


def build_geometry_provider_from_config(
geometry_config: Mapping[str, Any],
) -> geometry_provider.GeometryProvider:
"""Builds a `Geometry` from the input config.
The input config has one required key: `geometry_type`. Its value must be one
of:
- "circular"
- "chease"
- "fbt"
Depending on the `geometry_type` given, there are different keys/values
expected in the rest of the config. See the following functions to get a full
list of the arguments exposed:
- `circular_geometry.build_circular_geometry()`
- `geometry.StandardGeometryIntermediates.from_chease()`
- `geometry.StandardGeometryIntermediates.from_fbt()`
For time dependent geometries, the input config should have a key
`geometry_configs` which maps times to a dict of geometry config args.
Args:
geometry_config: Python dictionary containing keys/values that map onto a
`geometry` module function that builds a `Geometry` object.
Returns:
A `GeometryProvider` based on the input config.
"""
if 'geometry_type' not in geometry_config:
raise ValueError('geometry_type must be set in the input config.')
# Do a shallow copy to keep references to the original objects while not
# modifying the original config dict with the pop-statement below.
kwargs = dict(geometry_config)
geometry_type = kwargs.pop('geometry_type').lower() # Remove from kwargs.
if geometry_type == 'circular':
return _build_circular_geometry_provider(**kwargs)
# elif geometry_type == 'chease' or geometry_type == 'fbt':
elif geometry_type in ['chease', 'fbt', 'eqdsk']:
return _build_standard_geometry_provider(
geometry_type=geometry_type, **kwargs
)

raise ValueError(f'Unknown geometry type: {geometry_type}')


def build_sim_from_config(
config: Mapping[str, Any],
) -> sim_lib.Sim:
Expand Down Expand Up @@ -275,7 +131,9 @@ def build_sim_from_config(
' for more info.'
)
runtime_params = build_runtime_params_from_config(config['runtime_params'])
geo_provider = build_geometry_provider_from_config(config['geometry'])
geo_provider = geometry_pydantic_model.Geometry.from_dict(
config['geometry']
).build_provider()

if 'restart' in config:
file_restart = runtime_params_lib.FileRestart(**config['restart'])
Expand Down
113 changes: 3 additions & 110 deletions torax/config/tests/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,16 @@
from torax.config import build_sim
from torax.config import runtime_params as runtime_params_lib
from torax.config import runtime_params_slice
from torax.geometry import circular_geometry
from torax.geometry import geometry_provider
from torax.geometry import standard_geometry
from torax.geometry import pydantic_model as geometry_pydantic_model
from torax.pedestal_model import set_tped_nped
from torax.sources import runtime_params as source_runtime_params_lib
from torax.stepper import linear_theta_method
from torax.stepper import nonlinear_theta_method
from torax.stepper import runtime_params as stepper_params
from torax.time_step_calculator import chi_time_step_calculator
from torax.time_step_calculator import fixed_time_step_calculator
from torax.transport_model import constant as constant_transport
from torax.transport_model import critical_gradient as critical_gradient_transport
from torax.transport_model import qlknn_transport_model
from torax.transport_model import runtime_params as transport_model_params


class BuildSimTest(parameterized.TestCase):
Expand Down Expand Up @@ -176,7 +172,7 @@ def test_general_runtime_params_with_time_dependent_args(self):
self.assertEqual(runtime_params.plasma_composition.main_ion, 'D')
self.assertEqual(runtime_params.profile_conditions.ne_is_fGW, False)
self.assertEqual(runtime_params.output_dir, '/tmp/this/is/a/test')
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
dynamic_runtime_params_slice = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params,
Expand All @@ -195,110 +191,6 @@ def test_general_runtime_params_with_time_dependent_args(self):
dynamic_runtime_params_slice.numerics.resistivity_mult, 0.6
)

def test_missing_geometry_type_raises_error(self):
with self.assertRaises(ValueError):
build_sim.build_geometry_provider_from_config({})

def test_build_circular_geometry(self):
geo_provider = build_sim.build_geometry_provider_from_config({
'geometry_type': 'circular',
'n_rho': 5, # override a default.
})
self.assertIsInstance(
geo_provider, geometry_provider.ConstantGeometryProvider
)
geo = geo_provider(t=0)
np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 5)
np.testing.assert_array_equal(geo.B0, 5.3) # test a default.

def test_build_geometry_from_chease(self):
geo_provider = build_sim.build_geometry_provider_from_config(
{
'geometry_type': 'chease',
'n_rho': 5, # override a default.
},
)
self.assertIsInstance(
geo_provider, geometry_provider.ConstantGeometryProvider
)
self.assertIsInstance(geo_provider(t=0), standard_geometry.StandardGeometry)
np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 5)

def test_build_time_dependent_geometry_from_chease(self):
"""Tests correctness of config constraints with time-dependent geometry."""

base_config = {
'geometry_type': 'chease',
'Ip_from_parameters': True,
'n_rho': 10, # overrides the default
'geometry_configs': {
0.0: {
'geometry_file': 'ITER_hybrid_citrin_equil_cheasedata.mat2cols',
'Rmaj': 6.2,
'Rmin': 2.0,
'B0': 5.3,
},
1.0: {
'geometry_file': 'ITER_hybrid_citrin_equil_cheasedata.mat2cols',
'Rmaj': 6.2,
'Rmin': 2.0,
'B0': 5.3,
},
},
}

# Test valid config
geo_provider = build_sim.build_geometry_provider_from_config(base_config)
self.assertIsInstance(
geo_provider, standard_geometry.StandardGeometryProvider
)
self.assertIsInstance(geo_provider(t=0), standard_geometry.StandardGeometry)
np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 10)

# Test invalid configs:
for param, value in zip(
['n_rho', 'Ip_from_parameters', 'geometry_dir'], [5, True, '.']
):
for time_key in [0.0, 1.0]:
invalid_config = base_config.copy()
invalid_config['geometry_configs'][time_key][param] = value
with self.assertRaises(ValueError):
build_sim.build_geometry_provider_from_config(invalid_config)

# pylint: disable=invalid-name
def test_chease_geometry_updates_Ip(self):
"""Tests that the Ip is updated when using chease geometry."""
runtime_params = runtime_params_lib.GeneralRuntimeParams()
original_Ip_tot = runtime_params.profile_conditions.Ip_tot
geo_provider = build_sim.build_geometry_provider_from_config({
'geometry_type': 'chease',
'Ip_from_parameters': (
False
), # this will force update runtime_params.Ip_tot
})
runtime_params_provider = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
transport=transport_model_params.RuntimeParams(),
sources={},
stepper=stepper_params.RuntimeParams(),
torax_mesh=geo_provider.torax_mesh,
)
)
dynamic_slice, geo = (
runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry(
t=0,
dynamic_runtime_params_slice_provider=runtime_params_provider,
geometry_provider=geo_provider,
)
)
self.assertIsInstance(geo, standard_geometry.StandardGeometry)
self.assertIsNotNone(dynamic_slice)
self.assertNotEqual(
dynamic_slice.profile_conditions.Ip_tot, original_Ip_tot
)
# pylint: enable=invalid-name

def test_empty_source_config_only_has_defaults_turned_off(self):
"""Tests that an empty source config has all sources turned off."""
source_models_builder = build_sim.build_sources_builder_from_config({})
Expand Down Expand Up @@ -504,5 +396,6 @@ def test_build_time_step_calculator_from_config(
)
self.assertIsInstance(time_stepper, expected_type)


if __name__ == '__main__':
absltest.main()
6 changes: 3 additions & 3 deletions torax/config/tests/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
from absl.testing import parameterized
from torax import interpolated_param
from torax.config import numerics
from torax.geometry import circular_geometry
from torax.geometry import pydantic_model as geometry_pydantic_model


class NumericsTest(parameterized.TestCase):
"""Unit tests for the `torax.config.numerics` module."""

def test_numerics_make_provider(self):
nums = numerics.Numerics()
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
provider = nums.make_provider(geo.torax_mesh)
provider.build_dynamic_params(t=0.0)

Expand All @@ -35,7 +35,7 @@ def test_interpolated_vars_are_only_constructed_once(
):
"""Tests that interpolated vars are only constructed once."""
nums = numerics.Numerics()
geo = circular_geometry.build_circular_geometry()
geo = geometry_pydantic_model.CircularConfig().build_geometry()
provider = nums.make_provider(geo.torax_mesh)
interpolated_params = {}
for field in provider:
Expand Down
Loading

0 comments on commit c588818

Please sign in to comment.