From 5a5385748c08c9c4a486d26c634124404552e481 Mon Sep 17 00:00:00 2001 From: mishapadidar Date: Wed, 8 Jan 2025 12:29:34 -0500 Subject: [PATCH] adding JaxCurveRZFourier with unit tests --- src/simsopt/geo/curverzfourier.py | 133 ++++++++++++++++++- src/simsopt/geo/curvexyzfouriersymmetries.py | 12 +- tests/geo/test_curve.py | 44 +++++- tests/geo/test_curve_objectives.py | 6 +- 4 files changed, 183 insertions(+), 12 deletions(-) diff --git a/src/simsopt/geo/curverzfourier.py b/src/simsopt/geo/curverzfourier.py index fc8eb219c..cba080723 100644 --- a/src/simsopt/geo/curverzfourier.py +++ b/src/simsopt/geo/curverzfourier.py @@ -1,9 +1,10 @@ import numpy as np +import jax.numpy as jnp import simsoptpp as sopp -from .curve import Curve +from .curve import Curve, JaxCurve -__all__ = ['CurveRZFourier'] +__all__ = ['CurveRZFourier', 'JaxCurveRZFourier'] class CurveRZFourier(sopp.CurveRZFourier, Curve): @@ -53,3 +54,131 @@ def set_dofs(self, dofs): """ self.local_x = dofs sopp.CurveRZFourier.set_dofs(self, dofs) + + + +def jaxCurveRZFourier_pure(dofs, quadpoints, order, nfp, stellsym): + """The gamma() method for the JaxCurveRZFourier class. The code is + written as a 'pure' jax function and is ammenable to autodifferentation. + + Args: + dofs (array): 1D-array of dofs. Should be ordered as in the JaxCurveRZFourier + class. + quadpoints (array): 1D-array of quadrature points in [0, 1/nfp]. + order (int): Maximum mode number of the expansion. + nfp (int): Number of field periods. + stellsym (bool): True for stellarator symmetry, False otherwise. + + Returns: + gamma: (nphi, 3) jax numpy array of points on the curve. + """ + + phi1d = 2 * jnp.pi * quadpoints + phi, m = jnp.meshgrid(phi1d, jnp.arange(order + 1), indexing='ij') + + if stellsym: + rc = dofs[:order+1] + zs = dofs[order+1:] + + r = jnp.sum(rc[None, :] * jnp.cos(m * nfp * phi), axis=1) + z = jnp.sum(zs[None, :] * jnp.sin(m[:, 1:] * nfp * phi[:, 1:]), axis=1) + + else: + rc = dofs[0 : order+1] + rs = dofs[order+1 : 2*order+1] + zc = dofs[2*order+1: 3*order+2] + zs = dofs[3*order+2: ] + + r = (jnp.sum(rc[None, :] * jnp.cos(m * nfp * phi), axis=1) + + jnp.sum(rs[None, :] * jnp.sin(m[:, 1:] * nfp * phi[:, 1:]), axis=1) + ) + z = (jnp.sum(zc[None, :] * jnp.cos(m * nfp * phi), axis=1) + + jnp.sum(zs[None, :] * jnp.sin(m[:, 1:] * nfp * phi[:, 1:]), axis=1) + ) + + x = r * jnp.cos(phi1d) + y = r * jnp.sin(phi1d) + + gamma = jnp.zeros((len(quadpoints),3)) + gamma = gamma.at[:, 0].add(x) + gamma = gamma.at[:, 1].add(y) + gamma = gamma.at[:, 2].add(z) + return gamma + + +class JaxCurveRZFourier(JaxCurve): + r'''A ``CurveRZFourier`` that is based off of the ``JaxCurve`` class. + ``CurveRZFourier``is a curve that is represented in cylindrical coordinates using the following Fourier series: + + .. math:: + r(\phi) &= \sum_{m=0}^{\text{order}} r_{c,m}\cos(n_{\text{fp}} m \phi) + \sum_{m=1}^{\text{order}} r_{s,m}\sin(n_{\text{fp}} m \phi) \\ + z(\phi) &= \sum_{m=0}^{\text{order}} z_{c,m}\cos(n_{\text{fp}} m \phi) + \sum_{m=1}^{\text{order}} z_{s,m}\sin(n_{\text{fp}} m \phi) + + If ``stellsym = True``, then the :math:`\sin` terms for :math:`r` and the :math:`\cos` terms for :math:`z` are zero. + + For the ``stellsym = False`` case, the dofs are stored in the order + + .. math:: + [r_{c,0}, \cdots, r_{c,\text{order}}, r_{s,1}, \cdots, r_{s,\text{order}}, z_{c,0},....] + + or in the ``stellsym = True`` case they are stored + + .. math:: + [r_{c,0},...,r_{c,order},z_{s,1},...,z_{s,order}] + + Example usage of JaxCurve's can be found in examples/2_Intermediate/jax_curve_example.py. + + Args: + quadpoints (int, array): Either the number of quadrature points or a + 1D-array of quadrature points. + order (int): Maximum mode number of the expansion. + nfp (int): Number of field periods. + stellsym (bool): True for stellarator symmetry, False otherwise. + ''' + + def __init__(self, quadpoints, order, nfp, stellsym, **kwargs): + if isinstance(quadpoints, int): + quadpoints = list(np.linspace(0, 1./nfp, quadpoints, endpoint=False)) + elif isinstance(quadpoints, np.ndarray): + quadpoints = list(quadpoints) + pure = lambda dofs, points: jaxCurveRZFourier_pure( + dofs, points, order, nfp, stellsym) + + self.order = order + self.nfp = nfp + self.stellsym = stellsym + self.coefficients = np.zeros(self.num_dofs()) + if "dofs" not in kwargs: + if "x0" not in kwargs: + kwargs["x0"] = self.coefficients + else: + self.set_dofs_impl(kwargs["x0"]) + + super().__init__(quadpoints, pure, names=self._make_names(order), **kwargs) + + def _make_names(self, order): + if self.stellsym: + r_cos_names = [f'rc({i})' for i in range(0, order + 1)] + r_names = r_cos_names + z_sin_names = [f'zs({i})' for i in range(1, order + 1)] + z_names = z_sin_names + else: + r_names = ['rc(0)'] + r_cos_names = [f'rc({i})' for i in range(1, order + 1)] + r_sin_names = [f'rs({i})' for i in range(1, order + 1)] + r_names += r_cos_names + r_sin_names + z_names = ['zc(0)'] + z_cos_names = [f'zc({i})' for i in range(1, order + 1)] + z_sin_names = [f'zs({i})' for i in range(1, order + 1)] + z_names += z_cos_names + z_sin_names + + return r_names + z_names + + def num_dofs(self): + return (self.order+1) + self.order if self.stellsym else 4*self.order+2 + + def get_dofs(self): + return self.coefficients + + def set_dofs_impl(self, dofs): + self.coefficients[:] = dofs[:] \ No newline at end of file diff --git a/src/simsopt/geo/curvexyzfouriersymmetries.py b/src/simsopt/geo/curvexyzfouriersymmetries.py index b537e94f3..ad5988b80 100644 --- a/src/simsopt/geo/curvexyzfouriersymmetries.py +++ b/src/simsopt/geo/curvexyzfouriersymmetries.py @@ -73,11 +73,15 @@ class CurveXYZFourierSymmetries(JaxCurve): \hat x(\theta) &= x_{c, 0} + \sum_{m=1}^{\text{order}} \left[ x_{c, m} \cos(2 \pi n_{\text{fp}} m \theta) + x_{s, m} \sin(2 \pi n_{\text{fp}} m \theta) \right] \\ \hat y(\theta) &= y_{c, 0} + \sum_{m=1}^{\text{order}} \left[ y_{c, m} \cos(2 \pi n_{\text{fp}} m \theta) + y_{s, m} \sin(2 \pi n_{\text{fp}} m \theta) \right] \\ + Example usage of JaxCurve's can be found in examples/2_Intermediate/jax_curve_example.py. + Args: - quadpoints: number of grid points/resolution along the curve, - order: how many Fourier harmonics to include in the Fourier representation, - nfp: discrete rotational symmetry number, - stellsym: stellaratory symmetry if True, not stellarator symmetric otherwise, + quadpoints (int, array): Either the number of quadrature points or + 1D-array of quadrature points. If quadpoints is an int, then the quadrature points + will be linearly space in [0,1] (not [0, 1/nfp]). + order (int): Maximum mode number of the expansion. + nfp (int): Number of field periods (discrete rotational symmetry). + stellsym (bool): True for stellarator symmetry, False otherwise. ntor: the number of times the curve wraps toroidally before biting its tail. Note, it is assumed that nfp and ntor are coprime. If they are not coprime, then then the curve actually has nfp_new:=nfp // gcd(nfp, ntor), diff --git a/tests/geo/test_curve.py b/tests/geo/test_curve.py index a997a6e84..500f74bbe 100644 --- a/tests/geo/test_curve.py +++ b/tests/geo/test_curve.py @@ -8,7 +8,7 @@ from simsopt._core.json import GSONEncoder, GSONDecoder, SIMSON from simsopt.geo.curvexyzfourier import CurveXYZFourier, JaxCurveXYZFourier -from simsopt.geo.curverzfourier import CurveRZFourier +from simsopt.geo.curverzfourier import CurveRZFourier, JaxCurveRZFourier from simsopt.geo.curveplanarfourier import CurvePlanarFourier from simsopt.geo.curvehelical import CurveHelical from simsopt.geo.curvexyzfouriersymmetries import CurveXYZFourierSymmetries @@ -71,6 +71,8 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])): curve = JaxCurveXYZFourier(x, order) elif curvetype == "CurveRZFourier": curve = CurveRZFourier(x, order, 2, True) + elif curvetype == "JaxCurveRZFourier": + curve = JaxCurveRZFourier(x, order, 2, True) elif curvetype == "CurveHelical": curve = CurveHelical(x, order, 5, 2, 1.0, 0.3) elif curvetype == "CurveHelicalInitx0": @@ -91,7 +93,7 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])): dofs[1] = 1. dofs[2*order + 3] = 1. dofs[4*order + 3] = 1. - elif curvetype in ["CurveRZFourier", "CurvePlanarFourier"]: + elif curvetype in ["CurveRZFourier", "JaxCurveRZFourier", "CurvePlanarFourier"]: dofs[0] = 1. dofs[1] = 0.1 dofs[order+1] = 0.1 @@ -136,7 +138,7 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])): class Testing(unittest.TestCase): - curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier", "CurvePlanarFourier", "CurveHelical", "CurveXYZFourierSymmetries1","CurveXYZFourierSymmetries2", "CurveXYZFourierSymmetries3", "CurveHelicalInitx0"] + curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier", "JaxCurveRZFourier", "CurvePlanarFourier", "CurveHelical", "CurveXYZFourierSymmetries1","CurveXYZFourierSymmetries2", "CurveXYZFourierSymmetries3", "CurveHelicalInitx0"] def get_curvexyzfouriersymmetries(self, stellsym=True, x=None, nfp=None, ntor=1): # returns a CurveXYZFourierSymmetries that is randomly perturbed @@ -163,6 +165,42 @@ def get_curvexyzfouriersymmetries(self, stellsym=True, x=None, nfp=None, ntor=1) return curve + def test_jaxcurverzfourier(self): + """ + Test the JaxCurveRZFourier class is equil to the CurveRZFourier. + """ + + # configuration parameters + quadpoints = 7 + nfp = 2 + curve = JaxCurveRZFourier(quadpoints, order=2, nfp=nfp, stellsym = True) + self.assertEqual(len(curve.dof_names), curve.num_dofs(), + "JaxCurveRZFourier, incorrect number of dofs" + ) + + curve = JaxCurveRZFourier(quadpoints, order=1, nfp=2, stellsym = True) + curve.set('rc(0)', 1.0) + curve.set('rc(1)', 0.1) + curve.set('zs(1)', 0.2) + curve_actual = CurveRZFourier(quadpoints, order=1, nfp=2, stellsym = True) + curve_actual.set('x0', 1.0) + curve_actual.set('x1', 0.1) + curve_actual.set('x2', 0.2) + # assert np.testing.assert_allclose(curve.gamma(), curve_actual.gamma(), atol=1e-14), "JaxCurveRZFourier gamma incorrect." + self.assertTrue(np.allclose(curve.gamma(), curve_actual.gamma(), atol=1e-14), + "JaxCurveRZFourier gamma incorrect.") + + curve = JaxCurveRZFourier(quadpoints, order=1, nfp=2, stellsym = False) + curve.set('rc(0)', 1.0) + curve.set('rs(1)', 0.3) + curve.set('zc(0)', 0.2) + curve_actual = CurveRZFourier(quadpoints, order=1, nfp=2, stellsym = False) + curve_actual.set('x0', 1.0) + curve_actual.set('x2', 0.3) + curve_actual.set('x3', 0.2) + self.assertTrue(np.allclose(curve.gamma(), curve_actual.gamma(), atol=1e-14), + "JaxCurveRZFourier gamma incorrect.") + def test_curvexyzsymmetries_raisesexception(self): # test ensures that an exception is raised when you try and create a curvexyzfouriersymmetries # where gcd(ntor, nfp) != 1. diff --git a/tests/geo/test_curve_objectives.py b/tests/geo/test_curve_objectives.py index 78dc54467..269c7f67c 100644 --- a/tests/geo/test_curve_objectives.py +++ b/tests/geo/test_curve_objectives.py @@ -6,7 +6,7 @@ from simsopt.geo import parameters from simsopt.geo.curve import RotatedCurve, create_equally_spaced_curves from simsopt.geo.curvexyzfourier import CurveXYZFourier, JaxCurveXYZFourier -from simsopt.geo.curverzfourier import CurveRZFourier +from simsopt.geo.curverzfourier import CurveRZFourier, JaxCurveRZFourier from simsopt.geo.curveobjectives import CurveLength, LpCurveCurvature, \ LpCurveTorsion, CurveCurveDistance, ArclengthVariation, \ MeanSquaredCurvature, CurveSurfaceDistance, LinkingNumber @@ -21,7 +21,7 @@ class Testing(unittest.TestCase): - curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier"] + curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier", "JaxCurveRZFourier"] def create_curve(self, curvetype, rotated): np.random.seed(1) @@ -43,7 +43,7 @@ def create_curve(self, curvetype, rotated): dofs[1] = 1. dofs[2*order+3] = 1. dofs[4*order+3] = 1. - elif curvetype in ["CurveRZFourier"]: + elif curvetype in ["CurveRZFourier", "JaxCurveRZFourier"]: dofs[0] = 1. dofs[1] = 0.1 dofs[order+1] = 0.1