diff --git a/examples/2_Intermediate/stage_two_optimization_jax.py b/examples/2_Intermediate/stage_two_optimization_jax.py index 06936c3a7..e1f13d8a0 100644 --- a/examples/2_Intermediate/stage_two_optimization_jax.py +++ b/examples/2_Intermediate/stage_two_optimization_jax.py @@ -38,10 +38,10 @@ ncoils = 4 # Major radius for the initial circular coils: -R0 = 1.0 +R0 = 1.5 # Minor radius for the initial circular coils: -R1 = 0.5 +R1 = 1.0 # Number of Fourier modes describing each Cartesian component of each coil: order = 5 @@ -49,7 +49,7 @@ # Weight on the curve lengths in the objective function. We use the `Weight` # class here to later easily adjust the scalar value and rerun the optimization # without having to rebuild the objective. -LENGTH_WEIGHT = Weight(1e-8) +LENGTH_WEIGHT = Weight(1e-10) # Threshold and weight for the coil-to-coil distance penalty in the objective function: CC_THRESHOLD = 0.1 @@ -68,7 +68,7 @@ MSC_WEIGHT = 1e-6 # Weight for the Coil Coil forces term -FORCES_WEIGHT = 1e-14 # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +FORCES_WEIGHT = 1e-12 # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons # And this term weights the NetForce^2 ~ 10^10-10^12 # Number of iterations to perform: diff --git a/src/simsopt/field/biotsavart.py b/src/simsopt/field/biotsavart.py index cca7cbff0..dc5663788 100644 --- a/src/simsopt/field/biotsavart.py +++ b/src/simsopt/field/biotsavart.py @@ -554,13 +554,36 @@ def net_forces_squared_pure(self, curve_dofs, currents): # Minimize the sum of the net force magnitudes ^2 of every coil return jnp.sum(jnp.linalg.norm(self.coil_coil_forces_pure(curve_dofs, currents), axis=-1) ** 2) -class CoilCoilNetForces(Optimizable): - r""" - CurveLength is a class that computes the length of a curve, i.e. + def coil_coil_torques_pure(self, curve_dofs, currents): + """ + T = mu0 / 4pi * \oint_{C_1} \oint_{C_2} (dl_2 x r_12)(dl1 * r_12)/|r_12|^3 + """ + eps = 1e-20 # small number to avoid blow up in the denominator when i = j + gammas = self.get_gammas(curve_dofs) + gammadashs = self.get_gammadashs(curve_dofs) + Ii_Ij = (currents[None, :] * currents[:, None])[:, :, None] + # gamma and gammadash are shape (ncoils, nquadpoints, 3) + r_ij = gammas[None, :, None, :, :] - gammas[:, None, :, None, :] # Note, do not use the i = j indices + gammadash_prod = jnp.sum(gammadashs[None, :, None, :, :] * gammadashs[:, None, :, None, :], axis=-1) + rij_norm3 = jnp.linalg.norm(r_ij + eps, axis=-1) ** 3 - .. math:: - J = \int_{\text{curve}}~dl. + # Double sum over each of the closed curves + T = Ii_Ij * jnp.sum(jnp.sum((gammadash_prod / rij_norm3)[:, :, :, :, None] * r_ij, axis=3), axis=2) + net_forces = -jnp.sum(T, axis=1) * 1e-7 / jnp.shape(gammas)[1] ** 2 + return net_torques + + def net_torques_squared_pure(self, curve_dofs, currents): + # Minimize the sum of the net force magnitudes ^2 of every coil + return jnp.sum(jnp.linalg.norm(self.coil_coil_torques_pure(curve_dofs, currents), axis=-1) ** 2) + + def coil_coil_torques(self): + r""" + This function implements the curvature, :math:`\kappa(\varphi)`. + """ + return self.coil_coil_torques_pure(self.get_curve_dofs(), self.get_currents()) +class CoilCoilNetForces(Optimizable): + r""" """ def __init__(self, biot_savart): self.biot_savart = biot_savart @@ -594,4 +617,41 @@ def dJ(self): current_derivs = [coils[i].current.vjp(np.array([dF_by_dcurrents[i]])) for i in range(len(coils))] return sum(curve_derivs + current_derivs) + return_fn_map = {'J': J, 'dJ': dJ} + +class CoilCoilNetTorques(Optimizable): + r""" + """ + def __init__(self, biot_savart): + self.biot_savart = biot_savart + self.grad_curvedofs = jit(jacfwd(self.biot_savart.net_torques_squared_pure, argnums=0)) + self.grad_currents = jit(jacfwd(self.biot_savart.net_torques_squared_pure, argnums=1)) + super().__init__(depends_on=[biot_savart]) + + def J(self): + """ + This returns the value of the quantity. + """ + return self.biot_savart.net_torques_squared_pure( + self.biot_savart.get_curve_dofs(), + self.biot_savart.get_currents() + ) + + @derivative_dec + def dJ(self): + """ + This returns the derivative of the quantity with respect to the curve dofs. + """ + dT_by_dcurvedofs = self.grad_curvedofs( + self.biot_savart.get_curve_dofs(), + self.biot_savart.get_currents()) + dT_by_dcurrents = self.grad_currents( + self.biot_savart.get_curve_dofs(), + self.biot_savart.get_currents()) + + coils = self.biot_savart._coils + curve_derivs = [Derivative({coils[i].curve: dT_by_dcurvedofs[i]}) for i in range(len(dT_by_dcurvedofs))] + current_derivs = [coils[i].current.vjp(np.array([dT_by_dcurrents[i]])) for i in range(len(coils))] + return sum(curve_derivs + current_derivs) + return_fn_map = {'J': J, 'dJ': dJ} \ No newline at end of file