From c5888185a18f278f2c55e485f84bfacb8205f708 Mon Sep 17 00:00:00 2001 From: Sebastian Bodenstein Date: Wed, 26 Feb 2025 05:10:35 -0800 Subject: [PATCH] Remove `build_circular_geometry` defaults. There should be a single source of truth for these defaults, which is the `CircularConfig` Pydantic model. Removing defaults breaks many tests, so a general switch is made to using the Pydantic model to build circular geometries instead of `build_circular_geometry`. Includes some fixes in tests for Rmax < Rmin which failed pydantic validation. PiperOrigin-RevId: 731272969 --- run_simulation_main.py | 13 +- torax/config/build_sim.py | 152 +------- torax/config/tests/build_sim.py | 113 +----- torax/config/tests/numerics.py | 6 +- torax/config/tests/plasma_composition.py | 10 +- torax/config/tests/profile_conditions.py | 14 +- torax/config/tests/runtime_params.py | 6 +- torax/config/tests/runtime_params_slice.py | 14 +- torax/fvm/tests/calc_coeffs_test.py | 6 +- torax/fvm/tests/fvm_test.py | 18 +- torax/geometry/circular_geometry.py | 12 +- torax/geometry/pydantic_model.py | 362 ++++++++++++++++++ torax/geometry/standard_geometry.py | 30 +- .../geometry/tests/circular_geometry_test.py | 15 +- .../geometry/tests/geometry_provider_test.py | 20 +- torax/geometry/tests/geometry_test.py | 30 +- torax/geometry/tests/pydantic_model_test.py | 194 ++++++++++ .../geometry/tests/standard_geometry_test.py | 14 +- .../tests/set_pped_tpedratio_nped.py | 8 +- torax/pedestal_model/tests/set_tped_nped.py | 8 +- .../tests/bootstrap_current_source_test.py | 15 +- ...ction_impurity_radiation_heat_sink_test.py | 18 +- .../tests/electron_cyclotron_source_test.py | 4 +- .../tests/ion_cyclotron_source_test.py | 4 +- ...avrin_impurity_radiation_heat_sink_test.py | 5 +- torax/sources/tests/qei_source_test.py | 4 +- torax/sources/tests/source_operations_test.py | 5 +- .../tests/source_profile_builders_test.py | 4 +- .../tests/source_runtime_params_test.py | 4 +- torax/sources/tests/source_test.py | 15 +- torax/sources/tests/test_lib.py | 10 +- torax/tests/boundary_conditions.py | 4 +- torax/tests/core_profile_setters_test.py | 4 +- torax/tests/math_utils.py | 18 +- torax/tests/output.py | 4 +- torax/tests/physics.py | 14 +- torax/tests/post_processing.py | 6 +- torax/tests/sim.py | 15 +- torax/tests/sim_custom_sources.py | 38 +- torax/tests/sim_output_source_profiles.py | 37 +- torax/tests/sim_time_dependence.py | 4 +- torax/tests/state.py | 31 +- torax/tests/test_data/test_explicit.py | 4 +- torax/tests/test_lib/torax_refs.py | 47 +-- torax/torax_pydantic/model_base.py | 10 +- torax/torax_pydantic/model_config.py | 4 +- torax/transport_model/tests/bohm_gyrobohm.py | 4 +- torax/transport_model/tests/constant.py | 4 +- .../tests/critical_gradient.py | 4 +- .../tests/qlknn_transport_model.py | 6 +- .../tests/qualikiz_based_transport_model.py | 4 +- .../tests/qualikiz_transport_model.py | 7 +- .../tests/quasilinear_transport_model.py | 6 +- .../transport_model/tests/transport_model.py | 6 +- .../tests/transport_model_runtime_params.py | 4 +- 55 files changed, 884 insertions(+), 534 deletions(-) create mode 100644 torax/geometry/pydantic_model.py create mode 100644 torax/geometry/tests/pydantic_model_test.py diff --git a/run_simulation_main.py b/run_simulation_main.py index d73838c2..4f66802e 100644 --- a/run_simulation_main.py +++ b/run_simulation_main.py @@ -34,9 +34,9 @@ from torax.config import build_sim from torax.config import config_loader from torax.config import runtime_params +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.plotting import plotruns_lib - # String used when prompting the user to make a choice of command CHOICE_PROMPT = 'Your choice: ' # String used when prompting the user to make a yes / no choice @@ -246,9 +246,9 @@ def change_config( new_runtime_params = build_sim.build_runtime_params_from_config( sim_config['runtime_params'] ) - new_geo_provider = build_sim.build_geometry_provider_from_config( - sim_config['geometry'], - ) + new_geo_provider = geometry_pydantic_model.Geometry.from_dict( + sim_config['geometry'] + ).build_provider() new_transport_model_builder = ( build_sim.build_transport_model_builder_from_config( sim_config['transport'] @@ -299,7 +299,7 @@ def change_config( def change_sim_obj( - config_module_str: str + config_module_str: str, ) -> tuple[sim_lib.Sim, runtime_params.GeneralRuntimeParams, str]: """Builds a new Sim from the config module. @@ -554,7 +554,8 @@ def main(_): try: start_time = time.time() sim_and_runtime_params_or_none = change_config( - sim, config_module_str) + sim, config_module_str + ) if sim_and_runtime_params_or_none is not None: sim, new_runtime_params = sim_and_runtime_params_or_none config_change_time = time.time() - start_time diff --git a/torax/config/build_sim.py b/torax/config/build_sim.py index a90c632f..7202d319 100644 --- a/torax/config/build_sim.py +++ b/torax/config/build_sim.py @@ -21,9 +21,8 @@ from torax import sim as sim_lib from torax.config import config_args from torax.config import runtime_params as runtime_params_lib -from torax.geometry import circular_geometry -from torax.geometry import geometry_provider -from torax.geometry import standard_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model + from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.pedestal_model import set_pped_tpedratio_nped from torax.pedestal_model import set_tped_nped @@ -53,149 +52,6 @@ # pylint: disable=invalid-name -def _build_standard_geometry_provider( - geometry_type: str, - **kwargs, -) -> geometry_provider.GeometryProvider: - """Constructs a geometry provider for a standard geometry.""" - global_params = {'Ip_from_parameters', 'n_rho', 'geometry_dir'} - if geometry_type == 'chease': - intermediate_builder = ( - standard_geometry.StandardGeometryIntermediates.from_chease - ) - elif geometry_type == 'fbt': - # Check if parameters indicate a bundled FBT file and input validity. - if 'LY_bundle_object' in kwargs: - if 'geometry_configs' in kwargs: - raise ValueError( - "Cannot use 'geometry_configs' together with a bundled FBT file" - ) - if 'LY_object' in kwargs: - raise ValueError( - "Cannot use 'LY_object' together with a bundled FBT file" - ) - # Build and return the GeometryProvider for the bundled case. - intermediates = ( - standard_geometry.StandardGeometryIntermediates.from_fbt_bundle( - **kwargs, - ) - ) - geometries = { - t: standard_geometry.build_standard_geometry(intermediates[t]) - for t in intermediates - } - return standard_geometry.StandardGeometryProvider.create_provider( - geometries - ) - else: - intermediate_builder = ( - standard_geometry.StandardGeometryIntermediates.from_fbt_single_slice - ) - elif geometry_type == 'eqdsk': - intermediate_builder = ( - standard_geometry.StandardGeometryIntermediates.from_eqdsk - ) - else: - raise ValueError(f'Unknown geometry type: {geometry_type}') - if 'geometry_configs' in kwargs: - # geometry config has sequence of standalone geometry files. - if not isinstance(kwargs['geometry_configs'], dict): - raise ValueError('geometry_configs must be a dict.') - geometries = {} - global_kwargs = {key: kwargs[key] for key in global_params if key in kwargs} - for time, config in kwargs['geometry_configs'].items(): - if x := global_params.intersection(config): - raise ValueError( - 'The following parameters cannot be set per geometry_config:' - f' {", ".join(x)}' - ) - config.update(global_kwargs) - geometries[time] = standard_geometry.build_standard_geometry( - intermediate_builder( - **config, - ) - ) - return standard_geometry.StandardGeometryProvider.create_provider( - geometries - ) - return geometry_provider.ConstantGeometryProvider( - standard_geometry.build_standard_geometry( - intermediate_builder( - **kwargs, - ) - ) - ) - - -def _build_circular_geometry_provider( - **kwargs, -) -> geometry_provider.GeometryProvider: - """Builds a `GeometryProvider` from the input config.""" - if 'geometry_configs' in kwargs: - if not isinstance(kwargs['geometry_configs'], dict): - raise ValueError('geometry_configs must be a dict.') - if 'n_rho' not in kwargs: - raise ValueError('n_rho must be set in the input config.') - geometries = {} - for time, c in kwargs['geometry_configs'].items(): - geometries[time] = circular_geometry.build_circular_geometry( - n_rho=kwargs['n_rho'], **c - ) - return geometry_provider.TimeDependentGeometryProvider.create_provider( - geometries - ) - return geometry_provider.ConstantGeometryProvider( - circular_geometry.build_circular_geometry(**kwargs) - ) - - -def build_geometry_provider_from_config( - geometry_config: Mapping[str, Any], -) -> geometry_provider.GeometryProvider: - """Builds a `Geometry` from the input config. - - The input config has one required key: `geometry_type`. Its value must be one - of: - - - "circular" - - "chease" - - "fbt" - - Depending on the `geometry_type` given, there are different keys/values - expected in the rest of the config. See the following functions to get a full - list of the arguments exposed: - - - `circular_geometry.build_circular_geometry()` - - `geometry.StandardGeometryIntermediates.from_chease()` - - `geometry.StandardGeometryIntermediates.from_fbt()` - - For time dependent geometries, the input config should have a key - `geometry_configs` which maps times to a dict of geometry config args. - - Args: - geometry_config: Python dictionary containing keys/values that map onto a - `geometry` module function that builds a `Geometry` object. - - Returns: - A `GeometryProvider` based on the input config. - """ - if 'geometry_type' not in geometry_config: - raise ValueError('geometry_type must be set in the input config.') - # Do a shallow copy to keep references to the original objects while not - # modifying the original config dict with the pop-statement below. - kwargs = dict(geometry_config) - geometry_type = kwargs.pop('geometry_type').lower() # Remove from kwargs. - if geometry_type == 'circular': - return _build_circular_geometry_provider(**kwargs) - # elif geometry_type == 'chease' or geometry_type == 'fbt': - elif geometry_type in ['chease', 'fbt', 'eqdsk']: - return _build_standard_geometry_provider( - geometry_type=geometry_type, **kwargs - ) - - raise ValueError(f'Unknown geometry type: {geometry_type}') - - def build_sim_from_config( config: Mapping[str, Any], ) -> sim_lib.Sim: @@ -275,7 +131,9 @@ def build_sim_from_config( ' for more info.' ) runtime_params = build_runtime_params_from_config(config['runtime_params']) - geo_provider = build_geometry_provider_from_config(config['geometry']) + geo_provider = geometry_pydantic_model.Geometry.from_dict( + config['geometry'] + ).build_provider() if 'restart' in config: file_restart = runtime_params_lib.FileRestart(**config['restart']) diff --git a/torax/config/tests/build_sim.py b/torax/config/tests/build_sim.py index dcbe9d57..f461cc4c 100644 --- a/torax/config/tests/build_sim.py +++ b/torax/config/tests/build_sim.py @@ -20,20 +20,16 @@ from torax.config import build_sim from torax.config import runtime_params as runtime_params_lib from torax.config import runtime_params_slice -from torax.geometry import circular_geometry -from torax.geometry import geometry_provider -from torax.geometry import standard_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import set_tped_nped from torax.sources import runtime_params as source_runtime_params_lib from torax.stepper import linear_theta_method from torax.stepper import nonlinear_theta_method -from torax.stepper import runtime_params as stepper_params from torax.time_step_calculator import chi_time_step_calculator from torax.time_step_calculator import fixed_time_step_calculator from torax.transport_model import constant as constant_transport from torax.transport_model import critical_gradient as critical_gradient_transport from torax.transport_model import qlknn_transport_model -from torax.transport_model import runtime_params as transport_model_params class BuildSimTest(parameterized.TestCase): @@ -176,7 +172,7 @@ def test_general_runtime_params_with_time_dependent_args(self): self.assertEqual(runtime_params.plasma_composition.main_ion, 'D') self.assertEqual(runtime_params.profile_conditions.ne_is_fGW, False) self.assertEqual(runtime_params.output_dir, '/tmp/this/is/a/test') - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, @@ -195,110 +191,6 @@ def test_general_runtime_params_with_time_dependent_args(self): dynamic_runtime_params_slice.numerics.resistivity_mult, 0.6 ) - def test_missing_geometry_type_raises_error(self): - with self.assertRaises(ValueError): - build_sim.build_geometry_provider_from_config({}) - - def test_build_circular_geometry(self): - geo_provider = build_sim.build_geometry_provider_from_config({ - 'geometry_type': 'circular', - 'n_rho': 5, # override a default. - }) - self.assertIsInstance( - geo_provider, geometry_provider.ConstantGeometryProvider - ) - geo = geo_provider(t=0) - np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 5) - np.testing.assert_array_equal(geo.B0, 5.3) # test a default. - - def test_build_geometry_from_chease(self): - geo_provider = build_sim.build_geometry_provider_from_config( - { - 'geometry_type': 'chease', - 'n_rho': 5, # override a default. - }, - ) - self.assertIsInstance( - geo_provider, geometry_provider.ConstantGeometryProvider - ) - self.assertIsInstance(geo_provider(t=0), standard_geometry.StandardGeometry) - np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 5) - - def test_build_time_dependent_geometry_from_chease(self): - """Tests correctness of config constraints with time-dependent geometry.""" - - base_config = { - 'geometry_type': 'chease', - 'Ip_from_parameters': True, - 'n_rho': 10, # overrides the default - 'geometry_configs': { - 0.0: { - 'geometry_file': 'ITER_hybrid_citrin_equil_cheasedata.mat2cols', - 'Rmaj': 6.2, - 'Rmin': 2.0, - 'B0': 5.3, - }, - 1.0: { - 'geometry_file': 'ITER_hybrid_citrin_equil_cheasedata.mat2cols', - 'Rmaj': 6.2, - 'Rmin': 2.0, - 'B0': 5.3, - }, - }, - } - - # Test valid config - geo_provider = build_sim.build_geometry_provider_from_config(base_config) - self.assertIsInstance( - geo_provider, standard_geometry.StandardGeometryProvider - ) - self.assertIsInstance(geo_provider(t=0), standard_geometry.StandardGeometry) - np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 10) - - # Test invalid configs: - for param, value in zip( - ['n_rho', 'Ip_from_parameters', 'geometry_dir'], [5, True, '.'] - ): - for time_key in [0.0, 1.0]: - invalid_config = base_config.copy() - invalid_config['geometry_configs'][time_key][param] = value - with self.assertRaises(ValueError): - build_sim.build_geometry_provider_from_config(invalid_config) - - # pylint: disable=invalid-name - def test_chease_geometry_updates_Ip(self): - """Tests that the Ip is updated when using chease geometry.""" - runtime_params = runtime_params_lib.GeneralRuntimeParams() - original_Ip_tot = runtime_params.profile_conditions.Ip_tot - geo_provider = build_sim.build_geometry_provider_from_config({ - 'geometry_type': 'chease', - 'Ip_from_parameters': ( - False - ), # this will force update runtime_params.Ip_tot - }) - runtime_params_provider = ( - runtime_params_slice.DynamicRuntimeParamsSliceProvider( - runtime_params=runtime_params, - transport=transport_model_params.RuntimeParams(), - sources={}, - stepper=stepper_params.RuntimeParams(), - torax_mesh=geo_provider.torax_mesh, - ) - ) - dynamic_slice, geo = ( - runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry( - t=0, - dynamic_runtime_params_slice_provider=runtime_params_provider, - geometry_provider=geo_provider, - ) - ) - self.assertIsInstance(geo, standard_geometry.StandardGeometry) - self.assertIsNotNone(dynamic_slice) - self.assertNotEqual( - dynamic_slice.profile_conditions.Ip_tot, original_Ip_tot - ) - # pylint: enable=invalid-name - def test_empty_source_config_only_has_defaults_turned_off(self): """Tests that an empty source config has all sources turned off.""" source_models_builder = build_sim.build_sources_builder_from_config({}) @@ -504,5 +396,6 @@ def test_build_time_step_calculator_from_config( ) self.assertIsInstance(time_stepper, expected_type) + if __name__ == '__main__': absltest.main() diff --git a/torax/config/tests/numerics.py b/torax/config/tests/numerics.py index 4f5f7b72..b7b2a19e 100644 --- a/torax/config/tests/numerics.py +++ b/torax/config/tests/numerics.py @@ -18,7 +18,7 @@ from absl.testing import parameterized from torax import interpolated_param from torax.config import numerics -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model class NumericsTest(parameterized.TestCase): @@ -26,7 +26,7 @@ class NumericsTest(parameterized.TestCase): def test_numerics_make_provider(self): nums = numerics.Numerics() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = nums.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) @@ -35,7 +35,7 @@ def test_interpolated_vars_are_only_constructed_once( ): """Tests that interpolated vars are only constructed once.""" nums = numerics.Numerics() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = nums.make_provider(geo.torax_mesh) interpolated_params = {} for field in provider: diff --git a/torax/config/tests/plasma_composition.py b/torax/config/tests/plasma_composition.py index 20a81264..bdf3fbba 100644 --- a/torax/config/tests/plasma_composition.py +++ b/torax/config/tests/plasma_composition.py @@ -20,7 +20,7 @@ from torax import charge_states from torax import interpolated_param from torax.config import plasma_composition -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model class PlasmaCompositionTest(parameterized.TestCase): @@ -29,7 +29,7 @@ class PlasmaCompositionTest(parameterized.TestCase): def test_plasma_composition_make_provider(self): """Checks provider construction with no issues.""" pc = plasma_composition.PlasmaComposition() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = pc.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) @@ -40,7 +40,7 @@ def test_plasma_composition_make_provider(self): ) def test_zeff_accepts_float_inputs(self, zeff: float): """Tests that zeff accepts a single float input.""" - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() pc = plasma_composition.PlasmaComposition(Zeff=zeff) provider = pc.make_provider(geo.torax_mesh) dynamic_pc = provider.build_dynamic_params(t=0.0) @@ -63,7 +63,7 @@ def test_zeff_and_zeff_face_match_expected(self): 1.0: {0.0: 1.8, 0.5: 2.1, 1.0: 2.4}, } - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() pc = plasma_composition.PlasmaComposition(Zeff=zeff_profile) provider = pc.make_provider(geo.torax_mesh) @@ -102,7 +102,7 @@ def test_interpolated_vars_are_only_constructed_once( ): """Tests that interpolated vars are only constructed once.""" pc = plasma_composition.PlasmaComposition() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = pc.make_provider(geo.torax_mesh) interpolated_params = {} for field in provider: diff --git a/torax/config/tests/profile_conditions.py b/torax/config/tests/profile_conditions.py index cb2d9998..7f73a3ad 100644 --- a/torax/config/tests/profile_conditions.py +++ b/torax/config/tests/profile_conditions.py @@ -20,7 +20,7 @@ from torax import interpolated_param from torax.config import config_args from torax.config import profile_conditions -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model import xarray as xr @@ -30,7 +30,7 @@ class ProfileConditionsTest(parameterized.TestCase): def test_profile_conditions_make_provider(self): pc = profile_conditions.ProfileConditions() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = pc.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) @@ -46,7 +46,7 @@ def test_profile_conditions_sets_Te_bound_right_correctly( Te={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}}, Te_bound_right=Te_bound_right, ) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = pc.make_provider(geo.torax_mesh) dcs = provider.build_dynamic_params(t=0.0) self.assertEqual(dcs.Te_bound_right, expected_initial_value) @@ -65,7 +65,7 @@ def test_profile_conditions_sets_Ti_bound_right_correctly( Ti={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}}, Ti_bound_right=Ti_bound_right, ) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = pc.make_provider(geo.torax_mesh) dcs = provider.build_dynamic_params(t=0.0) self.assertEqual(dcs.Ti_bound_right, expected_initial_value) @@ -84,7 +84,7 @@ def test_profile_conditions_sets_ne_bound_right_correctly( ne={0: {0: 1.0, 1: 2.0}, 1.5: {0: 100.0, 1: 200.0}}, ne_bound_right=ne_bound_right, ) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = pc.make_provider(geo.torax_mesh) dcs = provider.build_dynamic_params(t=0.0) self.assertEqual(dcs.ne_bound_right, expected_initial_value) @@ -126,7 +126,7 @@ def test_profile_conditions_sets_psi_correctly( self, psi, expected_initial_value, expected_second_value ): """Tests that psi is set correctly.""" - geo = circular_geometry.build_circular_geometry(n_rho=4) + geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry() pc = profile_conditions.ProfileConditions( psi=psi, ) @@ -147,7 +147,7 @@ def test_interpolated_vars_are_only_constructed_once( ): """Tests that interpolated vars are only constructed once.""" pc = profile_conditions.ProfileConditions() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = pc.make_provider(geo.torax_mesh) interpolated_params = {} for field in provider: diff --git a/torax/config/tests/runtime_params.py b/torax/config/tests/runtime_params.py index f5cfd34c..2b68089d 100644 --- a/torax/config/tests/runtime_params.py +++ b/torax/config/tests/runtime_params.py @@ -21,7 +21,7 @@ from torax.config import config_args from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model # pylint: disable=invalid-name @@ -137,7 +137,9 @@ def test_runtime_params_make_provider(self): runtime_params = general_runtime_params.GeneralRuntimeParams( profile_conditions=profile_conditions_lib.ProfileConditions() ) - torax_mesh = circular_geometry.build_circular_geometry().torax_mesh + torax_mesh = ( + geometry_pydantic_model.CircularConfig().build_geometry().torax_mesh + ) runtime_params_provider = runtime_params.make_provider(torax_mesh) runtime_params_provider.build_dynamic_params(0.0) diff --git a/torax/config/tests/runtime_params_slice.py b/torax/config/tests/runtime_params_slice.py index 8e9ce64d..45352d6e 100644 --- a/torax/config/tests/runtime_params_slice.py +++ b/torax/config/tests/runtime_params_slice.py @@ -23,7 +23,7 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice as runtime_params_slice_lib -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import set_tped_nped from torax.sources import electron_density_sources from torax.sources import generic_current_source @@ -37,7 +37,7 @@ class RuntimeParamsSliceTest(parameterized.TestCase): def setUp(self): super().setUp() - self._geo = circular_geometry.build_circular_geometry() + self._geo = geometry_pydantic_model.CircularConfig().build_geometry() def test_dynamic_slice_can_be_input_to_jitted_function(self): """Tests that the slice can be input to a jitted function.""" @@ -351,7 +351,7 @@ def test_profile_conditions_set_electron_temperature_and_boundary_condition( runtime_params = general_runtime_params.GeneralRuntimeParams( profile_conditions=profile_conditions, ) - geo = circular_geometry.build_circular_geometry(n_rho=4) + geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry() dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, torax_mesh=geo.torax_mesh, @@ -394,7 +394,7 @@ def test_profile_conditions_set_electron_density_and_boundary_condition( ne_is_fGW=ne_is_fGW, ), ) - geo = circular_geometry.build_circular_geometry(n_rho=4) + geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry() dcs = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, @@ -427,7 +427,7 @@ def test_update_dynamic_slice_provider_updates_runtime_params( Ti_bound_right={0.0: 1.0, 1.0: 2.0}, ), ) - geo = circular_geometry.build_circular_geometry(n_rho=4) + geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry() provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, torax_mesh=geo.torax_mesh, @@ -463,7 +463,7 @@ def test_update_dynamic_slice_provider_updates_sources( source_models_builder.runtime_params[ generic_current_source.GenericCurrentSource.SOURCE_NAME ].Iext = 1.0 - geo = circular_geometry.build_circular_geometry(n_rho=4) + geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry() provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, sources=source_models_builder.runtime_params, @@ -519,7 +519,7 @@ def test_update_dynamic_slice_provider_updates_transport( """Tests that the dynamic slice provider can be updated.""" runtime_params = general_runtime_params.GeneralRuntimeParams() transport = transport_params_lib.RuntimeParams(De_inner=1.0) - geo = circular_geometry.build_circular_geometry(n_rho=4) + geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry() provider = runtime_params_slice_lib.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, torax_mesh=geo.torax_mesh, diff --git a/torax/fvm/tests/calc_coeffs_test.py b/torax/fvm/tests/calc_coeffs_test.py index 07f65b52..296c3548 100644 --- a/torax/fvm/tests/calc_coeffs_test.py +++ b/torax/fvm/tests/calc_coeffs_test.py @@ -20,7 +20,7 @@ from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice as runtime_params_slice_lib from torax.fvm import calc_coeffs -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import set_tped_nped from torax.sources import runtime_params as source_runtime_params from torax.sources import source_profile_builders @@ -52,7 +52,9 @@ def test_calc_coeffs_smoke_test( predictor_corrector=False, theta_imp=theta_imp, ) - geo = circular_geometry.build_circular_geometry(n_rho=num_cells) + geo = geometry_pydantic_model.CircularConfig( + n_rho=num_cells + ).build_geometry() transport_model_builder = ( constant_transport_model.ConstantTransportModelBuilder( runtime_params=constant_transport_model.RuntimeParams( diff --git a/torax/fvm/tests/fvm_test.py b/torax/fvm/tests/fvm_test.py index eafcc8f2..e2cf766c 100644 --- a/torax/fvm/tests/fvm_test.py +++ b/torax/fvm/tests/fvm_test.py @@ -29,7 +29,7 @@ from torax.fvm import cell_variable from torax.fvm import implicit_solve_block from torax.fvm import residual_and_loss -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import set_tped_nped from torax.sources import runtime_params as source_runtime_params from torax.sources import source_profile_builders @@ -224,7 +224,9 @@ def test_nonlinear_solve_block_loss_minimum( predictor_corrector=False, theta_imp=theta_imp, ) - geo = circular_geometry.build_circular_geometry(n_rho=num_cells) + geo = geometry_pydantic_model.CircularConfig( + n_rho=num_cells + ).build_geometry() transport_model_builder = ( constant_transport_model.ConstantTransportModelBuilder( runtime_params=constant_transport_model.RuntimeParams( @@ -392,7 +394,9 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): pedestal_model_builder = ( set_tped_nped.SetTemperatureDensityPedestalModelBuilder() ) - geo = circular_geometry.build_circular_geometry(n_rho=num_cells) + geo = geometry_pydantic_model.CircularConfig( + n_rho=num_cells + ).build_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, @@ -413,7 +417,9 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self): stepper=stepper_params, ) ) - geo = circular_geometry.build_circular_geometry(n_rho=num_cells) + geo = geometry_pydantic_model.CircularConfig( + n_rho=num_cells + ).build_geometry() source_models = source_models_builder() initial_core_profiles = core_profile_setters.initial_core_profiles( static_runtime_params_slice, @@ -515,7 +521,9 @@ def test_theta_residual_uses_updated_boundary_conditions(self): predictor_corrector=False, theta_imp=0.0, ) - geo = circular_geometry.build_circular_geometry(n_rho=num_cells) + geo = geometry_pydantic_model.CircularConfig( + n_rho=num_cells + ).build_geometry() transport_model_builder = ( constant_transport_model.ConstantTransportModelBuilder( runtime_params=constant_transport_model.RuntimeParams( diff --git a/torax/geometry/circular_geometry.py b/torax/geometry/circular_geometry.py index 31aa47ba..7833a935 100644 --- a/torax/geometry/circular_geometry.py +++ b/torax/geometry/circular_geometry.py @@ -25,12 +25,12 @@ # external physics implementations # pylint: disable=invalid-name def build_circular_geometry( - n_rho: int = 25, - elongation_LCFS: float = 1.72, - Rmaj: float = 6.2, - Rmin: float = 2.0, - B0: float = 5.3, - hires_fac: int = 4, + n_rho: int, + elongation_LCFS: float, + Rmaj: float, + Rmin: float, + B0: float, + hires_fac: int, ) -> geometry.Geometry: """Constructs a circular Geometry instance used for testing only. diff --git a/torax/geometry/pydantic_model.py b/torax/geometry/pydantic_model.py new file mode 100644 index 00000000..edd75b27 --- /dev/null +++ b/torax/geometry/pydantic_model.py @@ -0,0 +1,362 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pydantic model for geometry.""" +from collections.abc import Callable, Mapping +import inspect +from typing import Annotated, Any, Literal, TypeAlias, TypeVar + +import pydantic +from torax.geometry import circular_geometry +from torax.geometry import geometry +from torax.geometry import geometry_provider +from torax.geometry import standard_geometry +from torax.torax_pydantic import torax_pydantic +import typing_extensions +# Using invalid-name because we are using the same naming convention as the +# external physics implementations +# pylint: disable=invalid-name +T = TypeVar('T') + +LY_OBJECT_TYPE: TypeAlias = ( + str | Mapping[str, torax_pydantic.NumpyArray | float] +) + +TIME_INVARIANT = torax_pydantic.TIME_INVARIANT + + +class CircularConfig(torax_pydantic.BaseModelFrozen): + """Pydantic model for the circular geometry config. + + Attributes: + geometry_type: Always set to 'circular'. + n_rho: Number of radial grid points. + hires_fac: Only used when the initial condition ``psi`` is from plasma + current. Sets up a higher resolution mesh with ``nrho_hires = nrho * + hi_res_fac``, used for ``j`` to ``psi`` conversions. + Rmaj: Major radius (R) in meters. + Rmin: Minor radius (a) in meters. + B0: Vacuum toroidal magnetic field on axis [T]. + elongation_LCFS: Sets the plasma elongation used for volume, area and + q-profile corrections. + """ + + geometry_type: Annotated[Literal['circular'], TIME_INVARIANT] = 'circular' + n_rho: Annotated[pydantic.PositiveInt, TIME_INVARIANT] = 25 + hires_fac: pydantic.PositiveInt = 4 + Rmaj: torax_pydantic.Meter = 6.2 + Rmin: torax_pydantic.Meter = 2.0 + B0: torax_pydantic.Tesla = 5.3 + elongation_LCFS: pydantic.PositiveFloat = 1.72 + + @pydantic.model_validator(mode='after') + def _check_fields(self) -> typing_extensions.Self: + if not self.Rmaj >= self.Rmin: + raise ValueError('Rmin must be less than or equal to Rmaj.') + return self + + def build_geometry(self) -> geometry.Geometry: + return circular_geometry.build_circular_geometry( + n_rho=self.n_rho, + elongation_LCFS=self.elongation_LCFS, + Rmaj=self.Rmaj, + Rmin=self.Rmin, + B0=self.B0, + hires_fac=self.hires_fac, + ) + + +class CheaseConfig(torax_pydantic.BaseModelFrozen): + """Pydantic model for the CHEASE geometry. + + Attributes: + geometry_type: Always set to 'chease'. + n_rho: Number of radial grid points. + hires_fac: Only used when the initial condition ``psi`` is from plasma + current. Sets up a higher resolution mesh with ``nrho_hires = nrho * + hi_res_fac``, used for ``j`` to ``psi`` conversions. + geometry_dir: Optionally overrides the `TORAX_GEOMETRY_DIR` environment + variable. + Ip_from_parameters: Toggles whether total plasma current is read from the + configuration file, or from the geometry file. If True, then the `psi` + calculated from the geometry file is scaled to match the desired `I_p`. + Rmaj: Major radius (R) in meters. + Rmin: Minor radius (a) in meters. + B0: Vacuum toroidal magnetic field on axis [T]. + """ + + geometry_type: Annotated[Literal['chease'], TIME_INVARIANT] = 'chease' + n_rho: Annotated[pydantic.PositiveInt, TIME_INVARIANT] = 25 + hires_fac: pydantic.PositiveInt = 4 + geometry_dir: Annotated[str | None, TIME_INVARIANT] = None + Ip_from_parameters: Annotated[bool, TIME_INVARIANT] = True + geometry_file: str = 'ITER_hybrid_citrin_equil_cheasedata.mat2cols' + Rmaj: torax_pydantic.Meter = 6.2 + Rmin: torax_pydantic.Meter = 2.0 + B0: torax_pydantic.Tesla = 5.3 + + @pydantic.model_validator(mode='after') + def _check_fields(self) -> typing_extensions.Self: + if not self.Rmaj >= self.Rmin: + raise ValueError('Rmin must be less than or equal to Rmaj.') + return self + + def build_geometry(self) -> standard_geometry.StandardGeometry: + + return standard_geometry.build_standard_geometry( + _apply_relevant_kwargs( + standard_geometry.StandardGeometryIntermediates.from_chease, + self.__dict__, + ) + ) + + +class FBTConfig(torax_pydantic.BaseModelFrozen): + """Pydantic model for the FBT geometry. + + Attributes: + geometry_type: Always set to 'fbt'. + n_rho: Number of radial grid points. + hires_fac: Only used when the initial condition ``psi`` is from plasma + current. Sets up a higher resolution mesh with ``nrho_hires = nrho * + hi_res_fac``, used for ``j`` to ``psi`` conversions. + geometry_dir: Optionally overrides the `TORAX_GEOMETRY_DIR` environment + variable. + Ip_from_parameters: Toggles whether total plasma current is read from the + configuration file, or from the geometry file. If True, then the `psi` + calculated from the geometry file is scaled to match the desired `I_p`. + hires_fac: Sets up a higher resolution mesh with ``nrho_hires = nrho * + hi_res_fac``, used for ``j`` to ``psi`` conversions. + LY_object: Sets a single-slice FBT LY geometry file to be loaded, or + alternatively a dict directly containing a single time slice of LY data. + LY_bundle_object: Sets the FBT LY bundle file to be loaded, corresponding to + multiple time-slices, or alternatively a dict directly containing all + time-slices of LY data. + LY_to_torax_times: Sets the TORAX simulation times corresponding to the + individual slices in the FBT LY bundle file. If not provided, then the + times are taken from the LY_bundle_file itself. The length of the array + must match the number of slices in the bundle. + L_object: Sets the FBT L geometry file loaded, or alternatively a dict + directly containing the L data. + """ + + geometry_type: Annotated[Literal['fbt'], TIME_INVARIANT] = 'fbt' + n_rho: Annotated[pydantic.PositiveInt, TIME_INVARIANT] = 25 + hires_fac: pydantic.PositiveInt = 4 + geometry_dir: Annotated[str | None, TIME_INVARIANT] = None + Ip_from_parameters: Annotated[bool, TIME_INVARIANT] = True + LY_object: LY_OBJECT_TYPE | None = None + LY_bundle_object: LY_OBJECT_TYPE | None = None + LY_to_torax_times: torax_pydantic.NumpyArray | None = None + L_object: LY_OBJECT_TYPE | None = None + + @pydantic.model_validator(mode='before') + @classmethod + def _conform_data(cls, data: dict[str, Any]) -> dict[str, Any]: + # Remove unused fields from the data dict that come from file loading. + for obj in ('L_object', 'LY_object'): + if obj in data and isinstance(data[obj], dict): + for k in ('__header__', '__version__', '__globals__', 'shot'): + data[obj].pop(k, None) + return data + + @pydantic.model_validator(mode='after') + def _validate_model(self) -> typing_extensions.Self: + if self.LY_bundle_object is not None and self.LY_object is not None: + raise ValueError( + "Cannot use 'LY_object' together with a bundled FBT file" + ) + if self.LY_to_torax_times is not None and self.LY_bundle_object is None: + raise ValueError( + 'LY_bundle_object must be set when using LY_to_torax_times.' + ) + return self + + def build_geometry(self) -> standard_geometry.StandardGeometry: + + return standard_geometry.build_standard_geometry( + _apply_relevant_kwargs( + standard_geometry.StandardGeometryIntermediates.from_fbt_single_slice, + self.__dict__, + ) + ) + + # TODO(b/398191165): Remove this branch once the FBT bundle logic is + # redesigned. + def build_fbt_geometry_provider_from_bundle( + self, + ) -> geometry_provider.GeometryProvider: + """Builds a `GeometryProvider` from the input config.""" + intermediates = _apply_relevant_kwargs( + standard_geometry.StandardGeometryIntermediates.from_fbt_bundle, + self.__dict__, + ) + geometries = { + t: standard_geometry.build_standard_geometry(intermediates[t]) + for t in intermediates + } + return standard_geometry.StandardGeometryProvider.create_provider( + geometries + ) + + +class EQDSKConfig(torax_pydantic.BaseModelFrozen): + """Pydantic model for the EQDSK geometry. + + Attributes: + geometry_type: Always set to 'eqdsk'. + n_rho: Number of radial grid points. + hires_fac: Only used when the initial condition ``psi`` is from plasma + current. Sets up a higher resolution mesh with ``nrho_hires = nrho * + hi_res_fac``, used for ``j`` to ``psi`` conversions. + geometry_dir: Optionally overrides the `TORAX_GEOMETRY_DIR` environment + variable. + Ip_from_parameters: Toggles whether total plasma current is read from the + configuration file, or from the geometry file. If True, then the `psi` + calculated from the geometry file is scaled to match the desired `I_p`. + n_surfaces: Number of surfaces for which flux surface averages are + calculated. + last_surface_factor: Multiplication factor of the boundary poloidal flux, + used for the contour defining geometry terms at the LCFS on the TORAX + grid. Needed to avoid divergent integrations in diverted geometries. + """ + + geometry_type: Annotated[Literal['eqdsk'], TIME_INVARIANT] = 'eqdsk' + n_rho: Annotated[pydantic.PositiveInt, TIME_INVARIANT] = 25 + hires_fac: pydantic.PositiveInt = 4 + geometry_dir: Annotated[str | None, TIME_INVARIANT] = None + Ip_from_parameters: Annotated[bool, TIME_INVARIANT] = True + geometry_file: str = 'EQDSK_ITERhybrid_COCOS02.eqdsk' + n_surfaces: pydantic.PositiveInt = 100 + last_surface_factor: torax_pydantic.OpenUnitInterval = 0.99 + + def build_geometry(self) -> standard_geometry.StandardGeometry: + return standard_geometry.build_standard_geometry( + _apply_relevant_kwargs( + standard_geometry.StandardGeometryIntermediates.from_eqdsk, + self.__dict__, + ) + ) + + +class GeometryConfig(torax_pydantic.BaseModelFrozen): + """Pydantic model for a single geometry config.""" + + config: CircularConfig | CheaseConfig | FBTConfig | EQDSKConfig = ( + pydantic.Field(discriminator='geometry_type') + ) + + +class Geometry(torax_pydantic.BaseModelFrozen): + """Pydantic model for a geometry. + + This object can be constructed via `Geometry.from_dict(config)`, where + `config` is a dict described in + https://torax.readthedocs.io/en/latest/configuration.html#geometry. + + Attributes: + geometry_type: A `geometry.GeometryType` enum. + geometry_configs: Either a single `GeometryConfig` or a dict of + `GeometryConfig` objects, where the keys are times in seconds. + """ + + geometry_type: geometry.GeometryType + geometry_configs: GeometryConfig | dict[torax_pydantic.Second, GeometryConfig] + + @pydantic.model_validator(mode='before') + @classmethod + def _conform_data(cls, data: dict[str, Any]) -> dict[str, Any]: + + if 'geometry_type' not in data: + raise ValueError('geometry_type must be set in the input config.') + + geometry_type = data['geometry_type'] + # The geometry type can be an int if loading from JSON. + if isinstance(geometry_type, geometry.GeometryType | int): + return data + # Parse the user config dict. + elif isinstance(geometry_type, str): + return _conform_user_data(data) + else: + raise ValueError(f'Invalid value for geometry: {geometry_type}') + + def build_provider(self) -> geometry_provider.GeometryProvider: + # TODO(b/398191165): Remove this branch once the FBT bundle logic is + # redesigned. + if self.geometry_type == geometry.GeometryType.FBT: + if not isinstance(self.geometry_configs, dict): + assert isinstance(self.geometry_configs.config, FBTConfig) + if self.geometry_configs.config.LY_bundle_object is not None: + return ( + self.geometry_configs.config.build_fbt_geometry_provider_from_bundle() + ) + + if isinstance(self.geometry_configs, dict): + geometries = { + time: config.config.build_geometry() + for time, config in self.geometry_configs.items() + } + provider = ( + geometry_provider.TimeDependentGeometryProvider.create_provider + if self.geometry_type == geometry.GeometryType.CIRCULAR + else standard_geometry.StandardGeometryProvider.create_provider + ) + else: + geometries = self.geometry_configs.config.build_geometry() + provider = geometry_provider.ConstantGeometryProvider + + return provider(geometries) # pytype: disable=attribute-error + + +def _conform_user_data(data: dict[str, Any]) -> dict[str, Any]: + """Conform the user geometry dict to the pydantic model.""" + + if 'LY_bundle_object' in data and 'geometry_configs' in data: + raise ValueError( + 'Cannot use both `LY_bundle_object` and `geometry_configs` together.' + ) + + data_copy = data.copy() + # Useful to avoid failing if users mistakenly give the wrong case. + data_copy['geometry_type'] = data['geometry_type'].lower() + geometry_type = getattr(geometry.GeometryType, data['geometry_type'].upper()) + constructor_args = {'geometry_type': geometry_type} + configs_time_dependent = data_copy.pop('geometry_configs', None) + + if configs_time_dependent: + # geometry config has sequence of standalone geometry files. + if not isinstance(data['geometry_configs'], dict): + raise ValueError('geometry_configs must be a dict.') + constructor_args['geometry_configs'] = {} + for time, c_time_dependent in configs_time_dependent.items(): + gc = GeometryConfig.from_dict({'config': c_time_dependent | data_copy}) + constructor_args['geometry_configs'][time] = gc + if x := set(gc.config.time_invariant_fields()).intersection( + c_time_dependent.keys() + ): + raise ValueError( + 'The following parameters cannot be set per geometry_config:' + f' {", ".join(x)}' + ) + else: + constructor_args['geometry_configs'] = {'config': data_copy} + + return constructor_args + + +def _apply_relevant_kwargs(f: Callable[..., T], kwargs: Mapping[str, Any]) -> T: + """Apply only the kwargs actually used by the function.""" + relevant_kwargs = [i.name for i in inspect.signature(f).parameters.values()] + kwargs = {k: kwargs[k] for k in relevant_kwargs} + return f(**kwargs) diff --git a/torax/geometry/standard_geometry.py b/torax/geometry/standard_geometry.py index 5eb9b602..31cbf07e 100644 --- a/torax/geometry/standard_geometry.py +++ b/torax/geometry/standard_geometry.py @@ -234,14 +234,14 @@ def __post_init__(self): @classmethod def from_chease( cls, - geometry_dir: str | None = None, - geometry_file: str = 'ITER_hybrid_citrin_equil_cheasedata.mat2cols', - Ip_from_parameters: bool = True, - n_rho: int = 25, - Rmaj: float = 6.2, - Rmin: float = 2.0, - B0: float = 5.3, - hires_fac: int = 4, + geometry_dir: str | None, + geometry_file: str, + Ip_from_parameters: bool, + n_rho: int, + Rmaj: float, + Rmin: float, + B0: float, + hires_fac: int, ) -> StandardGeometryIntermediates: """Constructs a StandardGeometryIntermediates from a CHEASE file. @@ -575,13 +575,13 @@ def _from_fbt( @classmethod def from_eqdsk( cls, - geometry_dir: str | None = None, - geometry_file: str = 'EQDSK_ITERhybrid_COCOS02.eqdsk', - hires_fac: int = 4, - Ip_from_parameters: bool = True, - n_rho: int = 25, - n_surfaces: int = 100, - last_surface_factor: float = 0.99, + geometry_dir: str | None, + geometry_file: str, + hires_fac: int, + Ip_from_parameters: bool, + n_rho: int, + n_surfaces: int, + last_surface_factor: float, ) -> StandardGeometryIntermediates: """Constructs a StandardGeometryIntermediates from EQDSK. diff --git a/torax/geometry/tests/circular_geometry_test.py b/torax/geometry/tests/circular_geometry_test.py index 6b751886..a53c8365 100644 --- a/torax/geometry/tests/circular_geometry_test.py +++ b/torax/geometry/tests/circular_geometry_test.py @@ -40,10 +40,8 @@ def test_build_geometry_provider_from_circular(self): B0=5.3, hires_fac=4, ) - provider = ( - geometry_provider.TimeDependentGeometryProvider.create_provider( - {0.0: geo_0, 10.0: geo_1} - ) + provider = geometry_provider.TimeDependentGeometryProvider.create_provider( + {0.0: geo_0, 10.0: geo_1} ) geo = provider(5.0) np.testing.assert_allclose(geo.Rmaj, 6.7) @@ -55,7 +53,14 @@ def test_circular_geometry_can_be_input_to_jitted_function(self): def foo(geo: geometry.Geometry): return geo.Rmaj - geo = circular_geometry.build_circular_geometry() + geo = circular_geometry.build_circular_geometry( + n_rho=25, + elongation_LCFS=1.72, + Rmaj=6.2, + Rmin=2.0, + B0=5.3, + hires_fac=4, + ) # Make sure you can call the function with geo as an arg. foo(geo) diff --git a/torax/geometry/tests/geometry_provider_test.py b/torax/geometry/tests/geometry_provider_test.py index b3f6128d..6dbfaad4 100644 --- a/torax/geometry/tests/geometry_provider_test.py +++ b/torax/geometry/tests/geometry_provider_test.py @@ -16,27 +16,27 @@ from absl.testing import absltest import numpy as np -from torax.geometry import circular_geometry from torax.geometry import geometry from torax.geometry import geometry_provider +from torax.geometry import pydantic_model as geometry_pydantic_model class GeometryProviderTest(absltest.TestCase): def test_constant_geometry_return_same_value(self): - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = geometry_provider.ConstantGeometryProvider(geo) self.assertEqual(provider(0.0), geo) self.assertEqual(provider(1.0), geo) self.assertEqual(provider(2.0), geo) def test_time_dependent_geometry_return_different_values(self): - geo_0 = circular_geometry.build_circular_geometry( + geo_0 = geometry_pydantic_model.CircularConfig( Rmaj=6.2, Rmin=2.0, B0=5.3 - ) - geo_1 = circular_geometry.build_circular_geometry( + ).build_geometry() + geo_1 = geometry_pydantic_model.CircularConfig( Rmaj=7.4, Rmin=1.0, B0=6.5 - ) + ).build_geometry() provider = geometry_provider.TimeDependentGeometryProvider.create_provider( {0.0: geo_0, 10.0: geo_1} ) @@ -46,7 +46,7 @@ def test_time_dependent_geometry_return_different_values(self): np.testing.assert_allclose(geo.B0, 5.9) def test_time_dependent_different_types(self): - geo_0 = circular_geometry.build_circular_geometry() + geo_0 = geometry_pydantic_model.CircularConfig().build_geometry() geo_1 = dataclasses.replace(geo_0, geometry_type=geometry.GeometryType.FBT) with self.assertRaisesRegex( ValueError, "All geometries must have the same geometry type." @@ -56,8 +56,8 @@ def test_time_dependent_different_types(self): ) def test_time_dependent_different_meshes(self): - geo_0 = circular_geometry.build_circular_geometry(n_rho=25) - geo_1 = circular_geometry.build_circular_geometry(n_rho=50) + geo_0 = geometry_pydantic_model.CircularConfig(n_rho=25).build_geometry() + geo_1 = geometry_pydantic_model.CircularConfig(n_rho=50).build_geometry() with self.assertRaisesRegex( ValueError, "All geometries must have the same mesh." ): @@ -66,7 +66,7 @@ def test_time_dependent_different_meshes(self): ) def test_none_z_magnetic_axis_stays_none_time_dependent(self): - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() geo = dataclasses.replace(geo, _z_magnetic_axis=None) provider = geometry_provider.TimeDependentGeometryProvider.create_provider( {0.0: geo, 10.0: geo} diff --git a/torax/geometry/tests/geometry_test.py b/torax/geometry/tests/geometry_test.py index 30b9c09d..bcfcab4b 100644 --- a/torax/geometry/tests/geometry_test.py +++ b/torax/geometry/tests/geometry_test.py @@ -19,8 +19,8 @@ import jax from jax import numpy as jnp import numpy as np -from torax.geometry import circular_geometry from torax.geometry import geometry +from torax.geometry import pydantic_model as geometry_pydantic_model class GeometryTest(parameterized.TestCase): @@ -49,7 +49,7 @@ def test_face_to_cell(self, n_rho, seed): np.testing.assert_allclose(cell_jax, cell_np) def test_none_z_magnetic_axis_raises_an_error(self): - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() geo = dataclasses.replace(geo, _z_magnetic_axis=None) with self.subTest('non_jitted_function'): @@ -68,9 +68,15 @@ def test_none_z_magnetic_axis_raises_an_error(self): def test_stack_geometries_circular_geometries(self): """Test stack_geometries for circular geometries.""" # Create a few different geometries - geo0 = circular_geometry.build_circular_geometry(Rmaj=1.0, B0=2.0, n_rho=10) - geo1 = circular_geometry.build_circular_geometry(Rmaj=1.5, B0=2.5, n_rho=10) - geo2 = circular_geometry.build_circular_geometry(Rmaj=2.0, B0=3.0, n_rho=10) + geo0 = geometry_pydantic_model.CircularConfig( + Rmin=0.5, Rmaj=1.0, B0=2.0, n_rho=10 + ).build_geometry() + geo1 = geometry_pydantic_model.CircularConfig( + Rmin=0.5, Rmaj=1.5, B0=2.5, n_rho=10 + ).build_geometry() + geo2 = geometry_pydantic_model.CircularConfig( + Rmin=0.5, Rmaj=2.0, B0=3.0, n_rho=10 + ).build_geometry() # Stack them stacked_geo = geometry.stack_geometries([geo0, geo1, geo2]) @@ -118,10 +124,12 @@ def test_stack_geometries_error_handling_empty_list(self): def test_stack_geometries_error_handling_different_mesh_sizes(self): """Test error handling for stack_geometries with different mesh sizes.""" - geo0 = circular_geometry.build_circular_geometry(Rmaj=1.0, B0=2.0, n_rho=10) - geo_diff_mesh = circular_geometry.build_circular_geometry( - Rmaj=1.0, B0=2.0, n_rho=20 - ) # Different n_rho + geo0 = geometry_pydantic_model.CircularConfig( + Rmin=0.5, Rmaj=1.0, B0=2.0, n_rho=10 + ).build_geometry() + geo_diff_mesh = geometry_pydantic_model.CircularConfig( + Rmin=0.5, Rmaj=1.0, B0=2.0, n_rho=20 + ).build_geometry() # Different n_rho with self.assertRaisesRegex( ValueError, 'All geometries must have the same mesh.' ): @@ -129,7 +137,9 @@ def test_stack_geometries_error_handling_different_mesh_sizes(self): def test_stack_geometries_error_handling_different_geometry_types(self): """Test different geometry type error handling for stack_geometries.""" - geo0 = circular_geometry.build_circular_geometry(Rmaj=1.0, B0=2.0, n_rho=10) + geo0 = geometry_pydantic_model.CircularConfig( + Rmin=0.5, Rmaj=1.0, B0=2.0, n_rho=10 + ).build_geometry() geo_diff_geometry_type = dataclasses.replace( geo0, geometry_type=geometry.GeometryType(3) ) diff --git a/torax/geometry/tests/pydantic_model_test.py b/torax/geometry/tests/pydantic_model_test.py new file mode 100644 index 00000000..eda2b5e9 --- /dev/null +++ b/torax/geometry/tests/pydantic_model_test.py @@ -0,0 +1,194 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +from torax.config import runtime_params as runtime_params_lib +from torax.config import runtime_params_slice +from torax.geometry import geometry_provider +from torax.geometry import pydantic_model +from torax.geometry import standard_geometry +from torax.stepper import runtime_params as stepper_params +from torax.transport_model import runtime_params as transport_model_params + + +class PydanticModelTest(parameterized.TestCase): + + def test_missing_geometry_type_raises_error(self): + with self.assertRaisesRegex( + ValueError, 'geometry_type must be set in the input config' + ): + pydantic_model.Geometry.from_dict({}) + + def test_build_circular_geometry(self): + geo_provider = pydantic_model.Geometry.from_dict({ + 'geometry_type': 'circular', + 'n_rho': 5, # override a default. + }).build_provider() + + self.assertIsInstance( + geo_provider, geometry_provider.ConstantGeometryProvider + ) + geo = geo_provider(t=0) + np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 5) + np.testing.assert_array_equal(geo.B0, 5.3) # test a default. + + def test_build_geometry_from_chease(self): + geo_provider = pydantic_model.Geometry.from_dict( + { + 'geometry_type': 'chease', + 'n_rho': 5, # override a default. + }, + ).build_provider() + self.assertIsInstance( + geo_provider, geometry_provider.ConstantGeometryProvider + ) + self.assertIsInstance(geo_provider(t=0), standard_geometry.StandardGeometry) + np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 5) + + def test_build_time_dependent_geometry_from_chease(self): + """Tests correctness of config constraints with time-dependent geometry.""" + + base_config = { + 'geometry_type': 'chease', + 'Ip_from_parameters': True, + 'n_rho': 10, # overrides the default + 'geometry_configs': { + 0.0: { + 'geometry_file': 'ITER_hybrid_citrin_equil_cheasedata.mat2cols', + 'Rmaj': 6.2, + 'Rmin': 2.0, + 'B0': 5.3, + }, + 1.0: { + 'geometry_file': 'ITER_hybrid_citrin_equil_cheasedata.mat2cols', + 'Rmaj': 6.2, + 'Rmin': 2.0, + 'B0': 5.3, + }, + }, + } + + # Test valid config + geo_provider = pydantic_model.Geometry.from_dict( + base_config + ).build_provider() + self.assertIsInstance( + geo_provider, standard_geometry.StandardGeometryProvider + ) + self.assertIsInstance(geo_provider(t=0), standard_geometry.StandardGeometry) + np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 10) + + @parameterized.parameters([ + dict(param='n_rho', value=5), + dict(param='Ip_from_parameters', value=True), + dict(param='geometry_dir', value='.'), + ]) + def test_build_time_dependent_geometry_from_chease_failure( + self, param, value + ): + + base_config = { + 'geometry_type': 'chease', + 'Ip_from_parameters': True, + 'n_rho': 10, # overrides the default + 'geometry_configs': { + 0.0: { + 'geometry_file': 'ITER_hybrid_citrin_equil_cheasedata.mat2cols', + 'Rmaj': 6.2, + 'Rmin': 2.0, + 'B0': 5.3, + }, + 1.0: { + 'geometry_file': 'ITER_hybrid_citrin_equil_cheasedata.mat2cols', + 'Rmaj': 6.2, + 'Rmin': 2.0, + 'B0': 5.3, + }, + }, + } + + # Test invalid configs: + for time_key in [0.0, 1.0]: + invalid_config = base_config.copy() + invalid_config['geometry_configs'][time_key][param] = value + with self.assertRaisesRegex( + ValueError, 'following parameters cannot be set per geometry_config' + ): + pydantic_model.Geometry.from_dict(invalid_config) + + # pylint: disable=invalid-name + def test_chease_geometry_updates_Ip(self): + """Tests that the Ip is updated when using chease geometry.""" + runtime_params = runtime_params_lib.GeneralRuntimeParams() + original_Ip_tot = runtime_params.profile_conditions.Ip_tot + geo_provider = pydantic_model.Geometry.from_dict({ + 'geometry_type': 'chease', + 'Ip_from_parameters': ( + False + ), # this will force update runtime_params.Ip_tot + }).build_provider() + runtime_params_provider = ( + runtime_params_slice.DynamicRuntimeParamsSliceProvider( + runtime_params=runtime_params, + transport=transport_model_params.RuntimeParams(), + sources={}, + stepper=stepper_params.RuntimeParams(), + torax_mesh=geo_provider.torax_mesh, + ) + ) + dynamic_slice, geo = ( + runtime_params_slice.get_consistent_dynamic_runtime_params_slice_and_geometry( + t=0, + dynamic_runtime_params_slice_provider=runtime_params_provider, + geometry_provider=geo_provider, + ) + ) + self.assertIsInstance(geo, standard_geometry.StandardGeometry) + self.assertIsNotNone(dynamic_slice) + self.assertNotEqual( + dynamic_slice.profile_conditions.Ip_tot, original_Ip_tot + ) + # pylint: enable=invalid-name + + @parameterized.parameters([ + dict(config=pydantic_model.CheaseConfig), + dict(config=pydantic_model.CircularConfig), + ]) + def test_rmin_rmax_ordering(self, config): + + with self.subTest('rmin_greater_than_rmaj'): + with self.assertRaisesRegex( + ValueError, 'Rmin must be less than or equal to Rmaj' + ): + config(Rmaj=1.0, Rmin=2.0) + + with self.subTest('negative_values'): + with self.assertRaises(ValueError): + config(Rmaj=-1.0, Rmin=-2.0) + + def test_failed_test(self): + config = { + 'geometry_type': 'eqdsk', + 'geometry_file': 'EQDSK_ITERhybrid_COCOS02.eqdsk', + 'Ip_from_parameters': True, + 'last_surface_factor': 0.99, + 'n_surfaces': 100, + } + pydantic_model.Geometry.from_dict(config) + + +if __name__ == '__main__': + absltest.main() diff --git a/torax/geometry/tests/standard_geometry_test.py b/torax/geometry/tests/standard_geometry_test.py index 302c4af9..5bd309e7 100644 --- a/torax/geometry/tests/standard_geometry_test.py +++ b/torax/geometry/tests/standard_geometry_test.py @@ -18,9 +18,9 @@ from absl.testing import parameterized import jax import numpy as np -from torax.config import build_sim from torax.geometry import geometry from torax.geometry import geometry_loader +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.geometry import standard_geometry # Internal import. @@ -67,8 +67,7 @@ def foo(geo: geometry.Geometry): foo(geo) def test_build_geometry_from_chease(self): - intermediate = standard_geometry.StandardGeometryIntermediates.from_chease() - standard_geometry.build_standard_geometry(intermediate) + geometry_pydantic_model.CheaseConfig().build_geometry() @parameterized.parameters([ dict(geometry_file='eqdsk_cocos02.eqdsk'), @@ -76,15 +75,12 @@ def test_build_geometry_from_chease(self): ]) def test_build_geometry_from_eqdsk(self, geometry_file): """Test that EQDSK geometries can be built.""" - intermediate = standard_geometry.StandardGeometryIntermediates.from_eqdsk( - geometry_file=geometry_file - ) - standard_geometry.build_standard_geometry(intermediate) + config = geometry_pydantic_model.EQDSKConfig(geometry_file=geometry_file) + config.build_geometry() def test_access_z_magnetic_axis_raises_error_for_chease_geometry(self): """Test that accessing z_magnetic_axis raises error for CHEASE geometry.""" - intermediate = standard_geometry.StandardGeometryIntermediates.from_chease() - geo = standard_geometry.build_standard_geometry(intermediate) + geo = geometry_pydantic_model.CheaseConfig().build_geometry() with self.assertRaisesRegex(ValueError, 'does not have a z magnetic axis'): geo.z_magnetic_axis() diff --git a/torax/pedestal_model/tests/set_pped_tpedratio_nped.py b/torax/pedestal_model/tests/set_pped_tpedratio_nped.py index 33180b3f..1870f9db 100644 --- a/torax/pedestal_model/tests/set_pped_tpedratio_nped.py +++ b/torax/pedestal_model/tests/set_pped_tpedratio_nped.py @@ -19,7 +19,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import set_pped_tpedratio_nped from torax.sources import source_models as source_models_lib @@ -31,7 +31,7 @@ class SetPressureTemperatureRatioAndDensityPedestalModelTest( def test_runtime_params_builds_dynamic_params(self): runtime_params = set_pped_tpedratio_nped.RuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) @@ -61,14 +61,14 @@ def test_build_and_call_pedestal_model( runtime_params = general_runtime_params.GeneralRuntimeParams() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, sources=source_models_builder.runtime_params, torax_mesh=geo.torax_mesh, pedestal=pedestal_runtime_params, ) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() builder = set_pped_tpedratio_nped.SetPressureTemperatureRatioAndDensityPedestalModelBuilder( runtime_params=pedestal_runtime_params ) diff --git a/torax/pedestal_model/tests/set_tped_nped.py b/torax/pedestal_model/tests/set_tped_nped.py index 51dd823c..b09432de 100644 --- a/torax/pedestal_model/tests/set_tped_nped.py +++ b/torax/pedestal_model/tests/set_tped_nped.py @@ -18,7 +18,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import set_tped_nped from torax.sources import source_models as source_models_lib @@ -28,7 +28,7 @@ class SetTemperatureDensityPedestalModelTest(parameterized.TestCase): def test_runtime_params_builds_dynamic_params(self): runtime_params = set_tped_nped.RuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) @@ -62,14 +62,14 @@ def test_build_and_call_pedestal_model( runtime_params = general_runtime_params.GeneralRuntimeParams() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, sources=source_models_builder.runtime_params, torax_mesh=geo.torax_mesh, pedestal=pedestal_runtime_params, ) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() builder = set_tped_nped.SetTemperatureDensityPedestalModelBuilder( runtime_params=pedestal_runtime_params ) diff --git a/torax/sources/tests/bootstrap_current_source_test.py b/torax/sources/tests/bootstrap_current_source_test.py index 0dbe0059..5b60eb8b 100644 --- a/torax/sources/tests/bootstrap_current_source_test.py +++ b/torax/sources/tests/bootstrap_current_source_test.py @@ -21,7 +21,7 @@ from torax.config import plasma_composition from torax.config import runtime_params_slice from torax.fvm import cell_variable -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import bootstrap_current_source from torax.sources import runtime_params from torax.sources import source_profiles @@ -35,7 +35,9 @@ def setUp(self): self.source_name = ( bootstrap_current_source.BootstrapCurrentSource.SOURCE_NAME ) - self.geo = circular_geometry.build_circular_geometry(n_rho=n_rho) + self.geo = geometry_pydantic_model.CircularConfig( + n_rho=n_rho + ).build_geometry() dynamic_bootstap_params = bootstrap_current_source.DynamicRuntimeParams( prescribed_values=mock.ANY, bootstrap_mult=1.0, @@ -98,7 +100,8 @@ def test_get_bootstrap(self): ) self.assertEqual(bootstrap_profile.sigma.shape, self.geo.rho.shape) self.assertEqual( - bootstrap_profile.sigma_face.shape, self.geo.rho_face.shape) + bootstrap_profile.sigma_face.shape, self.geo.rho_face.shape + ) self.assertEqual(bootstrap_profile.j_bootstrap.shape, self.geo.rho.shape) self.assertEqual( bootstrap_profile.j_bootstrap_face.shape, self.geo.rho_face.shape @@ -133,7 +136,8 @@ def test_get_bootstrap_with_zero_mode(self): ) self.assertEqual(bootstrap_profile.sigma.shape, self.geo.rho.shape) self.assertEqual( - bootstrap_profile.sigma_face.shape, self.geo.rho_face.shape) + bootstrap_profile.sigma_face.shape, self.geo.rho_face.shape + ) self.assertEqual(bootstrap_profile.j_bootstrap.shape, self.geo.rho.shape) self.assertEqual( bootstrap_profile.j_bootstrap_face.shape, self.geo.rho_face.shape @@ -158,8 +162,7 @@ def test_prescribed_mode_not_supported(self): runtime_params_slice.StaticRuntimeParamsSlice, sources={source.SOURCE_NAME: static_bootstap_params}, ) - with self.assertRaisesRegex( - NotImplementedError, 'Prescribed mode'): + with self.assertRaisesRegex(NotImplementedError, 'Prescribed mode'): source.get_bootstrap( self.dynamic_params, static_params, diff --git a/torax/sources/tests/constant_fraction_impurity_radiation_heat_sink_test.py b/torax/sources/tests/constant_fraction_impurity_radiation_heat_sink_test.py index c9817259..54e8a696 100644 --- a/torax/sources/tests/constant_fraction_impurity_radiation_heat_sink_test.py +++ b/torax/sources/tests/constant_fraction_impurity_radiation_heat_sink_test.py @@ -17,7 +17,7 @@ import chex from torax import math_utils from torax.config import runtime_params_slice -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import generic_ion_el_heat_source from torax.sources import runtime_params as runtime_params_lib from torax.sources import source_profiles @@ -29,7 +29,8 @@ class ImpurityRadiationConstantFractionTest( - test_lib.SingleProfileSourceTestCase): + test_lib.SingleProfileSourceTestCase +): @classmethod def setUpClass(cls): @@ -71,19 +72,22 @@ def test_source_value(self): dynamic_slice = mock.create_autospec( runtime_params_slice.DynamicRuntimeParamsSlice, - sources={heat_name: heat_dynamic, - impurity_name: impurity_radiation_dynamic}) + sources={ + heat_name: heat_dynamic, + impurity_name: impurity_radiation_dynamic, + }, + ) static_slice = mock.create_autospec( runtime_params_slice.StaticRuntimeParamsSlice, - sources={heat_name: static, impurity_name: static} + sources={heat_name: static, impurity_name: static}, ) heat_source = generic_ion_el_heat_source.GenericIonElectronHeatSource( model_func=generic_ion_el_heat_source.default_formula, ) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() el, ion = heat_source.get_value( static_slice, dynamic_slice, @@ -107,7 +111,7 @@ def test_source_value(self): qei=mock.ANY, temp_el={'foo': el}, temp_ion={'foo_source': ion}, - ) + ), ) ) diff --git a/torax/sources/tests/electron_cyclotron_source_test.py b/torax/sources/tests/electron_cyclotron_source_test.py index b765e09d..54c5deb1 100644 --- a/torax/sources/tests/electron_cyclotron_source_test.py +++ b/torax/sources/tests/electron_cyclotron_source_test.py @@ -16,7 +16,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import electron_cyclotron_source from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib @@ -50,7 +50,7 @@ def test_source_value(self): source = source_models.sources[self._source_name] source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED self.assertIsInstance(source, source_lib.Source) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, diff --git a/torax/sources/tests/ion_cyclotron_source_test.py b/torax/sources/tests/ion_cyclotron_source_test.py index a1547d5c..65ad8026 100644 --- a/torax/sources/tests/ion_cyclotron_source_test.py +++ b/torax/sources/tests/ion_cyclotron_source_test.py @@ -24,7 +24,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import ion_cyclotron_source from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib @@ -134,7 +134,7 @@ def test_source_value(self, mock_path): source_builder = self._source_class_builder() # pytype: disable=missing-parameter # pylint: enable=missing-kwoa runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() source_models_builder = source_models_lib.SourceModelsBuilder( {ion_cyclotron_source.IonCyclotronSource.SOURCE_NAME: source_builder}, ) diff --git a/torax/sources/tests/mavrin_impurity_radiation_heat_sink_test.py b/torax/sources/tests/mavrin_impurity_radiation_heat_sink_test.py index 04073ca9..cd4f8e8b 100644 --- a/torax/sources/tests/mavrin_impurity_radiation_heat_sink_test.py +++ b/torax/sources/tests/mavrin_impurity_radiation_heat_sink_test.py @@ -1,4 +1,3 @@ - # Copyright 2024 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,7 +17,7 @@ from torax.config import plasma_composition from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import generic_ion_el_heat_source from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib @@ -82,7 +81,7 @@ def test_source_value(self): self.assertIsInstance(impurity_radiation_sink, source_lib.Source) # Geometry, profiles, and dynamic runtime params - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( diff --git a/torax/sources/tests/qei_source_test.py b/torax/sources/tests/qei_source_test.py index 8355d00f..2fd34ac7 100644 --- a/torax/sources/tests/qei_source_test.py +++ b/torax/sources/tests/qei_source_test.py @@ -15,7 +15,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import qei_source from torax.sources import source_models as source_models_lib from torax.sources.tests import test_lib @@ -42,7 +42,7 @@ def test_source_value(self): source_models = source_models_builder() source = source_models.sources['qei_source'] runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() static_slice = runtime_params_slice.build_static_runtime_params_slice( runtime_params=runtime_params, source_runtime_params=source_models_builder.runtime_params, diff --git a/torax/sources/tests/source_operations_test.py b/torax/sources/tests/source_operations_test.py index 885dff2b..252395fe 100644 --- a/torax/sources/tests/source_operations_test.py +++ b/torax/sources/tests/source_operations_test.py @@ -18,7 +18,7 @@ import jax import jax.numpy as jnp import numpy as np -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import source as source_lib from torax.sources import source_operations from torax.sources import source_profiles as source_profiles_lib @@ -45,7 +45,7 @@ def affected_core_profiles( class SourceOperationsTest(parameterized.TestCase): def test_summed_temp_ion_profiles_dont_change_when_jitting(self): - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() # Make some dummy source profiles that could have come from these sources. ones = jnp.ones_like(geo.rho) @@ -85,5 +85,6 @@ def test_summed_temp_ion_profiles_dont_change_when_jitting(self): jitted_temp_el = sum_temp_el(geo, profiles) np.testing.assert_allclose(jitted_temp_el, ones * 10 * geo.vpr) + if __name__ == '__main__': absltest.main() diff --git a/torax/sources/tests/source_profile_builders_test.py b/torax/sources/tests/source_profile_builders_test.py index 98a5686b..907ea5dc 100644 --- a/torax/sources/tests/source_profile_builders_test.py +++ b/torax/sources/tests/source_profile_builders_test.py @@ -21,7 +21,7 @@ from torax import core_profile_setters from torax.config import runtime_params as runtime_params_lib from torax.config import runtime_params_slice -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import runtime_params as source_runtime_params from torax.sources import source from torax.sources import source_models as source_models_lib @@ -34,7 +34,7 @@ class SourceModelsTest(parameterized.TestCase): def setUp(self): super().setUp() - self.geo = circular_geometry.build_circular_geometry() + self.geo = geometry_pydantic_model.CircularConfig().build_geometry() def test_computing_source_profiles_works_with_all_defaults(self): """Tests that you can compute source profiles with all defaults.""" diff --git a/torax/sources/tests/source_runtime_params_test.py b/torax/sources/tests/source_runtime_params_test.py index 955178f9..026e09a2 100644 --- a/torax/sources/tests/source_runtime_params_test.py +++ b/torax/sources/tests/source_runtime_params_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from absl.testing import absltest -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import runtime_params as runtime_params_lib @@ -20,7 +20,7 @@ class RuntimeParamsTest(absltest.TestCase): def test_runtime_params_builds_dynamic_params(self): runtime_params = runtime_params_lib.RuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = runtime_params.make_provider(geo.torax_mesh) dynamic_params = provider.build_dynamic_params(t=0.0) self.assertIsInstance( diff --git a/torax/sources/tests/source_test.py b/torax/sources/tests/source_test.py index 72211716..95de5d3b 100644 --- a/torax/sources/tests/source_test.py +++ b/torax/sources/tests/source_test.py @@ -19,7 +19,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib @@ -158,7 +158,7 @@ def test_zero_profile_works_by_default(self): source_models = source_models_builder() source = source_models.sources['foo'] runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, @@ -187,8 +187,7 @@ def test_zero_profile_works_by_default(self): calculated_source_profiles=None, ) np.testing.assert_allclose( - profile[0], - np.zeros_like(geo.torax_mesh.cell_centers) + profile[0], np.zeros_like(geo.torax_mesh.cell_centers) ) @parameterized.parameters( @@ -208,7 +207,7 @@ def test_correct_mode_called(self, mode, expected_profile): source = source_models.sources['foo'] source_runtime_params = source_models_builder.runtime_params runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = circular_geometry.build_circular_geometry(n_rho=4) + geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry() source_runtime_params['foo'] = dataclasses.replace( source_models_builder.runtime_params['foo'], mode=mode, @@ -254,7 +253,7 @@ def test_defaults_output_zeros(self): source_models = source_models_builder() source = source_models.sources['foo'] runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params, @@ -319,7 +318,7 @@ def test_defaults_output_zeros(self): def test_overriding_model(self): """The user-specified model should override the default model.""" - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() expected_output = (jnp.ones_like(geo.rho),) source_builder = source_lib.make_source_builder( IonElTestSource, @@ -363,7 +362,7 @@ def test_overriding_model(self): def test_overriding_prescribed_values(self): """Providing prescribed values results in the correct profile.""" - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() # Define the expected output expected_output = (jnp.ones_like(geo.rho),) # Create the source diff --git a/torax/sources/tests/test_lib.py b/torax/sources/tests/test_lib.py index 6b9df6e0..0ca5e721 100644 --- a/torax/sources/tests/test_lib.py +++ b/torax/sources/tests/test_lib.py @@ -23,7 +23,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import runtime_params as runtime_params_lib from torax.sources import source as source_lib from torax.sources import source_models as source_models_lib @@ -91,7 +91,7 @@ def setUpClass( def test_runtime_params_builds_dynamic_params(self): runtime_params = self._runtime_params_class() self.assertIsInstance(runtime_params, runtime_params_lib.RuntimeParams) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = runtime_params.make_provider(geo.torax_mesh) dynamic_params = provider.build_dynamic_params(t=0.0) self.assertIsInstance( @@ -140,7 +140,7 @@ def test_source_value_on_the_cell_grid(self): source = source_models.sources[self._source_name] source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED self.assertIsInstance(source, source_lib.Source) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() dynamic_runtime_params_slice = ( runtime_params_slice.DynamicRuntimeParamsSliceProvider( runtime_params=runtime_params, @@ -168,7 +168,7 @@ def test_source_value_on_the_cell_grid(self): temp_el={'foo_source': jnp.full(geo.rho.shape, 17.0)}, temp_ion={'foo_sink': jnp.full(geo.rho.shape, 19.0)}, ne={}, - qei=source_profiles.QeiInfo.zeros(geo) + qei=source_profiles.QeiInfo.zeros(geo), ) else: calculated_source_profiles = None @@ -192,7 +192,7 @@ def test_source_values_on_the_cell_grid(self): source_builder = self._source_class_builder() # pylint: enable=missing-kwoa runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() source_models_builder = source_models_lib.SourceModelsBuilder( {self._source_name: source_builder}, ) diff --git a/torax/tests/boundary_conditions.py b/torax/tests/boundary_conditions.py index 84258540..d4f01c11 100644 --- a/torax/tests/boundary_conditions.py +++ b/torax/tests/boundary_conditions.py @@ -24,7 +24,7 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import source_models as source_models_lib @@ -69,7 +69,7 @@ def test_setting_boundary_conditions( ), ) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() initial_dynamic_runtime_params_slice = ( diff --git a/torax/tests/core_profile_setters_test.py b/torax/tests/core_profile_setters_test.py index 5cae444d..874c4d4a 100644 --- a/torax/tests/core_profile_setters_test.py +++ b/torax/tests/core_profile_setters_test.py @@ -25,7 +25,7 @@ from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice as runtime_params_slice_lib from torax.fvm import cell_variable -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import source_models as source_models_lib from torax.stepper import runtime_params as stepper_params_lib from torax.transport_model import runtime_params as transport_params_lib @@ -41,7 +41,7 @@ class CoreProfileSettersTest(parameterized.TestCase): def setUp(self): super().setUp() jax_utils.enable_errors(True) - self.geo = circular_geometry.build_circular_geometry(n_rho=4) + self.geo = geometry_pydantic_model.CircularConfig(n_rho=4).build_geometry() def test_updated_ion_temperature(self): bound = np.array(42.0) diff --git a/torax/tests/math_utils.py b/torax/tests/math_utils.py index fc202a44..d7cec22d 100644 --- a/torax/tests/math_utils.py +++ b/torax/tests/math_utils.py @@ -21,8 +21,8 @@ import numpy as np import scipy.integrate from torax import math_utils -from torax.geometry import circular_geometry from torax.geometry import geometry +from torax.geometry import pydantic_model as geometry_pydantic_model jax.config.update('jax_enable_x64', True) @@ -31,7 +31,7 @@ class MathUtilsTest(parameterized.TestCase): """Unit tests for the `torax.math_utils` module.""" @parameterized.product( - initial=(None, 0.), + initial=(None, 0.0), axis=(-1, 1, -1), array_x=(False, True), dtype=(jnp.float32, jnp.float64), @@ -73,7 +73,9 @@ def test_cell_integration(self, num_cell_grid_points: int): x = jax.random.uniform( jax.random.PRNGKey(0), shape=(num_cell_grid_points + 1,) ) - geo = circular_geometry.build_circular_geometry(n_rho=num_cell_grid_points) + geo = geometry_pydantic_model.CircularConfig( + n_rho=num_cell_grid_points + ).build_geometry() np.testing.assert_allclose( math_utils.cell_integration(geometry.face_to_cell(x), geo), @@ -144,7 +146,9 @@ def test_cell_to_face( preserved_quantity: math_utils.IntegralPreservationQuantity, ): """Test that the cell_to_face method works as expected.""" - geo = circular_geometry.build_circular_geometry(n_rho=len(cell_values)) + geo = geometry_pydantic_model.CircularConfig( + n_rho=len(cell_values) + ).build_geometry() cell_values = jnp.array(cell_values, dtype=jnp.float32) face_values = math_utils.cell_to_face(cell_values, geo, preserved_quantity) @@ -175,9 +179,11 @@ def test_cell_to_face( ), ) - def test_cell_to_face_raises_when_too_few_values(self,): + def test_cell_to_face_raises_when_too_few_values( + self, + ): """Test that the cell_to_face method raises when too few values are provided.""" - geo = circular_geometry.build_circular_geometry(n_rho=1) + geo = geometry_pydantic_model.CircularConfig(n_rho=1).build_geometry() with self.assertRaises(ValueError): math_utils.cell_to_face(jnp.array([1.0], dtype=np.float32), geo) diff --git a/torax/tests/output.py b/torax/tests/output.py index 75d7cebb..54dd8cb6 100644 --- a/torax/tests/output.py +++ b/torax/tests/output.py @@ -28,8 +28,8 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice as runtime_params_slice_lib -from torax.geometry import circular_geometry from torax.geometry import geometry_provider +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import source_profiles as source_profiles_lib from torax.tests.test_lib import default_sources from torax.tests.test_lib import torax_refs @@ -56,7 +56,7 @@ def setUp(self): source_models_builder = default_sources.get_default_sources_builder() source_models = source_models_builder() # Make some dummy source profiles that could have come from these sources. - self.geo = circular_geometry.build_circular_geometry() + self.geo = geometry_pydantic_model.CircularConfig().build_geometry() ones = jnp.ones_like(self.geo.rho) geo_provider = geometry_provider.ConstantGeometryProvider(self.geo) dynamic_runtime_params_slice, geo = ( diff --git a/torax/tests/physics.py b/torax/tests/physics.py index bb7e24f6..7e28b400 100644 --- a/torax/tests/physics.py +++ b/torax/tests/physics.py @@ -28,7 +28,7 @@ from torax import physics from torax import state from torax.fvm import cell_variable -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.geometry import standard_geometry from torax.sources import generic_current_source from torax.sources import source_profiles @@ -310,14 +310,14 @@ def test_get_main_ion_dilution_factor(self, Zi, Zimp, Zeff, expected): def test_calculate_plh_scaling_factor(self): """Compare `calculate_plh_scaling_factor` to a reference value.""" - geo = circular_geometry.build_circular_geometry( + geo = geometry_pydantic_model.CircularConfig( n_rho=25, elongation_LCFS=1.0, hires_fac=4, Rmaj=6.0, Rmin=2.0, B0=5.0, - ) + ).build_geometry() core_profiles = state.CoreProfiles( ne=cell_variable.CellVariable( value=jnp.ones_like(geo.rho_norm) * 2, @@ -411,14 +411,14 @@ def test_calculate_plh_scaling_factor(self): # pylint: disable=invalid-name def test_calculate_scaling_law_confinement_time(self, elongation_LCFS): """Compare `calculate_scaling_law_confinement_time` to reference values.""" - geo = circular_geometry.build_circular_geometry( + geo = geometry_pydantic_model.CircularConfig( n_rho=25, elongation_LCFS=elongation_LCFS, hires_fac=4, Rmaj=6.0, Rmin=2.0, B0=5.0, - ) + ).build_geometry() core_profiles = state.CoreProfiles( ne=cell_variable.CellVariable( value=jnp.ones_like(geo.rho_norm) * 2, @@ -563,13 +563,13 @@ def test_calc_Wpol(self): # Small inverse aspect ratio limit of circular geometry, such that we # approximate the simplest form of circular geometry where the analytical # Bpol formula is applicable. - geo = circular_geometry.build_circular_geometry( + geo = geometry_pydantic_model.CircularConfig( n_rho=25, elongation_LCFS=1.0, Rmaj=100.0, Rmin=1.0, B0=5.0, - ) + ).build_geometry() Ip_tot = 15 # calculate high resolution jtot consistent with total current profile jtot_profile = (1 - geo.rho_hires_norm**2) ** 2 diff --git a/torax/tests/post_processing.py b/torax/tests/post_processing.py index 0f61be42..c16dc992 100644 --- a/torax/tests/post_processing.py +++ b/torax/tests/post_processing.py @@ -29,9 +29,9 @@ from torax.config import runtime_params as runtime_params_lib from torax.config import runtime_params_slice from torax.fvm import cell_variable -from torax.geometry import circular_geometry from torax.geometry import geometry from torax.geometry import geometry_provider +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import source_profiles as source_profiles_lib from torax.tests.test_lib import default_sources from torax.tests.test_lib import sim_test_case @@ -44,7 +44,7 @@ class PostProcessingTest(parameterized.TestCase): def setUp(self): super().setUp() runtime_params = runtime_params_lib.GeneralRuntimeParams() - self.geo = circular_geometry.build_circular_geometry() + self.geo = geometry_pydantic_model.CircularConfig().build_geometry() geo_provider = geometry_provider.ConstantGeometryProvider(self.geo) source_models_builder = default_sources.get_default_sources_builder() source_models = source_models_builder() @@ -170,7 +170,7 @@ def _make_constant_core_profile( def test_compute_stored_thermal_energy(self): """Test that stored thermal energy is computed correctly.""" - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() p_el = np.ones_like(geo.rho_face) p_ion = 2 * np.ones_like(geo.rho_face) p_tot = p_el + p_ion diff --git a/torax/tests/sim.py b/torax/tests/sim.py index efcae4a7..90199f7c 100644 --- a/torax/tests/sim.py +++ b/torax/tests/sim.py @@ -30,8 +30,8 @@ from torax.config import build_sim as build_sim_lib from torax.config import numerics as numerics_lib from torax.config import runtime_params as runtime_params_lib -from torax.geometry import circular_geometry from torax.geometry import geometry_provider +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import set_tped_nped from torax.sources import source_models as source_models_lib from torax.stepper import linear_theta_method @@ -40,7 +40,6 @@ from torax.transport_model import constant as constant_transport_model import xarray as xr - _ALL_PROFILES = ('temp_ion', 'temp_el', 'psi', 'q_face', 's_face', 'ne') @@ -523,7 +522,7 @@ def test_no_op(self): time_step_calculator = chi_time_step_calculator.ChiTimeStepCalculator() geo_provider = geometry_provider.ConstantGeometryProvider( - circular_geometry.build_circular_geometry() + geometry_pydantic_model.CircularConfig().build_geometry() ) sim = sim_lib.Sim.create( @@ -738,9 +737,9 @@ def test_update(self): 'test_iterhybrid_predictor_corrector_eqdsk.py' ).CONFIG sim.update_base_components( - geometry_provider=build_sim_lib.build_geometry_provider_from_config( + geometry_provider=geometry_pydantic_model.Geometry.from_dict( new_config['geometry'] - ) + ).build_provider() ) sim_outputs = sim.run() @@ -763,9 +762,9 @@ def test_update_new_mesh(self): sim = self._get_sim('test_iterhybrid_rampup.py') with self.assertRaisesRegex(ValueError, 'different mesh'): sim.update_base_components( - geometry_provider=geometry_provider.ConstantGeometryProvider( - circular_geometry.build_circular_geometry(n_rho=10) - ) + geometry_provider=geometry_pydantic_model.Geometry.from_dict( + {'geometry_type': 'circular', 'n_rho': 10} + ).build_provider() ) diff --git a/torax/tests/sim_custom_sources.py b/torax/tests/sim_custom_sources.py index 80596825..dd28cc4c 100644 --- a/torax/tests/sim_custom_sources.py +++ b/torax/tests/sim_custom_sources.py @@ -30,9 +30,9 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry from torax.geometry import geometry from torax.geometry import geometry_provider +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import set_tped_nped from torax.sources import electron_density_sources from torax.sources import runtime_params as runtime_params_lib @@ -116,19 +116,21 @@ def custom_source_formula( unused_calculated_source_profiles=unused_calculated_source_profiles, ) return ( - electron_density_sources.calc_puff_source( - source_name=electron_density_sources.GasPuffSource.SOURCE_NAME, - **kwargs - )[0] - + electron_density_sources.calc_generic_particle_source( - source_name=electron_density_sources.GenericParticleSource.SOURCE_NAME, - **kwargs - )[0] - + electron_density_sources.calc_pellet_source( - source_name=electron_density_sources.PelletSource.SOURCE_NAME, - **kwargs - )[0] - ), + ( + electron_density_sources.calc_puff_source( + source_name=electron_density_sources.GasPuffSource.SOURCE_NAME, + **kwargs, + )[0] + + electron_density_sources.calc_generic_particle_source( + source_name=electron_density_sources.GenericParticleSource.SOURCE_NAME, + **kwargs, + )[0] + + electron_density_sources.calc_pellet_source( + source_name=electron_density_sources.PelletSource.SOURCE_NAME, + **kwargs, + )[0] + ), + ) # First instantiate the same default sources that test_particle_sources # constant starts with. @@ -179,10 +181,8 @@ def custom_source_formula( pellet_deposition_location=pellet_params.pellet_deposition_location, S_pellet_tot=pellet_params.S_pellet_tot, ) - source_models_builder.source_builders[custom_source_name] = ( - source_builder( - runtime_params=runtime_params, - ) + source_models_builder.source_builders[custom_source_name] = source_builder( + runtime_params=runtime_params, ) # Load reference profiles @@ -190,7 +190,7 @@ def custom_source_formula( 'test_particle_sources_constant.nc', _ALL_PROFILES ) geo_provider = geometry_provider.ConstantGeometryProvider( - circular_geometry.build_circular_geometry() + geometry_pydantic_model.CircularConfig().build_geometry() ) sim = sim_lib.Sim.create( runtime_params=self.test_particle_sources_constant_runtime_params, diff --git a/torax/tests/sim_output_source_profiles.py b/torax/tests/sim_output_source_profiles.py index e19ce65b..7a7d5628 100644 --- a/torax/tests/sim_output_source_profiles.py +++ b/torax/tests/sim_output_source_profiles.py @@ -28,9 +28,9 @@ from torax import sim as sim_lib from torax import state as state_module from torax.config import runtime_params as general_runtime_params -from torax.geometry import circular_geometry from torax.geometry import geometry from torax.geometry import geometry_provider as geometry_provider_lib +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.orchestration import step_function from torax.pedestal_model import set_tped_nped from torax.sources import runtime_params as runtime_params_lib @@ -157,10 +157,11 @@ def custom_source_formula( }) source_models = source_models_builder() runtime_params = general_runtime_params.GeneralRuntimeParams() - runtime_params.numerics.t_final = 2. - runtime_params.numerics.fixed_dt = 1. - geo = circular_geometry.build_circular_geometry() + runtime_params.numerics.t_final = 2.0 + runtime_params.numerics.fixed_dt = 1.0 + geo = geometry_pydantic_model.CircularConfig().build_geometry() time_stepper = fixed_time_step_calculator.FixedTimeStepCalculator() + def mock_step_fn( _, static_runtime_params_slice, @@ -170,18 +171,22 @@ def mock_step_fn( ): dt = 1.0 new_t = input_state.t + dt - return dataclasses.replace( - input_state, - t=new_t, - dt=dt, - time_step_calculator_state=(), - core_sources=source_profile_builders.get_initial_source_profiles( - static_runtime_params_slice, - dynamic_runtime_params_slice_provider(new_t), - geometry_provider(new_t), - core_profiles=input_state.core_profiles, - source_models=source_models), - ), state_module.SimError.NO_ERROR + return ( + dataclasses.replace( + input_state, + t=new_t, + dt=dt, + time_step_calculator_state=(), + core_sources=source_profile_builders.get_initial_source_profiles( + static_runtime_params_slice, + dynamic_runtime_params_slice_provider(new_t), + geometry_provider(new_t), + core_profiles=input_state.core_profiles, + source_models=source_models, + ), + ), + state_module.SimError.NO_ERROR, + ) sim = sim_lib.Sim.create( runtime_params=runtime_params, diff --git a/torax/tests/sim_time_dependence.py b/torax/tests/sim_time_dependence.py index 6045bbad4..a5382599 100644 --- a/torax/tests/sim_time_dependence.py +++ b/torax/tests/sim_time_dependence.py @@ -29,9 +29,9 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry from torax.geometry import geometry from torax.geometry import geometry_provider as geometry_provider_lib +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.pedestal_model import set_tped_nped from torax.sources import source_models as source_models_lib @@ -69,7 +69,7 @@ def test_time_dependent_params_update_in_adaptive_dt( dt_reduction_factor=1.5, ), ) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() geometry_provider = geometry_provider_lib.ConstantGeometryProvider(geo) transport_builder = FakeTransportModelBuilder() source_models_builder = source_models_lib.SourceModelsBuilder() diff --git a/torax/tests/state.py b/torax/tests/state.py index 8898e235..f159ca36 100644 --- a/torax/tests/state.py +++ b/torax/tests/state.py @@ -30,10 +30,8 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry -from torax.geometry import geometry from torax.geometry import geometry_provider -from torax.geometry import standard_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import generic_current_source from torax.sources import runtime_params as runtime_params_lib from torax.sources import source_models as source_models_lib @@ -174,7 +172,7 @@ def test_initial_boundary_condition_from_time_dependent_params(self): source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() geo_provider = geometry_provider.ConstantGeometryProvider( - circular_geometry.build_circular_geometry() + geometry_pydantic_model.CircularConfig().build_geometry() ) dynamic_runtime_params_slice, geo = ( torax_refs.build_consistent_dynamic_runtime_params_slice_and_geometry( @@ -208,7 +206,7 @@ def test_core_profiles_quasineutrality_check(self): source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() geo_provider = geometry_provider.ConstantGeometryProvider( - circular_geometry.build_circular_geometry() + geometry_pydantic_model.CircularConfig().build_geometry() ) dynamic_runtime_params_slice, geo = ( torax_refs.build_consistent_dynamic_runtime_params_slice_and_geometry( @@ -236,16 +234,12 @@ def test_core_profiles_quasineutrality_check(self): assert not core_profiles.quasineutrality_satisfied() @parameterized.parameters([ - dict(geo_builder=circular_geometry.build_circular_geometry), - dict( - geo_builder=lambda: standard_geometry.build_standard_geometry( - standard_geometry.StandardGeometryIntermediates.from_chease() - ) - ), + dict(geometry_name='circular'), + dict(geometry_name='chease'), ]) def test_initial_psi_from_j( self, - geo_builder: Callable[[], geometry.Geometry], + geometry_name: str, ): """Tests expected behaviour of initial psi and current options.""" config1 = general_runtime_params.GeneralRuntimeParams( @@ -281,7 +275,9 @@ def test_initial_psi_from_j( ne_bound_right=0.5, ), ) - geo_provider = geometry_provider.ConstantGeometryProvider(geo_builder()) + geo_provider = geometry_pydantic_model.Geometry.from_dict( + {'geometry_type': geometry_name} + ).build_provider() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() source_models_builder.runtime_params['j_bootstrap'].bootstrap_mult = 0.0 @@ -373,7 +369,8 @@ def test_initial_psi_from_j( ctot = config1.profile_conditions.Ip_tot * 1e6 / denom jtot_formula = jformula * ctot johm_formula = jtot_formula * ( - 1 - dcs1.sources[ + 1 + - dcs1.sources[ generic_current_source.GenericCurrentSource.SOURCE_NAME ].fext # pytype: disable=attribute-error ) @@ -445,7 +442,7 @@ def test_initial_psi_from_geo_noop_circular(self): ne_bound_right=0.5, ), ) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() dcs1 = runtime_params_slice.DynamicRuntimeParamsSliceProvider( config1, sources=source_models_builder.runtime_params, @@ -474,7 +471,7 @@ def test_initial_psi_from_geo_noop_circular(self): core_profiles1 = core_profile_setters.initial_core_profiles( dynamic_runtime_params_slice=dcs1, static_runtime_params_slice=static_slice, - geo=circular_geometry.build_circular_geometry(), + geo=geometry_pydantic_model.CircularConfig().build_geometry(), source_models=source_models, ) static_slice = runtime_params_slice.build_static_runtime_params_slice( @@ -485,7 +482,7 @@ def test_initial_psi_from_geo_noop_circular(self): core_profiles2 = core_profile_setters.initial_core_profiles( dynamic_runtime_params_slice=dcs2, static_runtime_params_slice=static_slice, - geo=circular_geometry.build_circular_geometry(), + geo=geometry_pydantic_model.CircularConfig().build_geometry(), source_models=source_models, ) np.testing.assert_allclose( diff --git a/torax/tests/test_data/test_explicit.py b/torax/tests/test_data/test_explicit.py index a022df2b..08e2d750 100644 --- a/torax/tests/test_data/test_explicit.py +++ b/torax/tests/test_data/test_explicit.py @@ -19,8 +19,8 @@ from torax.config import numerics as numerics_lib from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params -from torax.geometry import circular_geometry from torax.geometry import geometry_provider +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.pedestal_model import set_tped_nped from torax.sources import runtime_params as source_runtime_params @@ -49,7 +49,7 @@ def get_runtime_params() -> general_runtime_params.GeneralRuntimeParams: def get_geometry_provider() -> geometry_provider.ConstantGeometryProvider: return geometry_provider.ConstantGeometryProvider( - circular_geometry.build_circular_geometry() + geometry_pydantic_model.CircularConfig().build_geometry() ) diff --git a/torax/tests/test_lib/torax_refs.py b/torax/tests/test_lib/torax_refs.py index e34ffef0..7e0f89cd 100644 --- a/torax/tests/test_lib/torax_refs.py +++ b/torax/tests/test_lib/torax_refs.py @@ -26,10 +26,9 @@ from torax.config import config_args from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry from torax.geometry import geometry from torax.geometry import geometry_provider as geometry_provider_lib -from torax.geometry import standard_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.sources import runtime_params as sources_params from torax.stepper import runtime_params as stepper_params from torax.transport_model import runtime_params as transport_model_params @@ -88,14 +87,14 @@ def circular_references() -> References: }, }, ) - geo = circular_geometry.build_circular_geometry( + geo = geometry_pydantic_model.CircularConfig( n_rho=25, elongation_LCFS=1.72, hires_fac=4, Rmaj=6.2, Rmin=2.0, B0=5.3, - ) + ).build_geometry() # ground truth values copied from example executions using # array.astype(str),which allows fully lossless reloading psi = fvm.cell_variable.CellVariable( @@ -237,17 +236,15 @@ def chease_references_Ip_from_chease() -> References: # pylint: disable=invalid }, }, ) - geo = standard_geometry.build_standard_geometry( - standard_geometry.StandardGeometryIntermediates.from_chease( - geometry_dir=_GEO_DIRECTORY, - geometry_file='ITER_hybrid_citrin_equil_cheasedata.mat2cols', - n_rho=25, - Ip_from_parameters=False, - Rmaj=6.2, - Rmin=2.0, - B0=5.3, - ) - ) + geo = geometry_pydantic_model.CheaseConfig( + geometry_dir=_GEO_DIRECTORY, + geometry_file='ITER_hybrid_citrin_equil_cheasedata.mat2cols', + n_rho=25, + Ip_from_parameters=False, + Rmaj=6.2, + Rmin=2.0, + B0=5.3, + ).build_geometry() # ground truth values copied from an example PINT execution using # array.astype(str),which allows fully lossless reloading psi = fvm.cell_variable.CellVariable( @@ -389,17 +386,15 @@ def chease_references_Ip_from_runtime_params() -> References: # pylint: disable }, }, ) - geo = standard_geometry.build_standard_geometry( - standard_geometry.StandardGeometryIntermediates.from_chease( - geometry_dir=_GEO_DIRECTORY, - geometry_file='ITER_hybrid_citrin_equil_cheasedata.mat2cols', - n_rho=25, - Ip_from_parameters=True, - Rmaj=6.2, - Rmin=2.0, - B0=5.3, - ) - ) + geo = geometry_pydantic_model.CheaseConfig( + geometry_dir=_GEO_DIRECTORY, + geometry_file='ITER_hybrid_citrin_equil_cheasedata.mat2cols', + n_rho=25, + Ip_from_parameters=True, + Rmaj=6.2, + Rmin=2.0, + B0=5.3, + ).build_geometry() # ground truth values copied from an example executions using # array.astype(str),which allows fully lossless reloading psi = fvm.cell_variable.CellVariable( diff --git a/torax/torax_pydantic/model_base.py b/torax/torax_pydantic/model_base.py index 014e9848..fc1c13d3 100644 --- a/torax/torax_pydantic/model_base.py +++ b/torax/torax_pydantic/model_base.py @@ -40,12 +40,20 @@ def _numpy_array_before_validator( x: np.ndarray | NumpySerialized, ) -> np.ndarray: + """Validates and converts a serialized NumPy array.""" if isinstance(x, np.ndarray): return x - else: + # This can be either a tuple or a list. The list case is if this is coming + # from JSON, which doesn't have a tuple type. + elif isinstance(x, tuple) or isinstance(x, list) and len(x) == 2: dtype, data = x return np.array(data, dtype=np.dtype(dtype)) + else: + raise ValueError( + 'Expected NumPy or a tuple representing a serialized NumPy array, but' + f' got a {type(x)}' + ) def _numpy_array_serializer(x: np.ndarray) -> NumpySerialized: diff --git a/torax/torax_pydantic/model_config.py b/torax/torax_pydantic/model_config.py index 2adebf42..3ef14608 100644 --- a/torax/torax_pydantic/model_config.py +++ b/torax/torax_pydantic/model_config.py @@ -14,6 +14,7 @@ """Pydantic config for Torax.""" +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import pydantic_model as pedestal_model_config from torax.time_step_calculator import config as time_step_calculator_config from torax.torax_pydantic import model_base @@ -27,5 +28,6 @@ class ToraxConfig(model_base.BaseModelMutable): pedestal: Config for the pedestal model. """ - time_step_calculator: time_step_calculator_config.TimeStepCalculator + geometry: geometry_pydantic_model.Geometry pedestal: pedestal_model_config.PedestalModel + time_step_calculator: time_step_calculator_config.TimeStepCalculator diff --git a/torax/transport_model/tests/bohm_gyrobohm.py b/torax/transport_model/tests/bohm_gyrobohm.py index e1a4b99a..edcff32e 100644 --- a/torax/transport_model/tests/bohm_gyrobohm.py +++ b/torax/transport_model/tests/bohm_gyrobohm.py @@ -13,7 +13,7 @@ # limitations under the License. """Tests for bohm_gyrobohm.""" from absl.testing import absltest -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.transport_model import bohm_gyrobohm @@ -21,7 +21,7 @@ class RuntimeParamsTest(absltest.TestCase): def test_runtime_params_builds_dynamic_params(self): runtime_params = bohm_gyrobohm.RuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) diff --git a/torax/transport_model/tests/constant.py b/torax/transport_model/tests/constant.py index 023ee024..dce5c0c8 100644 --- a/torax/transport_model/tests/constant.py +++ b/torax/transport_model/tests/constant.py @@ -13,7 +13,7 @@ # limitations under the License. """Tests for constant transport model.""" from absl.testing import absltest -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.transport_model import constant @@ -21,7 +21,7 @@ class RuntimeParamsTest(absltest.TestCase): def test_runtime_params_builds_dynamic_params(self): runtime_params = constant.RuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) diff --git a/torax/transport_model/tests/critical_gradient.py b/torax/transport_model/tests/critical_gradient.py index cb4876de..b3f05941 100644 --- a/torax/transport_model/tests/critical_gradient.py +++ b/torax/transport_model/tests/critical_gradient.py @@ -13,7 +13,7 @@ # limitations under the License. """Tests for critical gradient transport model.""" from absl.testing import absltest -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.transport_model import critical_gradient @@ -21,7 +21,7 @@ class RuntimeParamsTest(absltest.TestCase): def test_runtime_params_builds_dynamic_params(self): runtime_params = critical_gradient.RuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) diff --git a/torax/transport_model/tests/qlknn_transport_model.py b/torax/transport_model/tests/qlknn_transport_model.py index 23b06a53..1a5a0b46 100644 --- a/torax/transport_model/tests/qlknn_transport_model.py +++ b/torax/transport_model/tests/qlknn_transport_model.py @@ -21,7 +21,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import set_tped_nped from torax.sources import source_models as source_models_lib from torax.transport_model import qlknn_transport_model @@ -38,7 +38,7 @@ def test_qlknn_transport_model_cache_works(self): qlknn_transport_model.get_default_model_path() ) runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() pedestal_model_builder = ( @@ -171,7 +171,7 @@ def test_clip_inputs(self): def test_runtime_params_builds_dynamic_params(self): runtime_params = qlknn_transport_model.RuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) diff --git a/torax/transport_model/tests/qualikiz_based_transport_model.py b/torax/transport_model/tests/qualikiz_based_transport_model.py index 09978318..ccda8fc9 100644 --- a/torax/transport_model/tests/qualikiz_based_transport_model.py +++ b/torax/transport_model/tests/qualikiz_based_transport_model.py @@ -21,8 +21,8 @@ from torax import state from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry from torax.geometry import geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.pedestal_model import set_tped_nped from torax.sources import source_models as source_models_lib @@ -34,7 +34,7 @@ def _get_model_inputs(transport: qualikiz_based_transport_model.RuntimeParams): """Returns the model inputs for testing.""" runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() pedestal_model_builder = ( diff --git a/torax/transport_model/tests/qualikiz_transport_model.py b/torax/transport_model/tests/qualikiz_transport_model.py index aa070c5a..ed917fa1 100644 --- a/torax/transport_model/tests/qualikiz_transport_model.py +++ b/torax/transport_model/tests/qualikiz_transport_model.py @@ -22,7 +22,7 @@ from torax import core_profile_setters from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import pedestal_model from torax.sources import source_models as source_models_lib from torax.stepper import runtime_params as stepper_runtime_params @@ -31,6 +31,7 @@ # pylint: disable=g-import-not-at-top try: from torax.transport_model import qualikiz_transport_model + _QUALIKIZ_TRANSPORT_MODEL_AVAILABLE = True except ImportError: _QUALIKIZ_TRANSPORT_MODEL_AVAILABLE = False @@ -43,7 +44,7 @@ def test_runtime_params_builds_dynamic_params(self): if not _QUALIKIZ_TRANSPORT_MODEL_AVAILABLE: self.skipTest('Qualikiz transport model is not available.') runtime_params = qualikiz_transport_model.RuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0) @@ -59,7 +60,7 @@ def test_call(self): os.environ['TORAX_COMPILATION_ENABLED'] = '0' # Building the model inputs. - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() runtime_params = general_runtime_params.GeneralRuntimeParams() diff --git a/torax/transport_model/tests/quasilinear_transport_model.py b/torax/transport_model/tests/quasilinear_transport_model.py index a8021e64..2bb08e84 100644 --- a/torax/transport_model/tests/quasilinear_transport_model.py +++ b/torax/transport_model/tests/quasilinear_transport_model.py @@ -25,8 +25,8 @@ from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice from torax.fvm import cell_variable -from torax.geometry import circular_geometry from torax.geometry import geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.pedestal_model import set_tped_nped from torax.sources import source_models as source_models_lib @@ -40,7 +40,7 @@ def _get_model_inputs(transport: quasilinear_transport_model.RuntimeParams): """Returns the model inputs for testing.""" runtime_params = general_runtime_params.GeneralRuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() pedestal_model_builder = ( @@ -266,7 +266,7 @@ def _call_implementation( def _get_dummy_core_profiles(value, right_face_constraint): """Returns dummy core profiles for testing.""" - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() currents = state.Currents.zeros(geo) dummy_cell_variable = cell_variable.CellVariable( value=value, diff --git a/torax/transport_model/tests/transport_model.py b/torax/transport_model/tests/transport_model.py index 2551298d..59022a9a 100644 --- a/torax/transport_model/tests/transport_model.py +++ b/torax/transport_model/tests/transport_model.py @@ -25,8 +25,8 @@ from torax.config import profile_conditions as profile_conditions_lib from torax.config import runtime_params as general_runtime_params from torax.config import runtime_params_slice -from torax.geometry import circular_geometry from torax.geometry import geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.pedestal_model import pedestal_model as pedestal_model_lib from torax.pedestal_model import set_tped_nped from torax.sources import source_models as source_models_lib @@ -46,7 +46,7 @@ def test_smoothing(self): ne_bound_right=0.5, ), ) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() transport_model_builder = FakeTransportModelBuilder( @@ -198,7 +198,7 @@ def test_smoothing_everywhere(self): ne_bound_right=0.5, ), ) - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() source_models_builder = source_models_lib.SourceModelsBuilder() source_models = source_models_builder() transport_model_builder = FakeTransportModelBuilder( diff --git a/torax/transport_model/tests/transport_model_runtime_params.py b/torax/transport_model/tests/transport_model_runtime_params.py index befb7861..d0880ff3 100644 --- a/torax/transport_model/tests/transport_model_runtime_params.py +++ b/torax/transport_model/tests/transport_model_runtime_params.py @@ -13,7 +13,7 @@ # limitations under the License. """Tests for runtime params for transport model.""" from absl.testing import absltest -from torax.geometry import circular_geometry +from torax.geometry import pydantic_model as geometry_pydantic_model from torax.transport_model import runtime_params as runtime_params_lib @@ -21,7 +21,7 @@ class RuntimeParamsTest(absltest.TestCase): def test_runtime_params_builds_dynamic_params(self): runtime_params = runtime_params_lib.RuntimeParams() - geo = circular_geometry.build_circular_geometry() + geo = geometry_pydantic_model.CircularConfig().build_geometry() provider = runtime_params.make_provider(geo.torax_mesh) provider.build_dynamic_params(t=0.0)