Skip to content

Commit

Permalink
Modify the pedestal pydantic config to work with "time varying inputs…
Browse files Browse the repository at this point in the history
…" and expand the config test to work with all the configs tested in the sim tests.

PiperOrigin-RevId: 731677550
  • Loading branch information
Nush395 authored and Torax team committed Feb 28, 2025
1 parent 0554b3c commit 8525d8f
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 18 deletions.
37 changes: 27 additions & 10 deletions torax/pedestal_model/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import copy
from typing import Any, Literal

import pydantic
from torax.torax_pydantic import interpolated_param_1d
from torax.torax_pydantic import torax_pydantic


Expand All @@ -35,11 +35,19 @@ class SetPpedTpedRatioNped(torax_pydantic.BaseModelMutable):
rho_norm_ped_top: The location of the pedestal top.
"""
pedestal_model: Literal['set_pped_tpedratio_nped']
Pped: torax_pydantic.Pascal = 1e5
neped: torax_pydantic.Density = 0.7
Pped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(1e5)
)
neped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.7)
)
neped_is_fGW: bool = False
ion_electron_temperature_ratio: torax_pydantic.OpenUnitInterval = 1.0
rho_norm_ped_top: torax_pydantic.UnitInterval = 0.91
ion_electron_temperature_ratio: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(1.0)
)
rho_norm_ped_top: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.91)
)


class SetTpedNped(torax_pydantic.BaseModelMutable):
Expand All @@ -53,16 +61,25 @@ class SetTpedNped(torax_pydantic.BaseModelMutable):
Teped: Electron temperature at the pedestal [keV].
rho_norm_ped_top: The location of the pedestal top.
"""

pedestal_model: Literal['set_tped_nped']
neped: torax_pydantic.Density = 0.7
neped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.7)
)
neped_is_fGW: bool = False
Tiped: torax_pydantic.KiloElectronVolt = 5.0
Teped: torax_pydantic.KiloElectronVolt = 5.0
rho_norm_ped_top: torax_pydantic.UnitInterval = 0.91
Tiped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(5.0)
)
Teped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(5.0)
)
rho_norm_ped_top: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.91)
)


class PedestalModel(torax_pydantic.BaseModelMutable):
"""Config for a time step calculator."""
"""Config for a pedestal model."""
pedestal_config: SetPpedTpedRatioNped | SetTpedNped = pydantic.Field(
discriminator='pedestal_model', default_factory=SetTpedNped,
)
Expand Down
71 changes: 63 additions & 8 deletions torax/torax_pydantic/tests/model_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,79 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for the `torax.config` module."""

from absl.testing import absltest
from absl.testing import parameterized
import chex
from torax.config import config_loader
from torax.torax_pydantic import model_config


class ConfigTest(parameterized.TestCase):
"""Unit tests for the `torax.config` module."""

def test_full_config_construction(self):
@parameterized.parameters(
"test_crank_nicolson",
"test_implicit",
"test_qei",
"test_pedestal",
"test_cgmheat",
"test_bohmgyrobohm_all",
"test_semiimplicit_convection",
"test_qlknnheat",
"test_fixed_dt",
"test_psiequation",
"test_psi_and_heat",
"test_absolute_generic_current_source",
"test_newton_raphson_zeroiter",
"test_bootstrap",
"test_psi_heat_dens",
"test_particle_sources_constant",
"test_particle_sources_cgm",
"test_prescribed_generic_current_source",
"test_fusion_power",
"test_all_transport_fusion_qlknn",
"test_chease",
"test_eqdsk",
"test_ohmic_power",
"test_bremsstrahlung",
"test_bremsstrahlung_time_dependent_Zimp",
"test_qei_chease_highdens",
"test_psichease_ip_parameters",
"test_psichease_ip_chease",
"test_psichease_prescribed_jtot",
"test_psichease_prescribed_johm",
"test_timedependence",
"test_prescribed_timedependent_ne",
"test_ne_qlknn_defromchie",
"test_ne_qlknn_deff_veff",
"test_all_transport_crank_nicolson",
"test_pc_method_ne",
"test_iterbaseline_mockup",
"test_iterhybrid_mockup",
"test_iterhybrid_predictor_corrector",
"test_iterhybrid_predictor_corrector_eqdsk",
"test_iterhybrid_predictor_corrector_clip_inputs",
"test_iterhybrid_predictor_corrector_zeffprofile",
"test_iterhybrid_predictor_corrector_zi2",
"test_iterhybrid_predictor_corrector_timedependent_isotopes",
"test_iterhybrid_predictor_corrector_tungsten",
"test_iterhybrid_predictor_corrector_ec_linliu",
"test_iterhybrid_predictor_corrector_constant_fraction_impurity_radiation",
"test_iterhybrid_predictor_corrector_set_pped_tpedratio_nped",
"test_iterhybrid_predictor_corrector_cyclotron",
"test_iterhybrid_newton",
"test_iterhybrid_rampup",
"test_time_dependent_circular_geo",
"test_changing_config_before",
"test_changing_config_after",
"test_psichease_ip_parameters_vloop",
"test_psichease_ip_chease_vloop",
"test_psichease_prescribed_jtot_vloop",
)
def test_full_config_construction(self, config_name):
"""Test for basic config construction."""

module = config_loader.import_module(
".tests.test_data.test_iterhybrid_newton",
f".tests.test_data.{config_name}",
config_package="torax",
)

Expand All @@ -48,14 +104,13 @@ def test_full_config_construction(self):
if "pedestal_model" in module_config["pedestal"]
else "set_tped_nped",
)

# The full model should always be serializable.
with self.subTest("json_serialization"):
config_json = config_pydantic.model_dump_json()
config_pydantic_roundtrip = model_config.ToraxConfig.model_validate_json(
config_json
)
self.assertEqual(config_pydantic, config_pydantic_roundtrip)
chex.assert_trees_all_equal(config_pydantic, config_pydantic_roundtrip)


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions torax/torax_pydantic/torax_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Pydantic utilities and base classes."""

import functools
from typing import TypeAlias
import pydantic
from torax.torax_pydantic import interpolated_param_1d
Expand Down Expand Up @@ -47,3 +48,5 @@

TimeVaryingScalar = interpolated_param_1d.TimeVaryingScalar
TimeVaryingArray = interpolated_param_2d.TimeVaryingArray

ValidatedDefault = functools.partial(pydantic.Field, validate_default=True)

0 comments on commit 8525d8f

Please sign in to comment.