Skip to content

Commit

Permalink
Add integrate method to geometry module.
Browse files Browse the repository at this point in the history
This method integrates a value `x` over the rhon grid. Cell variables in TORAX are defined as the average of the face values. This method integrates that face value over the rhon grid implicitly using the trapezium rule to sum the averaged face values by the face grid spacing.

PiperOrigin-RevId: 694488518
  • Loading branch information
Nush395 authored and Torax team committed Nov 8, 2024
1 parent acd36e5 commit 9478735
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
22 changes: 20 additions & 2 deletions torax/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import jax.numpy as jnp
import numpy as np
import scipy
from torax import array_typing
from torax import constants
from torax import geometry_loader
from torax import interpolated_param
Expand Down Expand Up @@ -109,6 +110,25 @@ def face_to_cell(face: chex.Array) -> chex.Array:
return 0.5 * (face[:-1] + face[1:])


def integrate(
x: array_typing.ArrayFloat, geo: Geometry
) -> array_typing.ScalarFloat:
r"""Integrate a value `x` over the rhon grid.
Cell variables in TORAX are defined as the average of the face values. This
method integrates that face value over the rhon grid implicitly using the
trapezium rule to sum the averaged face values by the face grid spacing.
Args:
x: The cell averaged value to integrate.
geo: The geometry instance.
Returns:
Face value integrated over the rhon grid: $\int_0^1 x_{face} d\hat{rho}$
"""
return jnp.sum(x * geo.drho_norm)


@enum.unique
class GeometryType(enum.Enum):
"""Integer enum for geometry type.
Expand Down Expand Up @@ -1527,6 +1547,4 @@ def build_standard_geometry(
Phibdot=np.asarray(0.0),
_z_magnetic_axis=intermediate.z_magnetic_axis,
)


# pylint: enable=invalid-name
10 changes: 10 additions & 0 deletions torax/tests/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,16 @@ def test_build_geometry_from_eqdsk(self):
intermediate = geometry.StandardGeometryIntermediates.from_eqdsk()
geometry.build_standard_geometry(intermediate)

def test_integrate(self):
"""Test that the integrate method works as expected."""
x = jax.random.uniform(jax.random.PRNGKey(0), shape=(6,))
geo = geometry.build_circular_geometry(n_rho=5)

np.testing.assert_array_equal(
geometry.integrate(geometry.face_to_cell(x), geo),
jax.scipy.integrate.trapezoid(x, geo.rho_face_norm),
)


def face_to_cell(n_rho, face):
cell = np.zeros(n_rho)
Expand Down

0 comments on commit 9478735

Please sign in to comment.