Skip to content

Commit

Permalink
Started attempt at torques.
Browse files Browse the repository at this point in the history
  • Loading branch information
akaptano committed Sep 27, 2024
1 parent af2d133 commit 75c082b
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 9 deletions.
8 changes: 4 additions & 4 deletions examples/2_Intermediate/stage_two_optimization_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@
ncoils = 4

# Major radius for the initial circular coils:
R0 = 1.0
R0 = 1.5

# Minor radius for the initial circular coils:
R1 = 0.5
R1 = 1.0

# Number of Fourier modes describing each Cartesian component of each coil:
order = 5

# Weight on the curve lengths in the objective function. We use the `Weight`
# class here to later easily adjust the scalar value and rerun the optimization
# without having to rebuild the objective.
LENGTH_WEIGHT = Weight(1e-8)
LENGTH_WEIGHT = Weight(1e-10)

# Threshold and weight for the coil-to-coil distance penalty in the objective function:
CC_THRESHOLD = 0.1
Expand All @@ -68,7 +68,7 @@
MSC_WEIGHT = 1e-6

# Weight for the Coil Coil forces term
FORCES_WEIGHT = 1e-14 # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
FORCES_WEIGHT = 1e-12 # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
# And this term weights the NetForce^2 ~ 10^10-10^12

# Number of iterations to perform:
Expand Down
70 changes: 65 additions & 5 deletions src/simsopt/field/biotsavart.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,13 +554,36 @@ def net_forces_squared_pure(self, curve_dofs, currents):
# Minimize the sum of the net force magnitudes ^2 of every coil
return jnp.sum(jnp.linalg.norm(self.coil_coil_forces_pure(curve_dofs, currents), axis=-1) ** 2)

class CoilCoilNetForces(Optimizable):
r"""
CurveLength is a class that computes the length of a curve, i.e.
def coil_coil_torques_pure(self, curve_dofs, currents):
"""
T = mu0 / 4pi * \oint_{C_1} \oint_{C_2} (dl_2 x r_12)(dl1 * r_12)/|r_12|^3
"""
eps = 1e-20 # small number to avoid blow up in the denominator when i = j
gammas = self.get_gammas(curve_dofs)
gammadashs = self.get_gammadashs(curve_dofs)
Ii_Ij = (currents[None, :] * currents[:, None])[:, :, None]
# gamma and gammadash are shape (ncoils, nquadpoints, 3)
r_ij = gammas[None, :, None, :, :] - gammas[:, None, :, None, :] # Note, do not use the i = j indices
gammadash_prod = jnp.sum(gammadashs[None, :, None, :, :] * gammadashs[:, None, :, None, :], axis=-1)
rij_norm3 = jnp.linalg.norm(r_ij + eps, axis=-1) ** 3

.. math::
J = \int_{\text{curve}}~dl.
# Double sum over each of the closed curves
T = Ii_Ij * jnp.sum(jnp.sum((gammadash_prod / rij_norm3)[:, :, :, :, None] * r_ij, axis=3), axis=2)
net_forces = -jnp.sum(T, axis=1) * 1e-7 / jnp.shape(gammas)[1] ** 2
return net_torques

def net_torques_squared_pure(self, curve_dofs, currents):
# Minimize the sum of the net force magnitudes ^2 of every coil
return jnp.sum(jnp.linalg.norm(self.coil_coil_torques_pure(curve_dofs, currents), axis=-1) ** 2)

def coil_coil_torques(self):
r"""
This function implements the curvature, :math:`\kappa(\varphi)`.
"""
return self.coil_coil_torques_pure(self.get_curve_dofs(), self.get_currents())

class CoilCoilNetForces(Optimizable):
r"""
"""
def __init__(self, biot_savart):
self.biot_savart = biot_savart
Expand Down Expand Up @@ -594,4 +617,41 @@ def dJ(self):
current_derivs = [coils[i].current.vjp(np.array([dF_by_dcurrents[i]])) for i in range(len(coils))]
return sum(curve_derivs + current_derivs)

return_fn_map = {'J': J, 'dJ': dJ}

class CoilCoilNetTorques(Optimizable):
r"""
"""
def __init__(self, biot_savart):
self.biot_savart = biot_savart
self.grad_curvedofs = jit(jacfwd(self.biot_savart.net_torques_squared_pure, argnums=0))
self.grad_currents = jit(jacfwd(self.biot_savart.net_torques_squared_pure, argnums=1))
super().__init__(depends_on=[biot_savart])

def J(self):
"""
This returns the value of the quantity.
"""
return self.biot_savart.net_torques_squared_pure(
self.biot_savart.get_curve_dofs(),
self.biot_savart.get_currents()
)

@derivative_dec
def dJ(self):
"""
This returns the derivative of the quantity with respect to the curve dofs.
"""
dT_by_dcurvedofs = self.grad_curvedofs(
self.biot_savart.get_curve_dofs(),
self.biot_savart.get_currents())
dT_by_dcurrents = self.grad_currents(
self.biot_savart.get_curve_dofs(),
self.biot_savart.get_currents())

coils = self.biot_savart._coils
curve_derivs = [Derivative({coils[i].curve: dT_by_dcurvedofs[i]}) for i in range(len(dT_by_dcurvedofs))]
current_derivs = [coils[i].current.vjp(np.array([dT_by_dcurrents[i]])) for i in range(len(coils))]
return sum(curve_derivs + current_derivs)

return_fn_map = {'J': J, 'dJ': dJ}

0 comments on commit 75c082b

Please sign in to comment.