Skip to content

Commit

Permalink
adding more jax friendly functions to the JaxCurve class: kappa, tors…
Browse files Browse the repository at this point in the history
…ion, kappadash, frenet_frame. These can be used for autodiff
  • Loading branch information
mishapadidar committed Dec 11, 2024
1 parent d6bea2b commit ba10493
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/simsopt/geo/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ def kappa_pure(d1gamma, d2gamma):
kappagrad0 = jit(lambda d1gamma, d2gamma: jacfwd(lambda d1g: kappa_pure(d1g, d2gamma))(d1gamma))
kappagrad1 = jit(lambda d1gamma, d2gamma: jacfwd(lambda d2g: kappa_pure(d1gamma, d2g))(d2gamma))

@jit
def kappadash_pure(d1gamma, d2gamma, d3gamma):
r"""
A jax-friendly function for computing :math:`\kappa'(\phi)`, where :math:`\kappa` is the curvature.
"""
norm = lambda a: jnp.linalg.norm(a, axis=1)
inner = lambda a, b: jnp.sum(a*b, axis=1)
cross = lambda a, b: jnp.cross(a, b, axis=1)
dkappa_by_dphi = inner(cross(d1gamma, d2gamma), cross(d1gamma, d3gamma))/(norm(cross(d1gamma, d2gamma)) * norm(d1gamma)**3) \
- 3 * inner(d1gamma, d2gamma) * norm(cross(d1gamma, d2gamma))/norm(d1gamma)**5
return dkappa_by_dphi

@jit
def torsion_pure(d1gamma, d2gamma, d3gamma):
Expand All @@ -54,6 +65,30 @@ def torsion_pure(d1gamma, d2gamma, d3gamma):
torsionvjp1 = jit(lambda d1gamma, d2gamma, d3gamma, v: vjp(lambda d2g: torsion_pure(d1gamma, d2g, d3gamma), d2gamma)[1](v)[0])
torsionvjp2 = jit(lambda d1gamma, d2gamma, d3gamma, v: vjp(lambda d3g: torsion_pure(d1gamma, d2gamma, d3g), d3gamma)[1](v)[0])

@jit
def frenet_frame_pure(gammadash, gammadashdash):
r"""
A jax-friendly function for computing the Frenet frame.
This function returns the Frenet frame, :math:`(\mathbf{t}, \mathbf{n}, \mathbf{b})`,
associated to the curve.
"""

# gammadash = self.gammadash_jax(x)
# gammadashdash = self.gammadashdash_jax(x)
l = jnp.linalg.norm(gammadash, axis=1)
# l = self.incremental_arclength_jax(x)
norm = lambda a: jnp.linalg.norm(a, axis=1)
inner = lambda a, b: jnp.sum(a*b, axis=1)
# N = len(self.quadpoints)
# t, n, b = (jnp.zeros((N, 3)), jnp.zeros((N, 3)), jnp.zeros((N, 3)))
t = (1./l[:, None]) * gammadash

tdash = (1./l[:, None])**2 * (l[:, None] * gammadashdash
- (inner(gammadash, gammadashdash)/l)[:, None] * gammadash
)
n = (1./norm(tdash))[:, None] * tdash
b = jnp.cross(t, n, axis=1)
return t, n, b

class Curve(Optimizable):
"""
Expand Down Expand Up @@ -445,6 +480,13 @@ def __init__(self, quadpoints, gamma_pure, **kwargs):

self.dtorsion_by_dcoeff_vjp_jax = jit(lambda x, v: vjp(lambda d: torsion_pure(self.gammadash_jax(d), self.gammadashdash_jax(d), self.gammadashdashdash_jax(d)), x)[1](v)[0])

# jax can differentiate through these
self.incremental_arclength_jax = jit(lambda x: incremental_arclength_pure(self.gammadash_jax(x)))
self.kappa_jax = jit(lambda x: kappa_pure(self.gammadash_jax(x), self.gammadashdash_jax(x)))
self.torsion_jax = jit(lambda x: torsion_pure(self.gammadash_jax(x), self.gammadashdash_jax(x), self.gammadashdashdash_jax(x)))
self.kappadash_jax = jit(lambda x: kappadash_pure(self.gammadash_jax(x), self.gammadashdash_jax(x), self.gammadashdashdash_jax(x)))
self.frenet_frame_jax = jit(lambda x: frenet_frame_pure(self.gammadash_jax(x), self.gammadashdash_jax(x)))

def set_dofs(self, dofs):
self.local_x = dofs
sopp.Curve.set_dofs(self, dofs)
Expand Down

0 comments on commit ba10493

Please sign in to comment.