Skip to content

Commit

Permalink
Modify the pedestal pydantic config to work with all "time varying in…
Browse files Browse the repository at this point in the history
…puts" and expand the config test to work with all the configs tested in the sim tests.

PiperOrigin-RevId: 731677550
  • Loading branch information
Nush395 authored and Torax team committed Feb 27, 2025
1 parent 72bbf2c commit d2c6a88
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 14 deletions.
37 changes: 27 additions & 10 deletions torax/pedestal_model/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import copy
from typing import Any, Literal

import pydantic
from torax.torax_pydantic import interpolated_param_1d
from torax.torax_pydantic import torax_pydantic


Expand All @@ -35,11 +35,19 @@ class SetPpedTpedRatioNped(torax_pydantic.BaseModelMutable):
rho_norm_ped_top: The location of the pedestal top.
"""
pedestal_model: Literal['set_pped_tpedratio_nped']
Pped: torax_pydantic.Pascal = 1e5
neped: torax_pydantic.Density = 0.7
Pped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.DefaultValue(1e5)
)
neped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.DefaultValue(0.7)
)
neped_is_fGW: bool = False
ion_electron_temperature_ratio: torax_pydantic.OpenUnitInterval = 1.0
rho_norm_ped_top: torax_pydantic.UnitInterval = 0.91
ion_electron_temperature_ratio: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.DefaultValue(1.0)
)
rho_norm_ped_top: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.DefaultValue(0.91)
)


class SetTpedNped(torax_pydantic.BaseModelMutable):
Expand All @@ -53,16 +61,25 @@ class SetTpedNped(torax_pydantic.BaseModelMutable):
Teped: Electron temperature at the pedestal [keV].
rho_norm_ped_top: The location of the pedestal top.
"""

pedestal_model: Literal['set_tped_nped']
neped: torax_pydantic.Density = 0.7
neped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.DefaultValue(0.7)
)
neped_is_fGW: bool = False
Tiped: torax_pydantic.KiloElectronVolt = 5.0
Teped: torax_pydantic.KiloElectronVolt = 5.0
rho_norm_ped_top: torax_pydantic.UnitInterval = 0.91
Tiped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.DefaultValue(5.0)
)
Teped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.DefaultValue(5.0)
)
rho_norm_ped_top: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.DefaultValue(0.91)
)


class PedestalModel(torax_pydantic.BaseModelMutable):
"""Config for a time step calculator."""
"""Config for a pedestal model."""
pedestal_config: SetPpedTpedRatioNped | SetTpedNped = pydantic.Field(
discriminator='pedestal_model', default_factory=SetTpedNped,
)
Expand Down
67 changes: 63 additions & 4 deletions torax/torax_pydantic/tests/model_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,78 @@

from absl.testing import absltest
from absl.testing import parameterized
import chex
from torax.config import config_loader
from torax.torax_pydantic import model_config


class ConfigTest(parameterized.TestCase):
"""Unit tests for the `torax.config` module."""

def test_full_config_construction(self):
@parameterized.parameters(
"test_crank_nicolson",
"test_implicit",
"test_qei",
"test_pedestal",
"test_cgmheat",
"test_bohmgyrobohm_all",
"test_semiimplicit_convection",
"test_qlknnheat",
"test_fixed_dt",
"test_psiequation",
"test_psi_and_heat",
"test_absolute_generic_current_source",
"test_newton_raphson_zeroiter",
"test_bootstrap",
"test_psi_heat_dens",
"test_particle_sources_constant",
"test_particle_sources_cgm",
"test_prescribed_generic_current_source",
"test_fusion_power",
"test_all_transport_fusion_qlknn",
"test_chease",
"test_eqdsk",
"test_ohmic_power",
"test_bremsstrahlung",
"test_bremsstrahlung_time_dependent_Zimp",
"test_qei_chease_highdens",
"test_psichease_ip_parameters",
"test_psichease_ip_chease",
"test_psichease_prescribed_jtot",
"test_psichease_prescribed_johm",
"test_timedependence",
"test_prescribed_timedependent_ne",
"test_ne_qlknn_defromchie",
"test_ne_qlknn_deff_veff",
"test_all_transport_crank_nicolson",
"test_pc_method_ne",
"test_iterbaseline_mockup",
"test_iterhybrid_mockup",
"test_iterhybrid_predictor_corrector",
"test_iterhybrid_predictor_corrector_eqdsk",
"test_iterhybrid_predictor_corrector_clip_inputs",
"test_iterhybrid_predictor_corrector_zeffprofile",
"test_iterhybrid_predictor_corrector_zi2",
"test_iterhybrid_predictor_corrector_timedependent_isotopes",
"test_iterhybrid_predictor_corrector_tungsten",
"test_iterhybrid_predictor_corrector_ec_linliu",
"test_iterhybrid_predictor_corrector_constant_fraction_impurity_radiation",
"test_iterhybrid_predictor_corrector_set_pped_tpedratio_nped",
"test_iterhybrid_predictor_corrector_cyclotron",
"test_iterhybrid_newton",
"test_iterhybrid_rampup",
"test_time_dependent_circular_geo",
"test_changing_config_before",
"test_changing_config_after",
"test_psichease_ip_parameters_vloop",
"test_psichease_ip_chease_vloop",
"test_psichease_prescribed_jtot_vloop",
)
def test_full_config_construction(self, config_name):
"""Test for basic config construction."""

module = config_loader.import_module(
".tests.test_data.test_iterhybrid_newton",
f".tests.test_data.{config_name}",
config_package="torax",
)

Expand All @@ -48,14 +108,13 @@ def test_full_config_construction(self):
if "pedestal_model" in module_config["pedestal"]
else "set_tped_nped",
)

# The full model should always be serializable.
with self.subTest("json_serialization"):
config_json = config_pydantic.model_dump_json()
config_pydantic_roundtrip = model_config.ToraxConfig.model_validate_json(
config_json
)
self.assertEqual(config_pydantic, config_pydantic_roundtrip)
chex.assert_trees_all_equal(config_pydantic, config_pydantic_roundtrip)


if __name__ == "__main__":
Expand Down
3 changes: 3 additions & 0 deletions torax/torax_pydantic/torax_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Pydantic utilities and base classes."""

import functools
from typing import TypeAlias
import pydantic
from torax.torax_pydantic import interpolated_param_1d
Expand Down Expand Up @@ -47,3 +48,5 @@

TimeVaryingScalar = interpolated_param_1d.TimeVaryingScalar
TimeVaryingArray = interpolated_param_2d.TimeVaryingArray

DefaultValue = functools.partial(pydantic.Field, validate_default=True)

0 comments on commit d2c6a88

Please sign in to comment.