Skip to content

Commit

Permalink
Added normal factor to the CurveCWSFourier class
Browse files Browse the repository at this point in the history
  • Loading branch information
abaillod committed Sep 20, 2024
1 parent f9593e9 commit a86369d
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/simsopt/geo/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def normal(curve_dofs, qpts, order, G, H, surf_dofs, mpol, ntor, nfp):

return n

def nfactor(gamma2d, surf_dofs, qpts, mpol, ntor, nfp, direction='z'):
def nfactor(curve_dofs, qpts, order, G, H, surf_dofs, mpol, ntor, nfp, direction='z'):
"""Compute the scalar product between the unitary vector normal to the surface and some direction.
Args:
Expand All @@ -1128,9 +1128,9 @@ def nfactor(gamma2d, surf_dofs, qpts, mpol, ntor, nfp, direction='z'):
- Scalar product between the unitary normal vector and the access direction.
"""
if direction=='z':
return normal(gamma2d, surf_dofs, qpts, mpol, ntor, nfp)[:,2]
return normal(curve_dofs, qpts, order, G, H, surf_dofs, mpol, ntor, nfp)[:,2]
elif direction=='r':
return normal(gamma2d, surf_dofs, qpts, mpol, ntor, nfp)[:,0]
return normal(curve_dofs, qpts, order, G, H, surf_dofs, mpol, ntor, nfp)[:,0]

class CurveCWSFourier( Curve, sopp.Curve ):
"""Curve that lies on a surface
Expand Down Expand Up @@ -1213,6 +1213,14 @@ def __init__(self, quadpoints, order, surf, G=0, H=0, **kwargs):
self.dtorsion_by_dcoeff_vjp_jax = jit(lambda cdofs, sdofs, v: vjp(lambda x: torsion_pure(self.gammadash_jax(x, sdofs), self.gammadashdash_jax(x, sdofs), self.gammadashdashdash_jax(x, sdofs)), cdofs)[1](v)[0])
self.dtorsion_by_dsurf_vjp_jax = jit(lambda cdofs, sdofs, v: vjp(lambda x: torsion_pure(self.gammadash_jax(cdofs, x), self.gammadashdash_jax(cdofs, x), self.gammadashdashdash_jax(cdofs, x)), sdofs)[1](v)[0])


# NORMAL
self.snz = lambda cdofs, sdofs: nfactor(cdofs, quadpoints, order, G, H, sdofs, self.surf.mpol, self.surf.ntor, self.surf.nfp, direction='z')
self.snr = lambda cdofs, sdofs: nfactor(cdofs, quadpoints, order, G, H, sdofs, self.surf.mpol, self.surf.ntor, self.surf.nfp, direction='r')

self.dsnz_by_dcoeff_vjp_jax = lambda cdofs, sdofs, v: vjp(lambda x: self.snz(x, sdofs), cdofs)[1](v)[0]
self.dsnr_by_dcoeff_vjp_jax = lambda cdofs, sdofs, v: vjp(lambda x: self.snr(x, sdofs), cdofs)[1](v)[0]

def set_dofs(self, dofs):
self.local_x = dofs
sopp.Curve.set_dofs(self, dofs)
Expand Down Expand Up @@ -1243,6 +1251,7 @@ def _make_names(self):

return dofs_name


# GAMMA
# =====
def gamma_impl(self, gamma, quadpoints):
Expand Down Expand Up @@ -1578,10 +1587,7 @@ def zfactor(self):
def dzfactor_by_dcoeff_vjp(self, v):
cdofs = self.get_dofs()
sdofs = self.surf.get_dofs()
dndcurve = self.dsnz_by_dcurve_vjp_jax(cdofs, sdofs, v)
dndcoef = self.curve2d.dgamma_by_dcoeff_vjp(dndcurve)
dndsurf = Derivative({self.surf: self.dsnz_by_dsurf_vjp_jax(cdofs, sdofs, v)})
return dndcoef + dndsurf
return Derivative({self: self.dsnz_by_dcoeff_vjp_jax(cdofs, sdofs, v)})

def rfactor(self):
cdofs = self.get_dofs()
Expand All @@ -1591,7 +1597,4 @@ def rfactor(self):
def drfactor_by_dcoeff_vjp(self, v):
cdofs = self.get_dofs()
sdofs = self.surf.get_dofs()
dndcurve = self.dsnr_by_dcurve_vjp_jax(cdofs, sdofs, v)
dndcoef = self.curve2d.dgamma_by_dcoeff_vjp(dndcurve)
dndsurf = Derivative({self.surf: self.dsnr_by_dsurf_vjp_jax(cdofs, sdofs, v)})
return dndcoef + dndsurf
return Derivative({self: self.dsnr_by_dcoeff_vjp_jax(cdofs, sdofs, v)})

0 comments on commit a86369d

Please sign in to comment.