diff --git a/src/simsopt/geo/curve.py b/src/simsopt/geo/curve.py index 2d07d7499..966657f44 100644 --- a/src/simsopt/geo/curve.py +++ b/src/simsopt/geo/curve.py @@ -1112,7 +1112,7 @@ def normal(curve_dofs, qpts, order, G, H, surf_dofs, mpol, ntor, nfp): return n -def nfactor(gamma2d, surf_dofs, qpts, mpol, ntor, nfp, direction='z'): +def nfactor(curve_dofs, qpts, order, G, H, surf_dofs, mpol, ntor, nfp, direction='z'): """Compute the scalar product between the unitary vector normal to the surface and some direction. Args: @@ -1128,9 +1128,9 @@ def nfactor(gamma2d, surf_dofs, qpts, mpol, ntor, nfp, direction='z'): - Scalar product between the unitary normal vector and the access direction. """ if direction=='z': - return normal(gamma2d, surf_dofs, qpts, mpol, ntor, nfp)[:,2] + return normal(curve_dofs, qpts, order, G, H, surf_dofs, mpol, ntor, nfp)[:,2] elif direction=='r': - return normal(gamma2d, surf_dofs, qpts, mpol, ntor, nfp)[:,0] + return normal(curve_dofs, qpts, order, G, H, surf_dofs, mpol, ntor, nfp)[:,0] class CurveCWSFourier( Curve, sopp.Curve ): """Curve that lies on a surface @@ -1213,6 +1213,14 @@ def __init__(self, quadpoints, order, surf, G=0, H=0, **kwargs): self.dtorsion_by_dcoeff_vjp_jax = jit(lambda cdofs, sdofs, v: vjp(lambda x: torsion_pure(self.gammadash_jax(x, sdofs), self.gammadashdash_jax(x, sdofs), self.gammadashdashdash_jax(x, sdofs)), cdofs)[1](v)[0]) self.dtorsion_by_dsurf_vjp_jax = jit(lambda cdofs, sdofs, v: vjp(lambda x: torsion_pure(self.gammadash_jax(cdofs, x), self.gammadashdash_jax(cdofs, x), self.gammadashdashdash_jax(cdofs, x)), sdofs)[1](v)[0]) + + # NORMAL + self.snz = lambda cdofs, sdofs: nfactor(cdofs, quadpoints, order, G, H, sdofs, self.surf.mpol, self.surf.ntor, self.surf.nfp, direction='z') + self.snr = lambda cdofs, sdofs: nfactor(cdofs, quadpoints, order, G, H, sdofs, self.surf.mpol, self.surf.ntor, self.surf.nfp, direction='r') + + self.dsnz_by_dcoeff_vjp_jax = lambda cdofs, sdofs, v: vjp(lambda x: self.snz(x, sdofs), cdofs)[1](v)[0] + self.dsnr_by_dcoeff_vjp_jax = lambda cdofs, sdofs, v: vjp(lambda x: self.snr(x, sdofs), cdofs)[1](v)[0] + def set_dofs(self, dofs): self.local_x = dofs sopp.Curve.set_dofs(self, dofs) @@ -1243,6 +1251,7 @@ def _make_names(self): return dofs_name + # GAMMA # ===== def gamma_impl(self, gamma, quadpoints): @@ -1578,10 +1587,7 @@ def zfactor(self): def dzfactor_by_dcoeff_vjp(self, v): cdofs = self.get_dofs() sdofs = self.surf.get_dofs() - dndcurve = self.dsnz_by_dcurve_vjp_jax(cdofs, sdofs, v) - dndcoef = self.curve2d.dgamma_by_dcoeff_vjp(dndcurve) - dndsurf = Derivative({self.surf: self.dsnz_by_dsurf_vjp_jax(cdofs, sdofs, v)}) - return dndcoef + dndsurf + return Derivative({self: self.dsnz_by_dcoeff_vjp_jax(cdofs, sdofs, v)}) def rfactor(self): cdofs = self.get_dofs() @@ -1591,7 +1597,4 @@ def rfactor(self): def drfactor_by_dcoeff_vjp(self, v): cdofs = self.get_dofs() sdofs = self.surf.get_dofs() - dndcurve = self.dsnr_by_dcurve_vjp_jax(cdofs, sdofs, v) - dndcoef = self.curve2d.dgamma_by_dcoeff_vjp(dndcurve) - dndsurf = Derivative({self.surf: self.dsnr_by_dsurf_vjp_jax(cdofs, sdofs, v)}) - return dndcoef + dndsurf + return Derivative({self: self.dsnr_by_dcoeff_vjp_jax(cdofs, sdofs, v)})