Skip to content

Commit

Permalink
adding JaxCurveRZFourier with unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mishapadidar committed Jan 8, 2025
1 parent 59f9003 commit 5a53857
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 12 deletions.
133 changes: 131 additions & 2 deletions src/simsopt/geo/curverzfourier.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
import jax.numpy as jnp

import simsoptpp as sopp
from .curve import Curve
from .curve import Curve, JaxCurve

__all__ = ['CurveRZFourier']
__all__ = ['CurveRZFourier', 'JaxCurveRZFourier']


class CurveRZFourier(sopp.CurveRZFourier, Curve):
Expand Down Expand Up @@ -53,3 +54,131 @@ def set_dofs(self, dofs):
"""
self.local_x = dofs
sopp.CurveRZFourier.set_dofs(self, dofs)



def jaxCurveRZFourier_pure(dofs, quadpoints, order, nfp, stellsym):
"""The gamma() method for the JaxCurveRZFourier class. The code is
written as a 'pure' jax function and is ammenable to autodifferentation.
Args:
dofs (array): 1D-array of dofs. Should be ordered as in the JaxCurveRZFourier
class.
quadpoints (array): 1D-array of quadrature points in [0, 1/nfp].
order (int): Maximum mode number of the expansion.
nfp (int): Number of field periods.
stellsym (bool): True for stellarator symmetry, False otherwise.
Returns:
gamma: (nphi, 3) jax numpy array of points on the curve.
"""

phi1d = 2 * jnp.pi * quadpoints
phi, m = jnp.meshgrid(phi1d, jnp.arange(order + 1), indexing='ij')

if stellsym:
rc = dofs[:order+1]
zs = dofs[order+1:]

r = jnp.sum(rc[None, :] * jnp.cos(m * nfp * phi), axis=1)
z = jnp.sum(zs[None, :] * jnp.sin(m[:, 1:] * nfp * phi[:, 1:]), axis=1)

else:
rc = dofs[0 : order+1]
rs = dofs[order+1 : 2*order+1]
zc = dofs[2*order+1: 3*order+2]
zs = dofs[3*order+2: ]

r = (jnp.sum(rc[None, :] * jnp.cos(m * nfp * phi), axis=1) +
jnp.sum(rs[None, :] * jnp.sin(m[:, 1:] * nfp * phi[:, 1:]), axis=1)
)
z = (jnp.sum(zc[None, :] * jnp.cos(m * nfp * phi), axis=1) +
jnp.sum(zs[None, :] * jnp.sin(m[:, 1:] * nfp * phi[:, 1:]), axis=1)
)

x = r * jnp.cos(phi1d)
y = r * jnp.sin(phi1d)

gamma = jnp.zeros((len(quadpoints),3))
gamma = gamma.at[:, 0].add(x)
gamma = gamma.at[:, 1].add(y)
gamma = gamma.at[:, 2].add(z)
return gamma


class JaxCurveRZFourier(JaxCurve):
r'''A ``CurveRZFourier`` that is based off of the ``JaxCurve`` class.
``CurveRZFourier``is a curve that is represented in cylindrical coordinates using the following Fourier series:
.. math::
r(\phi) &= \sum_{m=0}^{\text{order}} r_{c,m}\cos(n_{\text{fp}} m \phi) + \sum_{m=1}^{\text{order}} r_{s,m}\sin(n_{\text{fp}} m \phi) \\
z(\phi) &= \sum_{m=0}^{\text{order}} z_{c,m}\cos(n_{\text{fp}} m \phi) + \sum_{m=1}^{\text{order}} z_{s,m}\sin(n_{\text{fp}} m \phi)
If ``stellsym = True``, then the :math:`\sin` terms for :math:`r` and the :math:`\cos` terms for :math:`z` are zero.
For the ``stellsym = False`` case, the dofs are stored in the order
.. math::
[r_{c,0}, \cdots, r_{c,\text{order}}, r_{s,1}, \cdots, r_{s,\text{order}}, z_{c,0},....]
or in the ``stellsym = True`` case they are stored
.. math::
[r_{c,0},...,r_{c,order},z_{s,1},...,z_{s,order}]
Example usage of JaxCurve's can be found in examples/2_Intermediate/jax_curve_example.py.
Args:
quadpoints (int, array): Either the number of quadrature points or a
1D-array of quadrature points.
order (int): Maximum mode number of the expansion.
nfp (int): Number of field periods.
stellsym (bool): True for stellarator symmetry, False otherwise.
'''

def __init__(self, quadpoints, order, nfp, stellsym, **kwargs):
if isinstance(quadpoints, int):
quadpoints = list(np.linspace(0, 1./nfp, quadpoints, endpoint=False))
elif isinstance(quadpoints, np.ndarray):
quadpoints = list(quadpoints)
pure = lambda dofs, points: jaxCurveRZFourier_pure(
dofs, points, order, nfp, stellsym)

self.order = order
self.nfp = nfp
self.stellsym = stellsym
self.coefficients = np.zeros(self.num_dofs())
if "dofs" not in kwargs:
if "x0" not in kwargs:
kwargs["x0"] = self.coefficients
else:
self.set_dofs_impl(kwargs["x0"])

super().__init__(quadpoints, pure, names=self._make_names(order), **kwargs)

def _make_names(self, order):
if self.stellsym:
r_cos_names = [f'rc({i})' for i in range(0, order + 1)]
r_names = r_cos_names
z_sin_names = [f'zs({i})' for i in range(1, order + 1)]
z_names = z_sin_names
else:
r_names = ['rc(0)']
r_cos_names = [f'rc({i})' for i in range(1, order + 1)]
r_sin_names = [f'rs({i})' for i in range(1, order + 1)]
r_names += r_cos_names + r_sin_names
z_names = ['zc(0)']
z_cos_names = [f'zc({i})' for i in range(1, order + 1)]
z_sin_names = [f'zs({i})' for i in range(1, order + 1)]
z_names += z_cos_names + z_sin_names

return r_names + z_names

def num_dofs(self):
return (self.order+1) + self.order if self.stellsym else 4*self.order+2

def get_dofs(self):
return self.coefficients

def set_dofs_impl(self, dofs):
self.coefficients[:] = dofs[:]
12 changes: 8 additions & 4 deletions src/simsopt/geo/curvexyzfouriersymmetries.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,15 @@ class CurveXYZFourierSymmetries(JaxCurve):
\hat x(\theta) &= x_{c, 0} + \sum_{m=1}^{\text{order}} \left[ x_{c, m} \cos(2 \pi n_{\text{fp}} m \theta) + x_{s, m} \sin(2 \pi n_{\text{fp}} m \theta) \right] \\
\hat y(\theta) &= y_{c, 0} + \sum_{m=1}^{\text{order}} \left[ y_{c, m} \cos(2 \pi n_{\text{fp}} m \theta) + y_{s, m} \sin(2 \pi n_{\text{fp}} m \theta) \right] \\
Example usage of JaxCurve's can be found in examples/2_Intermediate/jax_curve_example.py.
Args:
quadpoints: number of grid points/resolution along the curve,
order: how many Fourier harmonics to include in the Fourier representation,
nfp: discrete rotational symmetry number,
stellsym: stellaratory symmetry if True, not stellarator symmetric otherwise,
quadpoints (int, array): Either the number of quadrature points or
1D-array of quadrature points. If quadpoints is an int, then the quadrature points
will be linearly space in [0,1] (not [0, 1/nfp]).
order (int): Maximum mode number of the expansion.
nfp (int): Number of field periods (discrete rotational symmetry).
stellsym (bool): True for stellarator symmetry, False otherwise.
ntor: the number of times the curve wraps toroidally before biting its tail. Note,
it is assumed that nfp and ntor are coprime. If they are not coprime,
then then the curve actually has nfp_new:=nfp // gcd(nfp, ntor),
Expand Down
44 changes: 41 additions & 3 deletions tests/geo/test_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from simsopt._core.json import GSONEncoder, GSONDecoder, SIMSON
from simsopt.geo.curvexyzfourier import CurveXYZFourier, JaxCurveXYZFourier
from simsopt.geo.curverzfourier import CurveRZFourier
from simsopt.geo.curverzfourier import CurveRZFourier, JaxCurveRZFourier
from simsopt.geo.curveplanarfourier import CurvePlanarFourier
from simsopt.geo.curvehelical import CurveHelical
from simsopt.geo.curvexyzfouriersymmetries import CurveXYZFourierSymmetries
Expand Down Expand Up @@ -71,6 +71,8 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])):
curve = JaxCurveXYZFourier(x, order)
elif curvetype == "CurveRZFourier":
curve = CurveRZFourier(x, order, 2, True)
elif curvetype == "JaxCurveRZFourier":
curve = JaxCurveRZFourier(x, order, 2, True)
elif curvetype == "CurveHelical":
curve = CurveHelical(x, order, 5, 2, 1.0, 0.3)
elif curvetype == "CurveHelicalInitx0":
Expand All @@ -91,7 +93,7 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])):
dofs[1] = 1.
dofs[2*order + 3] = 1.
dofs[4*order + 3] = 1.
elif curvetype in ["CurveRZFourier", "CurvePlanarFourier"]:
elif curvetype in ["CurveRZFourier", "JaxCurveRZFourier", "CurvePlanarFourier"]:
dofs[0] = 1.
dofs[1] = 0.1
dofs[order+1] = 0.1
Expand Down Expand Up @@ -136,7 +138,7 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])):

class Testing(unittest.TestCase):

curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier", "CurvePlanarFourier", "CurveHelical", "CurveXYZFourierSymmetries1","CurveXYZFourierSymmetries2", "CurveXYZFourierSymmetries3", "CurveHelicalInitx0"]
curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier", "JaxCurveRZFourier", "CurvePlanarFourier", "CurveHelical", "CurveXYZFourierSymmetries1","CurveXYZFourierSymmetries2", "CurveXYZFourierSymmetries3", "CurveHelicalInitx0"]

def get_curvexyzfouriersymmetries(self, stellsym=True, x=None, nfp=None, ntor=1):
# returns a CurveXYZFourierSymmetries that is randomly perturbed
Expand All @@ -163,6 +165,42 @@ def get_curvexyzfouriersymmetries(self, stellsym=True, x=None, nfp=None, ntor=1)

return curve

def test_jaxcurverzfourier(self):
"""
Test the JaxCurveRZFourier class is equil to the CurveRZFourier.
"""

# configuration parameters
quadpoints = 7
nfp = 2
curve = JaxCurveRZFourier(quadpoints, order=2, nfp=nfp, stellsym = True)
self.assertEqual(len(curve.dof_names), curve.num_dofs(),
"JaxCurveRZFourier, incorrect number of dofs"
)

curve = JaxCurveRZFourier(quadpoints, order=1, nfp=2, stellsym = True)
curve.set('rc(0)', 1.0)
curve.set('rc(1)', 0.1)
curve.set('zs(1)', 0.2)
curve_actual = CurveRZFourier(quadpoints, order=1, nfp=2, stellsym = True)
curve_actual.set('x0', 1.0)
curve_actual.set('x1', 0.1)
curve_actual.set('x2', 0.2)
# assert np.testing.assert_allclose(curve.gamma(), curve_actual.gamma(), atol=1e-14), "JaxCurveRZFourier gamma incorrect."
self.assertTrue(np.allclose(curve.gamma(), curve_actual.gamma(), atol=1e-14),
"JaxCurveRZFourier gamma incorrect.")

curve = JaxCurveRZFourier(quadpoints, order=1, nfp=2, stellsym = False)
curve.set('rc(0)', 1.0)
curve.set('rs(1)', 0.3)
curve.set('zc(0)', 0.2)
curve_actual = CurveRZFourier(quadpoints, order=1, nfp=2, stellsym = False)
curve_actual.set('x0', 1.0)
curve_actual.set('x2', 0.3)
curve_actual.set('x3', 0.2)
self.assertTrue(np.allclose(curve.gamma(), curve_actual.gamma(), atol=1e-14),
"JaxCurveRZFourier gamma incorrect.")

def test_curvexyzsymmetries_raisesexception(self):
# test ensures that an exception is raised when you try and create a curvexyzfouriersymmetries
# where gcd(ntor, nfp) != 1.
Expand Down
6 changes: 3 additions & 3 deletions tests/geo/test_curve_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from simsopt.geo import parameters
from simsopt.geo.curve import RotatedCurve, create_equally_spaced_curves
from simsopt.geo.curvexyzfourier import CurveXYZFourier, JaxCurveXYZFourier
from simsopt.geo.curverzfourier import CurveRZFourier
from simsopt.geo.curverzfourier import CurveRZFourier, JaxCurveRZFourier

Check failure on line 9 in tests/geo/test_curve_objectives.py

View workflow job for this annotation

GitHub Actions / CI (3.9)

Ruff (F401)

tests/geo/test_curve_objectives.py:9:56: F401 `simsopt.geo.curverzfourier.JaxCurveRZFourier` imported but unused
from simsopt.geo.curveobjectives import CurveLength, LpCurveCurvature, \
LpCurveTorsion, CurveCurveDistance, ArclengthVariation, \
MeanSquaredCurvature, CurveSurfaceDistance, LinkingNumber
Expand All @@ -21,7 +21,7 @@

class Testing(unittest.TestCase):

curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier"]
curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier", "JaxCurveRZFourier"]

def create_curve(self, curvetype, rotated):
np.random.seed(1)
Expand All @@ -43,7 +43,7 @@ def create_curve(self, curvetype, rotated):
dofs[1] = 1.
dofs[2*order+3] = 1.
dofs[4*order+3] = 1.
elif curvetype in ["CurveRZFourier"]:
elif curvetype in ["CurveRZFourier", "JaxCurveRZFourier"]:
dofs[0] = 1.
dofs[1] = 0.1
dofs[order+1] = 0.1
Expand Down

0 comments on commit 5a53857

Please sign in to comment.