Skip to content

Commit

Permalink
field alignment optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
phuslage committed Nov 10, 2023
1 parent 1b85d0f commit 6892915
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 54 deletions.
44 changes: 44 additions & 0 deletions examples/3_Advanced/field_alignment_opt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from simsopt.configs import get_w7x_data
from scipy.optimize import minimize
from simsopt.geo import FrameRotation, FramedCurveCentroid
from simsopt.field import Coil
from simsopt.field.fieldalignment import CrtitcalCurrentOpt, critical_current
from simsopt.util import in_github_actions
import numpy as np

MAXITER = 50 if in_github_actions else 400

curves, currents, ma = get_w7x_data()

curve = curves[0]
current = currents[0]
coils = [Coil(c, curr) for c, curr in zip(curves[1:], currents[1:])]

rot_order = 10 # order of the Fourier expression for the rotation of the filament pack

curve.fix_all() # fix curve DOFs -> only optimize winding angle
current.fix_all()
rotation = FrameRotation(curve.quadpoints, rot_order)

framedcurve = FramedCurveCentroid(curve, rotation)
coil = Coil(framedcurve, current)
JF = CrtitcalCurrentOpt(coil, coils, a=0.05, b=0.05)
print("Minimum Ic before Optimization:",
np.min(critical_current(coil, a=0.05, b=0.05)))


def fun(dofs):
JF.x = dofs
J = JF.J()
grad = JF.dJ()
return J, grad


f = fun
dofs = JF.x

res = minimize(fun, dofs, jac=True, method='L-BFGS-B',
options={'maxiter': MAXITER, 'maxcor': 10, 'gtol': 1e-20, 'ftol': 1e-20}, tol=1e-20)

print("Minimum Ic after Optimization:",
np.min(critical_current(coil, a=0.05, b=0.05)))
113 changes: 67 additions & 46 deletions src/simsopt/field/fieldalignment.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
"""Implements optimization of the angle between the magnetic field and the ReBCO c-axis."""
import math
import pickle
from scipy import constants
import numpy as np
import jax.numpy as jnp
from jax import grad
from simsopt.field import BiotSavart, Coil
from simsopt.field.selffield import B_regularized, regularization_rect, regularization_circ, B_regularized_pure
from simsopt.geo import FramedCurve
from simsopt.geo.framedcurve import rotated_frenet_frame
from simsopt.field import BiotSavart
from simsopt.field.selffield import regularization_rect, B_regularized_pure
from simsopt.geo.framedcurve import rotated_centroid_frame
from simsopt.geo.jit import jit
from simsopt._core import Optimizable
from simsopt._core.derivative import derivative_dec
Expand Down Expand Up @@ -75,14 +72,14 @@ def c_axis_angle_pure(coil, B):
def critical_current_pure(gamma, gammadash, gammadashdash, alpha, quadpoints, current, a, b):

regularization = regularization_rect(a, b)
tangent, normal, binormal = rotated_frenet_frame(
gamma, gammadash, gammadashdash, alpha)
tangent, normal, binormal = rotated_centroid_frame(
gamma, gammadash, alpha)

# Fit parameters for reduced Kim-like model of the critical current (doi:10.1088/0953-2048/24/6/065005)
xi = -0.3
k = 0.7
xi = -0.7
k = 0.3
B_0 = 42.6e-3
Ic_0 = 1
Ic_0 = 1 # 1.3e11,

field = B_regularized_pure(
gamma, gammadash, gammadashdash, quadpoints, current, regularization)
Expand All @@ -102,71 +99,95 @@ def critical_current(framedcoil, a, b, JANUS=True):
return Ic


def critical_current_obj_pure(gamma, gammadash, gammadashdash, alpha, quadpoints, current, a, b):
def critical_current_obj_pure(gamma, gammadash, gammadashdash, alpha, quadpoints, current, a, b, p):
Ic0 = 1
Ic = critical_current_pure(
gamma, gammadash, gammadashdash, alpha, quadpoints, current, a, b)
obj = np.min(Ic)
obj = (1./p)*jnp.mean((Ic-Ic0)**p)
return obj


def critical_current_obj(framedcoil, a, b, JANUS=True):
def critical_current_obj(framedcoil, a, b, p=4, JANUS=True):
"""Objective for field alignement optimization: Target minimum of the critical current along the coil"""
return np.min(critical_current(framedcoil, a, b))
obj = critical_current_obj_pure(framedcoil.curve.curve.gamma, framedcoil.curve.curve.d1gamma, framedcoil.curve.curve.d2gamma,
framedcoil.curve.rotation.alpha(framedcoil.curve.quadpoints), framedcoil.curve.quadpoints, framedcoil.current.current, a, b, p)
return obj


class CrtitcalCurrentOpt(Optimizable):
"""Optimizable class to optimize the critical on a ReBCO coil"""

def __init__(self, framedcoil, coils, a=0.05):
self.coil = coil
self.curve = coil.curve
def __init__(self, framedcoil, coils, a=0.05, b=0.05, p=2):
self.coil = framedcoil
self.curve = framedcoil.curve
self.coils = coils
self.a = a
self.B_ext = BiotSavart(coils).set_points(self.curve.gamma()).B()
self.b = b
self.p = p
self.B_ext = BiotSavart(coils).set_points(
framedcoil.curve.curve.gamma()).B()
self.B_self = 0
self.B = 0
self.alpha = coil.curve.rotation.alpha(coil.curve.quadpoints)
self.quadpoints = coil.curve.quadpoints
self.alpha = framedcoil.curve.rotation.alpha(
framedcoil.curve.quadpoints)
self.quadpoints = framedcoil.curve.quadpoints
self.J_jax = jit(lambda gamma, gammadash, gammadashdash, alpha, quadpoints, current, a,
b: critical_current_obj_pure(gamma, gammadash, gammadashdash, alpha, quadpoints, current, a, b))
b, p: critical_current_obj_pure(gamma, gammadash, gammadashdash, alpha, quadpoints, current, a, b, p))

self.thisgrad0 = jit(lambda gamma, gammadash, gammadashdash, current, phi, phidash, B_ext: grad(
self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, phi, phidash, B_ext))
self.thisgrad1 = jit(lambda gamma, gammadash, gammadashdash, current, phi, phidash, B_ext: grad(
self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, phi, phidash, B_ext))
self.thisgrad2 = jit(lambda gamma, gammadash, gammadashdash, current, phi, phidash, B_ext: grad(
self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, phi, phidash, B_ext))
self.thisgrad0 = jit(lambda gamma, gammadash, gammadashdash, alpha, quadpoints, current, a,
b, p: grad(
self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, alpha, quadpoints, current, a,
b, p))
self.thisgrad1 = jit(lambda gamma, gammadash, gammadashdash, alpha, quadpoints, current, a, b, p: grad(
self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, alpha, quadpoints, current, a, b, p))
self.thisgrad2 = jit(lambda gamma, gammadash, gammadashdash, alpha, quadpoints, current, a, b, p: grad(
self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, alpha, quadpoints, current, a, b, p))
self.thisgrad3 = jit(lambda gamma, gammadash, gammadashdash, alpha, quadpoints, current, a, b, p: grad(
self.J_jax, argnums=3)(gamma, gammadash, gammadashdash, alpha, quadpoints, current, a, b, p))

super().__init__(depends_on=[coil])
super().__init__(depends_on=[framedcoil])

def J(self):
gamma = self.coil.curve.gamma()
d1gamma = self.coil.curve.gammadash()
d2gamma = self.coil.curve.gammadashdash()
gamma = self.coil.curve.curve.gamma()
d1gamma = self.coil.curve.curve.gammadash()
d2gamma = self.coil.curve.curve.gammadashdash()
current = self.coil.current.get_value()
phi = self.coil.curve.quadpoints
phidash = self.coil.curve.quadpoints
phi = self.coil.curve.curve.quadpoints
phidash = self.coil.curve.curve.quadpoints
alpha = self.coil.curve.rotation.alpha(phi)
a = self.a
b = self.b
p = self.p

B_ext = self.B_ext
return self.J_jax(gamma, d1gamma, d2gamma, current, phi, phidash, B_ext)
return self.J_jax(gamma, d1gamma, d2gamma, alpha, phi, current, a, b, p)

@derivative_dec
def dJ(self):
gamma = self.coil.curve.gamma()
d1gamma = self.coil.curve.gammadash()
d2gamma = self.coil.curve.gammadashdash()
gamma = self.coil.curve.curve.gamma()
d1gamma = self.coil.curve.curve.gammadash()
d2gamma = self.coil.curve.curve.gammadashdash()
current = self.coil.current.get_value()
phi = self.coil.curve.quadpoints
phidash = self.coil.curve.quadpoints
alpha = self.coil.curve.rotation.alpha(phi)
a = self.a
b = self.b
p = self.p

B_ext = self.B_ext

grad0 = self.thisgrad0(gamma, d1gamma, d2gamma,
current, phi, phidash, B_ext)
grad1 = self.thisgrad0(gamma, d1gamma, d2gamma,
current, phi, phidash, B_ext)
grad2 = self.thisgrad0(gamma, d1gamma, d2gamma,
current, phi, phidash, B_ext)

return self.coil.curve.dgamma_by_dcoeff_vjp(grad0) + self.coil.curve.dgammadash_by_dcoeff_vjp(grad1) \
+ self.coil.curve.dgammadashdash_by_dcoeff_vjp(grad2)
alpha, phi, current, a, b, p)
grad1 = self.thisgrad1(gamma, d1gamma, d2gamma,
alpha, phi, current, a, b, p)
grad2 = self.thisgrad2(gamma, d1gamma, d2gamma,
alpha, phi, current, a, b, p)
grad3 = self.thisgrad3(gamma, d1gamma, d2gamma,
alpha, phi, current, a, b, p)

return self.coil.curve.curve.dgamma_by_dcoeff_vjp(grad0) + self.coil.curve.curve.dgammadash_by_dcoeff_vjp(grad1) \
+ self.coil.curve.curve.dgammadashdash_by_dcoeff_vjp(
grad2) + self.coil.curve.rotation.dalpha_by_dcoeff_vjp(phi, grad3)

return_fn_map = {'J': J, 'dJ': dJ}
21 changes: 13 additions & 8 deletions src/simsopt/field/force.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def self_force(coil, regularization):
Compute the self-force of a coil.
"""
I = coil.current.get_value()
tangent = coil.curve.gammadash() / np.linalg.norm(coil.curve.gammadash(), axis=1)[:, None]
tangent = coil.curve.gammadash() / np.linalg.norm(coil.curve.gammadash(),
axis=1)[:, None]
B = B_regularized(coil, regularization)
return coil_force_pure(B, I, tangent)

Expand Down Expand Up @@ -62,27 +63,31 @@ def __init__(self, coil, coils, regularization):
self.regularization = regularization
self.B_ext = BiotSavart(coils).set_points(self.coil.curve.gamma()).B()
self.J_jax = jit(
lambda gamma, gammadash, gammadashdash, current, phi, B_ext:
force_opt_pure(gamma, gammadash, gammadashdash, current, phi, B_ext, regularization)
lambda gamma, gammadash, gammadashdash, current, phi, B_ext:
force_opt_pure(gamma, gammadash, gammadashdash,
current, phi, B_ext, regularization)
)

self.thisgrad0 = jit(
lambda gamma, gammadash, gammadashdash, current, phi, B_ext:
grad(self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, phi, B_ext)
grad(self.J_jax, argnums=0)(gamma, gammadash,
gammadashdash, current, phi, B_ext)
)
self.thisgrad1 = jit(
lambda gamma, gammadash, gammadashdash, current, phi, B_ext:
grad(self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, phi, B_ext)
grad(self.J_jax, argnums=1)(gamma, gammadash,
gammadashdash, current, phi, B_ext)
)
self.thisgrad2 = jit(
lambda gamma, gammadash, gammadashdash, current, phi, B_ext:
grad(self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, phi, B_ext)
grad(self.J_jax, argnums=2)(gamma, gammadash,
gammadashdash, current, phi, B_ext)
)

super().__init__(depends_on=[coil])
# The version in the next line is needed
#eventually to get derivatives with respect to the other source coils:
#super().__init__(depends_on=[coil] + coils)
# eventually to get derivatives with respect to the other source coils:
# super().__init__(depends_on=[coil] + coils)

def J(self):
gamma = self.coil.curve.gamma()
Expand Down

0 comments on commit 6892915

Please sign in to comment.