Skip to content

Commit

Permalink
Create pydantic model for stepper.
Browse files Browse the repository at this point in the history
Drive-by:
- remove a layer of nesting from stepper configs(by removing newton_raphson_params and optimizer_params from).

Follow up to remove the builder objects completely.

PiperOrigin-RevId: 731728066
  • Loading branch information
Nush395 authored and Torax team committed Feb 28, 2025
1 parent 0554b3c commit c13d51d
Show file tree
Hide file tree
Showing 11 changed files with 220 additions and 58 deletions.
34 changes: 8 additions & 26 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torax.sources import source_models as source_models_lib
from torax.stepper import linear_theta_method
from torax.stepper import nonlinear_theta_method
from torax.stepper import pydantic_model as stepper_pydantic_model
from torax.stepper import stepper as stepper_lib
from torax.time_step_calculator import chi_time_step_calculator
from torax.time_step_calculator import fixed_time_step_calculator
Expand Down Expand Up @@ -503,48 +504,29 @@ def build_stepper_builder_from_config(
Raises:
ValueError if the `stepper_type` is unknown.
"""
if isinstance(stepper_config, str):
stepper_config = {'stepper_type': stepper_config}
else:
if 'stepper_type' not in stepper_config:
raise ValueError('stepper_type must be set in the input config.')
# Shallow copy so we don't modify the input config.
stepper_config = copy.copy(stepper_config)
stepper_type = stepper_config.pop('stepper_type')
stepper_model = stepper_pydantic_model.Stepper.from_dict(stepper_config)
stepper_model = stepper_model.to_dict()
stepper_type = stepper_model['stepper_config'].pop('stepper_type')

if stepper_type == 'linear':
# Remove params from steppers with nested configs, if present.
stepper_config.pop('newton_raphson_params', None)
stepper_config.pop('optimizer_params', None)
return linear_theta_method.LinearThetaMethodBuilder(
runtime_params=config_args.recursive_replace(
linear_theta_method.LinearRuntimeParams(),
**stepper_config,
**stepper_model['stepper_config'],
)
)
elif stepper_type == 'newton_raphson':
newton_raphson_params = stepper_config.pop('newton_raphson_params', {})
if not isinstance(newton_raphson_params, dict):
raise ValueError('newton_raphson_params must be a dict.')
newton_raphson_params.update(stepper_config)
# Remove params from other steppers with nested configs, if present.
newton_raphson_params.pop('optimizer_params', None)
return nonlinear_theta_method.NewtonRaphsonThetaMethodBuilder(
runtime_params=config_args.recursive_replace(
nonlinear_theta_method.NewtonRaphsonRuntimeParams(),
**newton_raphson_params,
**stepper_model['stepper_config'],
)
)
elif stepper_type == 'optimizer':
optimizer_params = stepper_config.pop('optimizer_params', {})
if not isinstance(optimizer_params, dict):
raise ValueError('optimizer_params must be a dict.')
optimizer_params.update(stepper_config)
# Remove params from other steppers with nested configs, if present.
optimizer_params.pop('newton_raphson_params', None)
return nonlinear_theta_method.OptimizerThetaMethodBuilder(
runtime_params=config_args.recursive_replace(
nonlinear_theta_method.OptimizerRuntimeParams(),
**optimizer_params,
**stepper_model['stepper_config'],
)
)
raise ValueError(f'Unknown stepper type: {stepper_type}')
Expand Down
4 changes: 1 addition & 3 deletions torax/examples/iterhybrid_rampup.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@
'd_per': 15,
# use_pereverzev is only used for the linear solver
'use_pereverzev': True,
'newton_raphson_params': {
'log_iterations': False,
},
'log_iterations': False,
},
'time_step_calculator': {
'calculator_type': 'fixed',
Expand Down
39 changes: 28 additions & 11 deletions torax/pedestal_model/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import copy
from typing import Any, Literal

import pydantic
from torax.torax_pydantic import interpolated_param_1d
from torax.torax_pydantic import torax_pydantic


Expand All @@ -35,11 +35,19 @@ class SetPpedTpedRatioNped(torax_pydantic.BaseModelMutable):
rho_norm_ped_top: The location of the pedestal top.
"""
pedestal_model: Literal['set_pped_tpedratio_nped']
Pped: torax_pydantic.Pascal = 1e5
neped: torax_pydantic.Density = 0.7
Pped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(1e5)
)
neped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.7)
)
neped_is_fGW: bool = False
ion_electron_temperature_ratio: torax_pydantic.OpenUnitInterval = 1.0
rho_norm_ped_top: torax_pydantic.UnitInterval = 0.91
ion_electron_temperature_ratio: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(1.0)
)
rho_norm_ped_top: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.91)
)


class SetTpedNped(torax_pydantic.BaseModelMutable):
Expand All @@ -53,16 +61,25 @@ class SetTpedNped(torax_pydantic.BaseModelMutable):
Teped: Electron temperature at the pedestal [keV].
rho_norm_ped_top: The location of the pedestal top.
"""

pedestal_model: Literal['set_tped_nped']
neped: torax_pydantic.Density = 0.7
neped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.7)
)
neped_is_fGW: bool = False
Tiped: torax_pydantic.KiloElectronVolt = 5.0
Teped: torax_pydantic.KiloElectronVolt = 5.0
rho_norm_ped_top: torax_pydantic.UnitInterval = 0.91
Tiped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(5.0)
)
Teped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(5.0)
)
rho_norm_ped_top: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.91)
)


class PedestalModel(torax_pydantic.BaseModelMutable):
"""Config for a time step calculator."""
class Pedestal(torax_pydantic.BaseModelMutable):
"""Config for a pedestal model."""
pedestal_config: SetPpedTpedRatioNped | SetTpedNped = pydantic.Field(
discriminator='pedestal_model', default_factory=SetTpedNped,
)
Expand Down
109 changes: 109 additions & 0 deletions torax/stepper/pydantic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pydantic config for Stepper."""
from typing import Any, Literal, Union

import pydantic
from torax.fvm import enums
from torax.torax_pydantic import torax_pydantic


# pylint: disable=invalid-name
class Linear(torax_pydantic.BaseModelMutable):
"""Model for linear stepper.
Attributes:
stepper_type: The type of stepper to use, hardcoded to 'linear'.
theta_imp: The theta value in the theta method 0 = explicit, 1 = fully
implicit, 0.5 = Crank-Nicolson.
predictor_corrector: Enables predictor_corrector iterations with the linear
solver. If False, compilation is faster.
corrector_steps: The number of corrector steps for the predictor-corrector
linear solver. 0 means a pure linear solve with no corrector steps.
convection_dirichlet_mode: See `fvm.convection_terms` docstring,
`dirichlet_mode` argument.
convection_neumann_mode: See `fvm.convection_terms` docstring,
`neumann_mode` argument.
use_pereverzev: Use pereverzev terms for linear solver. Is only applied in
the nonlinear solver for the optional initial guess from the linear solver
chi_per: (deliberately) large heat conductivity for Pereverzev rule.
d_per: (deliberately) large particle diffusion for Pereverzev rule.
"""
stepper_type: Literal['linear']
theta_imp: float = 1.0
predictor_corrector: bool = True
corrector_steps: int = 1
convection_dirichlet_mode: str = 'ghost'
convection_neumann_mode: str = 'ghost'
use_pereverzev: bool = False
chi_per: float = 20.0
d_per: float = 10.0


class NewtonRaphson(Linear):
"""Model for non linear Newton-Raphson stepper.
Attributes:
stepper_type: The type of stepper to use, hardcoded to 'newton_raphson'.
log_iterations: If True, log internal iterations in Newton-Raphson solver.
initial_guess_mode: The initial guess mode for the Newton-Raphson solver.
maxiter: The maximum number of iterations for the Newton-Raphson solver.
tol: The tolerance for the Newton-Raphson solver.
coarse_tol: The coarse tolerance for the Newton-Raphson solver.
delta_reduction_factor: The delta reduction factor for the Newton-Raphson
solver.
tau_min: The minimum value of tau for the Newton-Raphson solver.
"""
stepper_type: Literal['newton_raphson']
log_iterations: bool = False
initial_guess_mode: enums.InitialGuessMode = enums.InitialGuessMode.LINEAR
maxiter: int = 30
tol: float = 1e-5
coarse_tol: float = 1e-2
delta_reduction_factor: float = 0.5
tau_min: float = 0.01


class Optimizer(Linear):
"""A basic version of the pedestal model that uses direct specification.
Attributes:
stepper_type: The type of stepper to use, hardcoded to 'optimizer'.
initial_guess_mode: The initial guess mode for the optimizer.
maxiter: The maximum number of iterations for the optimizer.
tol: The tolerance for the optimizer.
"""
stepper_type: Literal['optimizer']
initial_guess_mode: enums.InitialGuessMode = enums.InitialGuessMode.LINEAR
maxiter: int = 100
tol: float = 1e-12


StepperConfig = Union[Linear, NewtonRaphson, Optimizer]


class Stepper(torax_pydantic.BaseModelMutable):
"""Config for a stepper."""
stepper_config: StepperConfig = pydantic.Field(discriminator='stepper_type')

@pydantic.model_validator(mode='before')
@classmethod
def _conform_data(cls, data: dict[str, Any]) -> dict[str, Any]:
# If we are running with the standard class constructor we don't need to do
# any custom validation.
if 'stepper_config' in data:
return data

return {'stepper_config': data}
1 change: 0 additions & 1 deletion torax/tests/test_data/test_iterhybrid_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@
# (deliberately) large particle diffusion for Pereverzev rule
'd_per': 15,
'use_pereverzev': True,
'newton_raphson_params': {},
},
'time_step_calculator': {
'calculator_type': 'chi',
Expand Down
1 change: 0 additions & 1 deletion torax/tests/test_data/test_iterhybrid_rampup.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@
'd_per': 15,
# use_pereverzev is only used for the linear solver
'use_pereverzev': True,
'newton_raphson_params': {},
},
'time_step_calculator': {
'calculator_type': 'fixed',
Expand Down
1 change: 0 additions & 1 deletion torax/tests/test_data/test_iterhybrid_rampup_short.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@
'd_per': 15,
# use_pereverzev is only used for the linear solver
'use_pereverzev': True,
'newton_raphson_params': {},
},
'time_step_calculator': {
'calculator_type': 'fixed',
Expand Down
4 changes: 1 addition & 3 deletions torax/tests/test_data/test_newton_raphson_zeroiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@
'stepper_type': 'newton_raphson',
'predictor_corrector': False,
'use_pereverzev': True,
'newton_raphson_params': {
'maxiter': 0,
},
'maxiter': 0,
},
'time_step_calculator': {
'calculator_type': 'chi',
Expand Down
11 changes: 7 additions & 4 deletions torax/torax_pydantic/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""Pydantic config for Torax."""

from torax.geometry import pydantic_model as geometry_pydantic_model
from torax.pedestal_model import pydantic_model as pedestal_model_config
from torax.pedestal_model import pydantic_model as pedestal_pydantic_model
from torax.stepper import pydantic_model as stepper_pydantic_model
from torax.time_step_calculator import config as time_step_calculator_config
from torax.torax_pydantic import model_base

Expand All @@ -24,10 +25,12 @@ class ToraxConfig(model_base.BaseModelMutable):
"""Base config class for Torax.
Attributes:
time_step_calculator: Config for the time step calculator.
geometry: Config for the geometry.
pedestal: Config for the pedestal model.
stepper: Config for the stepper.
time_step_calculator: Config for the time step calculator.
"""

geometry: geometry_pydantic_model.Geometry
pedestal: pedestal_model_config.PedestalModel
pedestal: pedestal_pydantic_model.Pedestal
stepper: stepper_pydantic_model.Stepper
time_step_calculator: time_step_calculator_config.TimeStepCalculator
Loading

0 comments on commit c13d51d

Please sign in to comment.