Skip to content

Commit

Permalink
Fixed NANs in derivative of self-force
Browse files Browse the repository at this point in the history
  • Loading branch information
landreman committed Sep 30, 2023
1 parent 02c678c commit 5ac31ea
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 88 deletions.
2 changes: 1 addition & 1 deletion examples/3_Advanced/coil_forces.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from simsopt.objectives import SquaredFlux, QuadraticPenalty
from simsopt.geo import CurveLength, CurveCurveDistance, CurveSurfaceDistance
from simsopt.field import BiotSavart
from simsopt.field.forces import ForceOpt
from simsopt.field.force import ForceOpt


# File for the desired boundary magnetic surface:
Expand Down
49 changes: 27 additions & 22 deletions src/simsopt/field/force.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ def self_force_rect(coil, a, b):
def force_opt_pure(gamma, gammadash, gammadashdash,
current, phi, B_ext, regularization):
"""Cost function for force optimization. Optimize for peak self force on the coil (so far)"""
t = gammadash / jnp.linalg.norm(gammadash)
tangent = gammadash / jnp.linalg.norm(gammadash, axis=1)[:, None]
B_self = B_regularized_pure(
gamma, gammadash, gammadashdash, phi, phi, current, regularization)
gamma, gammadash, gammadashdash, phi, current, regularization)
B_tot = B_self + B_ext
force = coil_force_pure(B_tot, current, t)
force = coil_force_pure(B_tot, current, tangent)
f_norm = jnp.linalg.norm(force, axis=1)
result = jnp.max(f_norm)
# result = jnp.sum(f_norm)
Expand All @@ -58,34 +58,40 @@ def force_opt_pure(gamma, gammadash, gammadashdash,
class ForceOpt(Optimizable):
"""Optimizable class to optimize forces on a coil"""

def __init__(self, coil, coils, a=0.05):
def __init__(self, coil, coils, regularization):
self.coil = coil
self.curve = coil.curve
self.coils = coils
self.a = a
self.B_ext = BiotSavart(coils).set_points(self.curve.gamma()).B()
self.B_self = 0
self.B = 0
self.J_jax = jit(lambda gamma, gammadash, gammadashdash,
current, phi, B_ext: force_opt_pure(gamma, gammadash, gammadashdash,
current, phi, B_ext))

self.thisgrad0 = jit(lambda gamma, gammadash, gammadashdash, current, phi, B_ext: grad(
self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, phi, B_ext))
self.thisgrad1 = jit(lambda gamma, gammadash, gammadashdash, current, phi, B_ext: grad(
self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, phi, B_ext))
self.thisgrad2 = jit(lambda gamma, gammadash, gammadashdash, current, phi, B_ext: grad(
self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, phi, B_ext))
self.regularization = regularization
self.B_ext = BiotSavart(coils).set_points(self.coil.curve.gamma()).B()
self.J_jax = jit(
lambda gamma, gammadash, gammadashdash, current, phi, B_ext:
force_opt_pure(gamma, gammadash, gammadashdash, current, phi, B_ext, regularization)
)

self.thisgrad0 = jit(
lambda gamma, gammadash, gammadashdash, current, phi, B_ext:
grad(self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, phi, B_ext)
)
self.thisgrad1 = jit(
lambda gamma, gammadash, gammadashdash, current, phi, B_ext:
grad(self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, phi, B_ext)
)
self.thisgrad2 = jit(
lambda gamma, gammadash, gammadashdash, current, phi, B_ext:
grad(self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, phi, B_ext)
)

super().__init__(depends_on=[coil])
# The version in the next line is needed
#eventually to get derivatives with respect to the other source coils:
#super().__init__(depends_on=[coil] + coils)

def J(self):
gamma = self.coil.curve.gamma()
d1gamma = self.coil.curve.gammadash()
d2gamma = self.coil.curve.gammadashdash()
current = self.coil.current.get_value()
phi = self.coil.curve.quadpoints
phi = self.coil.curve.quadpoints
B_ext = self.B_ext
return self.J_jax(gamma, d1gamma, d2gamma, current, phi, B_ext)

Expand All @@ -96,7 +102,6 @@ def dJ(self):
d2gamma = self.coil.curve.gammadashdash()
current = self.coil.current.get_value()
phi = self.coil.curve.quadpoints
phi = self.coil.curve.quadpoints
B_ext = self.B_ext

grad0 = self.thisgrad0(gamma, d1gamma, d2gamma,
Expand All @@ -107,7 +112,7 @@ def dJ(self):
current, phi, B_ext)

return (
self.coil.curve.dgamma_by_dcoeff_vjp(grad0)
self.coil.curve.dgamma_by_dcoeff_vjp(grad0)
+ self.coil.curve.dgammadash_by_dcoeff_vjp(grad1)
+ self.coil.curve.dgammadashdash_by_dcoeff_vjp(grad2)
)
Expand Down
8 changes: 2 additions & 6 deletions src/simsopt/field/selffield.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, reg
for j in range(n_quad):
dr = r_c - r_c[j]
first_term = (
jnp.cross(rc_prime[j], dr) / ((jnp.linalg.norm(dr, axis=1)**2 + regularization) ** 1.5)[:, None]
jnp.cross(rc_prime[j], dr) / ((jnp.sum(dr * dr, axis=1) + regularization) ** 1.5)[:, None]
)
cos_fac = 2 - 2 * jnp.cos(phi[j] - phi)
denominator2 = cos_fac * jnp.linalg.norm(rc_prime, axis=1)**2 + regularization
denominator2 = cos_fac * jnp.sum(rc_prime * rc_prime, axis=1) + regularization
factor2 = 0.5 * cos_fac / denominator2**1.5
second_term = jnp.cross(rc_prime_prime, rc_prime) * factor2[:, None]
integral_term += dphi * (first_term + second_term)
Expand All @@ -86,10 +86,6 @@ def B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, reg

def B_regularized(coil, regularization):
"""Calculate the regularized field on a coil following the Landreman and Hurwitz method"""
phi = coil.curve.quadpoints * 2 * np.pi
r_c = coil.curve.gamma()
rc_prime = coil.curve.gammadash() / 2 / np.pi
rc_prime_prime = coil.curve.gammadashdash() / 4 / np.pi**2
return B_regularized_pure(
coil.curve.gamma(),
coil.curve.gammadash(),
Expand Down
Loading

0 comments on commit 5ac31ea

Please sign in to comment.