diff --git a/src/simsopt/geo/curve.py b/src/simsopt/geo/curve.py index f0dce4e4c..2ee29aa96 100644 --- a/src/simsopt/geo/curve.py +++ b/src/simsopt/geo/curve.py @@ -1020,8 +1020,7 @@ def __init__(self, quadpoints, order, G=0, H=0, dofs=None): pure, x0=np.concatenate(self.modes), external_dof_setter=Curve2D.set_dofs_impl, - names=self._make_names(), - gamma_pure=pure + names=self._make_names() ) else: super().__init__( @@ -1029,8 +1028,7 @@ def __init__(self, quadpoints, order, G=0, H=0, dofs=None): pure, dofs=dofs, external_dof_setter=Curve2D.set_dofs_impl, - names=self._make_names(), - gamma_pure=pure + names=self._make_names() ) self.dgamma_by_dpoint_pure = jit(lambda d, p: jacfwd(pure, argnums=1)(d, p)) @@ -1247,13 +1245,10 @@ def __init__(self, curve2d, surf): self.dgammadash_by_dsurf_vjp_jax = jit(lambda gamma2d, surf_dofs, v: vjp(self.gammadash_jax, gamma2d, surf_dofs)[1](v)[1]) # GAMMADASHDASH - self.dxdt_times_dc2ddl = lambda gamma, sdofs, p: jvp(lambda g: self.gamma_pure(g, sdofs, p), (gamma,), (self.curve2d.gammadash(),))[1] - self.dxxdtt_times_dc2ddlsq = lambda gamma, sdofs, p: jvp(lambda g: self.dxdt_times_dc2ddl(g, sdofs, p), (gamma,), (self.curve2d.gammadash(),))[1] - - self.gammadashdash_pure = lambda g, sdofs, p: jvp(lambda g2: self.gamma_pure(g2, sdofs, p), (g,), (self.curve2d.gammadashdash(),))[1] + self.dxxdtt_times_dc2ddlsq(g, sdofs, p) - - - + self.gdashtmp = lambda g1, sdofs, p: jvp(lambda g2: self.gamma_pure(g2, sdofs, p), (g1,), (self.curve2d.gammadash(),))[1] + self.gdashdash_part1 = lambda gamma, sdofs, p: jvp(lambda g2: self.gdashtmp(g2, sdofs, p), (gamma,), (self.curve2d.gammadash(),))[1] + self.gdashdash_part2 = lambda g, sdofs, p: jvp(lambda g2: self.gamma_pure(g2, sdofs, p), (g,), (self.curve2d.gammadashdash(),))[1] + self.gammadashdash_pure = lambda g, sdofs, p: self.gdashdash_part1(g, sdofs, p) + self.gdashdash_part2(g, sdofs, p) self.gammadashdash_jax = jit(lambda g, sdofs: self.gammadashdash_pure(g, sdofs, self.quadpoints)) self.dgammadashdash_by_dcurve_jax = jit(lambda gamma2d, surf_dofs: jacfwd(self.gammadashdash_jax, argnums=0)(gamma2d, surf_dofs)) @@ -1263,7 +1258,11 @@ def __init__(self, curve2d, surf): # GAMMADASHDASHDASH - self.gammadashdashdash_pure = lambda g, sdofs, p: jvp(lambda g2: self.gammadashdash_pure(g2, sdofs, p), (g,), (self.curve2d.gammadashdashdash(),))[1] + self.gdashdashdash_part1 = lambda g1, sdofs, p: jvp(lambda g2: self.gdashdash_part1(g2, sdofs, p), (g1,), (self.curve2d.gammadash(),))[1] + self.gdashdashdash_part2 = lambda gamma, sdofs, p: 3.0 * jvp(lambda g2: self.gdashtmp(g2, sdofs, p), (gamma,), (self.curve2d.gammadashdash(),))[1] + self.gdashdashdash_part3 = lambda g, sdofs, p: jvp(lambda g2: self.gamma_pure(g2, sdofs, p), (g,), (self.curve2d.gammadashdashdash(),))[1] + self.gammadashdashdash_pure = lambda g, sdofs, p: self.gdashdashdash_part1(g, sdofs, p) + self.gdashdashdash_part2(g, sdofs, p) + self.gdashdashdash_part3(g, sdofs, p) + self.gammadashdashdash_jax = jit(lambda g, sdofs: self.gammadashdashdash_pure(g, sdofs, self.quadpoints)) self.dgammadashdashdash_by_dcurve_jax = jit(lambda gamma2d, surf_dofs: jacfwd(self.gammadashdashdash_jax, argnums=0)(gamma2d, surf_dofs)) self.dgammadashdashdash_by_dsurf_jax = jit(lambda gamma2d, surf_dofs: jacfwd(self.gammadashdashdash_jax, argnums=1)(gamma2d, surf_dofs))