Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
abaillod committed Jul 8, 2024
1 parent ee33d5a commit 103f8de
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions src/simsopt/geo/curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,17 +1020,15 @@ def __init__(self, quadpoints, order, G=0, H=0, dofs=None):
pure,
x0=np.concatenate(self.modes),
external_dof_setter=Curve2D.set_dofs_impl,
names=self._make_names(),
gamma_pure=pure
names=self._make_names()
)
else:
super().__init__(
quadpoints,
pure,
dofs=dofs,
external_dof_setter=Curve2D.set_dofs_impl,
names=self._make_names(),
gamma_pure=pure
names=self._make_names()
)

self.dgamma_by_dpoint_pure = jit(lambda d, p: jacfwd(pure, argnums=1)(d, p))
Expand Down Expand Up @@ -1247,13 +1245,10 @@ def __init__(self, curve2d, surf):
self.dgammadash_by_dsurf_vjp_jax = jit(lambda gamma2d, surf_dofs, v: vjp(self.gammadash_jax, gamma2d, surf_dofs)[1](v)[1])

# GAMMADASHDASH
self.dxdt_times_dc2ddl = lambda gamma, sdofs, p: jvp(lambda g: self.gamma_pure(g, sdofs, p), (gamma,), (self.curve2d.gammadash(),))[1]
self.dxxdtt_times_dc2ddlsq = lambda gamma, sdofs, p: jvp(lambda g: self.dxdt_times_dc2ddl(g, sdofs, p), (gamma,), (self.curve2d.gammadash(),))[1]

self.gammadashdash_pure = lambda g, sdofs, p: jvp(lambda g2: self.gamma_pure(g2, sdofs, p), (g,), (self.curve2d.gammadashdash(),))[1] + self.dxxdtt_times_dc2ddlsq(g, sdofs, p)



self.gdashtmp = lambda g1, sdofs, p: jvp(lambda g2: self.gamma_pure(g2, sdofs, p), (g1,), (self.curve2d.gammadash(),))[1]
self.gdashdash_part1 = lambda gamma, sdofs, p: jvp(lambda g2: self.gdashtmp(g2, sdofs, p), (gamma,), (self.curve2d.gammadash(),))[1]
self.gdashdash_part2 = lambda g, sdofs, p: jvp(lambda g2: self.gamma_pure(g2, sdofs, p), (g,), (self.curve2d.gammadashdash(),))[1]
self.gammadashdash_pure = lambda g, sdofs, p: self.gdashdash_part1(g, sdofs, p) + self.gdashdash_part2(g, sdofs, p)

self.gammadashdash_jax = jit(lambda g, sdofs: self.gammadashdash_pure(g, sdofs, self.quadpoints))
self.dgammadashdash_by_dcurve_jax = jit(lambda gamma2d, surf_dofs: jacfwd(self.gammadashdash_jax, argnums=0)(gamma2d, surf_dofs))
Expand All @@ -1263,7 +1258,11 @@ def __init__(self, curve2d, surf):


# GAMMADASHDASHDASH
self.gammadashdashdash_pure = lambda g, sdofs, p: jvp(lambda g2: self.gammadashdash_pure(g2, sdofs, p), (g,), (self.curve2d.gammadashdashdash(),))[1]
self.gdashdashdash_part1 = lambda g1, sdofs, p: jvp(lambda g2: self.gdashdash_part1(g2, sdofs, p), (g1,), (self.curve2d.gammadash(),))[1]
self.gdashdashdash_part2 = lambda gamma, sdofs, p: 3.0 * jvp(lambda g2: self.gdashtmp(g2, sdofs, p), (gamma,), (self.curve2d.gammadashdash(),))[1]
self.gdashdashdash_part3 = lambda g, sdofs, p: jvp(lambda g2: self.gamma_pure(g2, sdofs, p), (g,), (self.curve2d.gammadashdashdash(),))[1]
self.gammadashdashdash_pure = lambda g, sdofs, p: self.gdashdashdash_part1(g, sdofs, p) + self.gdashdashdash_part2(g, sdofs, p) + self.gdashdashdash_part3(g, sdofs, p)

self.gammadashdashdash_jax = jit(lambda g, sdofs: self.gammadashdashdash_pure(g, sdofs, self.quadpoints))
self.dgammadashdashdash_by_dcurve_jax = jit(lambda gamma2d, surf_dofs: jacfwd(self.gammadashdashdash_jax, argnums=0)(gamma2d, surf_dofs))
self.dgammadashdashdash_by_dsurf_jax = jit(lambda gamma2d, surf_dofs: jacfwd(self.gammadashdashdash_jax, argnums=1)(gamma2d, surf_dofs))
Expand Down

0 comments on commit 103f8de

Please sign in to comment.