Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create pydantic model for stepper. #777

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 8 additions & 26 deletions torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from torax.sources import source_models as source_models_lib
from torax.stepper import linear_theta_method
from torax.stepper import nonlinear_theta_method
from torax.stepper import pydantic_model as stepper_pydantic_model
from torax.stepper import stepper as stepper_lib
from torax.time_step_calculator import chi_time_step_calculator
from torax.time_step_calculator import fixed_time_step_calculator
Expand Down Expand Up @@ -503,48 +504,29 @@ def build_stepper_builder_from_config(
Raises:
ValueError if the `stepper_type` is unknown.
"""
if isinstance(stepper_config, str):
stepper_config = {'stepper_type': stepper_config}
else:
if 'stepper_type' not in stepper_config:
raise ValueError('stepper_type must be set in the input config.')
# Shallow copy so we don't modify the input config.
stepper_config = copy.copy(stepper_config)
stepper_type = stepper_config.pop('stepper_type')
stepper_model = stepper_pydantic_model.Stepper.from_dict(stepper_config)
stepper_model = stepper_model.to_dict()
stepper_type = stepper_model['stepper_config'].pop('stepper_type')

if stepper_type == 'linear':
# Remove params from steppers with nested configs, if present.
stepper_config.pop('newton_raphson_params', None)
stepper_config.pop('optimizer_params', None)
return linear_theta_method.LinearThetaMethodBuilder(
runtime_params=config_args.recursive_replace(
linear_theta_method.LinearRuntimeParams(),
**stepper_config,
**stepper_model['stepper_config'],
)
)
elif stepper_type == 'newton_raphson':
newton_raphson_params = stepper_config.pop('newton_raphson_params', {})
if not isinstance(newton_raphson_params, dict):
raise ValueError('newton_raphson_params must be a dict.')
newton_raphson_params.update(stepper_config)
# Remove params from other steppers with nested configs, if present.
newton_raphson_params.pop('optimizer_params', None)
return nonlinear_theta_method.NewtonRaphsonThetaMethodBuilder(
runtime_params=config_args.recursive_replace(
nonlinear_theta_method.NewtonRaphsonRuntimeParams(),
**newton_raphson_params,
**stepper_model['stepper_config'],
)
)
elif stepper_type == 'optimizer':
optimizer_params = stepper_config.pop('optimizer_params', {})
if not isinstance(optimizer_params, dict):
raise ValueError('optimizer_params must be a dict.')
optimizer_params.update(stepper_config)
# Remove params from other steppers with nested configs, if present.
optimizer_params.pop('newton_raphson_params', None)
return nonlinear_theta_method.OptimizerThetaMethodBuilder(
runtime_params=config_args.recursive_replace(
nonlinear_theta_method.OptimizerRuntimeParams(),
**optimizer_params,
**stepper_model['stepper_config'],
)
)
raise ValueError(f'Unknown stepper type: {stepper_type}')
Expand Down
4 changes: 1 addition & 3 deletions torax/examples/iterhybrid_rampup.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@
'd_per': 15,
# use_pereverzev is only used for the linear solver
'use_pereverzev': True,
'newton_raphson_params': {
'log_iterations': False,
},
'log_iterations': False,
},
'time_step_calculator': {
'calculator_type': 'fixed',
Expand Down
39 changes: 28 additions & 11 deletions torax/pedestal_model/pydantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import copy
from typing import Any, Literal

import pydantic
from torax.torax_pydantic import interpolated_param_1d
from torax.torax_pydantic import torax_pydantic


Expand All @@ -35,11 +35,19 @@ class SetPpedTpedRatioNped(torax_pydantic.BaseModelMutable):
rho_norm_ped_top: The location of the pedestal top.
"""
pedestal_model: Literal['set_pped_tpedratio_nped']
Pped: torax_pydantic.Pascal = 1e5
neped: torax_pydantic.Density = 0.7
Pped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(1e5)
)
neped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.7)
)
neped_is_fGW: bool = False
ion_electron_temperature_ratio: torax_pydantic.OpenUnitInterval = 1.0
rho_norm_ped_top: torax_pydantic.UnitInterval = 0.91
ion_electron_temperature_ratio: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(1.0)
)
rho_norm_ped_top: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.91)
)


class SetTpedNped(torax_pydantic.BaseModelMutable):
Expand All @@ -53,16 +61,25 @@ class SetTpedNped(torax_pydantic.BaseModelMutable):
Teped: Electron temperature at the pedestal [keV].
rho_norm_ped_top: The location of the pedestal top.
"""

pedestal_model: Literal['set_tped_nped']
neped: torax_pydantic.Density = 0.7
neped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.7)
)
neped_is_fGW: bool = False
Tiped: torax_pydantic.KiloElectronVolt = 5.0
Teped: torax_pydantic.KiloElectronVolt = 5.0
rho_norm_ped_top: torax_pydantic.UnitInterval = 0.91
Tiped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(5.0)
)
Teped: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(5.0)
)
rho_norm_ped_top: interpolated_param_1d.TimeVaryingScalar = (
torax_pydantic.ValidatedDefault(0.91)
)


class PedestalModel(torax_pydantic.BaseModelMutable):
"""Config for a time step calculator."""
class Pedestal(torax_pydantic.BaseModelMutable):
"""Config for a pedestal model."""
pedestal_config: SetPpedTpedRatioNped | SetTpedNped = pydantic.Field(
discriminator='pedestal_model', default_factory=SetTpedNped,
)
Expand Down
109 changes: 109 additions & 0 deletions torax/stepper/pydantic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pydantic config for Stepper."""
from typing import Any, Literal, Union

import pydantic
from torax.fvm import enums
from torax.torax_pydantic import torax_pydantic


# pylint: disable=invalid-name
class Linear(torax_pydantic.BaseModelMutable):
"""Model for linear stepper.

Attributes:
stepper_type: The type of stepper to use, hardcoded to 'linear'.
theta_imp: The theta value in the theta method 0 = explicit, 1 = fully
implicit, 0.5 = Crank-Nicolson.
predictor_corrector: Enables predictor_corrector iterations with the linear
solver. If False, compilation is faster.
corrector_steps: The number of corrector steps for the predictor-corrector
linear solver. 0 means a pure linear solve with no corrector steps.
convection_dirichlet_mode: See `fvm.convection_terms` docstring,
`dirichlet_mode` argument.
convection_neumann_mode: See `fvm.convection_terms` docstring,
`neumann_mode` argument.
use_pereverzev: Use pereverzev terms for linear solver. Is only applied in
the nonlinear solver for the optional initial guess from the linear solver
chi_per: (deliberately) large heat conductivity for Pereverzev rule.
d_per: (deliberately) large particle diffusion for Pereverzev rule.
"""
stepper_type: Literal['linear']
theta_imp: float = 1.0
predictor_corrector: bool = True
corrector_steps: int = 1
convection_dirichlet_mode: str = 'ghost'
convection_neumann_mode: str = 'ghost'
use_pereverzev: bool = False
chi_per: float = 20.0
d_per: float = 10.0


class NewtonRaphson(Linear):
"""Model for non linear Newton-Raphson stepper.

Attributes:
stepper_type: The type of stepper to use, hardcoded to 'newton_raphson'.
log_iterations: If True, log internal iterations in Newton-Raphson solver.
initial_guess_mode: The initial guess mode for the Newton-Raphson solver.
maxiter: The maximum number of iterations for the Newton-Raphson solver.
tol: The tolerance for the Newton-Raphson solver.
coarse_tol: The coarse tolerance for the Newton-Raphson solver.
delta_reduction_factor: The delta reduction factor for the Newton-Raphson
solver.
tau_min: The minimum value of tau for the Newton-Raphson solver.
"""
stepper_type: Literal['newton_raphson']
log_iterations: bool = False
initial_guess_mode: enums.InitialGuessMode = enums.InitialGuessMode.LINEAR
maxiter: int = 30
tol: float = 1e-5
coarse_tol: float = 1e-2
delta_reduction_factor: float = 0.5
tau_min: float = 0.01


class Optimizer(Linear):
"""A basic version of the pedestal model that uses direct specification.

Attributes:
stepper_type: The type of stepper to use, hardcoded to 'optimizer'.
initial_guess_mode: The initial guess mode for the optimizer.
maxiter: The maximum number of iterations for the optimizer.
tol: The tolerance for the optimizer.
"""
stepper_type: Literal['optimizer']
initial_guess_mode: enums.InitialGuessMode = enums.InitialGuessMode.LINEAR
maxiter: int = 100
tol: float = 1e-12


StepperConfig = Union[Linear, NewtonRaphson, Optimizer]


class Stepper(torax_pydantic.BaseModelMutable):
"""Config for a stepper."""
stepper_config: StepperConfig = pydantic.Field(discriminator='stepper_type')

@pydantic.model_validator(mode='before')
@classmethod
def _conform_data(cls, data: dict[str, Any]) -> dict[str, Any]:
# If we are running with the standard class constructor we don't need to do
# any custom validation.
if 'stepper_config' in data:
return data

return {'stepper_config': data}
1 change: 0 additions & 1 deletion torax/tests/test_data/test_iterhybrid_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@
# (deliberately) large particle diffusion for Pereverzev rule
'd_per': 15,
'use_pereverzev': True,
'newton_raphson_params': {},
},
'time_step_calculator': {
'calculator_type': 'chi',
Expand Down
1 change: 0 additions & 1 deletion torax/tests/test_data/test_iterhybrid_rampup.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,6 @@
'd_per': 15,
# use_pereverzev is only used for the linear solver
'use_pereverzev': True,
'newton_raphson_params': {},
},
'time_step_calculator': {
'calculator_type': 'fixed',
Expand Down
1 change: 0 additions & 1 deletion torax/tests/test_data/test_iterhybrid_rampup_short.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@
'd_per': 15,
# use_pereverzev is only used for the linear solver
'use_pereverzev': True,
'newton_raphson_params': {},
},
'time_step_calculator': {
'calculator_type': 'fixed',
Expand Down
4 changes: 1 addition & 3 deletions torax/tests/test_data/test_newton_raphson_zeroiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@
'stepper_type': 'newton_raphson',
'predictor_corrector': False,
'use_pereverzev': True,
'newton_raphson_params': {
'maxiter': 0,
},
'maxiter': 0,
},
'time_step_calculator': {
'calculator_type': 'chi',
Expand Down
11 changes: 7 additions & 4 deletions torax/torax_pydantic/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"""Pydantic config for Torax."""

from torax.geometry import pydantic_model as geometry_pydantic_model
from torax.pedestal_model import pydantic_model as pedestal_model_config
from torax.pedestal_model import pydantic_model as pedestal_pydantic_model
from torax.stepper import pydantic_model as stepper_pydantic_model
from torax.time_step_calculator import config as time_step_calculator_config
from torax.torax_pydantic import model_base

Expand All @@ -24,10 +25,12 @@ class ToraxConfig(model_base.BaseModelMutable):
"""Base config class for Torax.

Attributes:
time_step_calculator: Config for the time step calculator.
geometry: Config for the geometry.
pedestal: Config for the pedestal model.
stepper: Config for the stepper.
time_step_calculator: Config for the time step calculator.
"""

geometry: geometry_pydantic_model.Geometry
pedestal: pedestal_model_config.PedestalModel
pedestal: pedestal_pydantic_model.Pedestal
stepper: stepper_pydantic_model.Stepper
time_step_calculator: time_step_calculator_config.TimeStepCalculator
Loading