diff --git a/src/simsopt/geo/strain_optimization.py b/src/simsopt/geo/strain_optimization.py index e3499bc1c..7957e52fc 100644 --- a/src/simsopt/geo/strain_optimization.py +++ b/src/simsopt/geo/strain_optimization.py @@ -8,7 +8,8 @@ from simsopt._core.derivative import derivative_dec from simsopt.geo.curveobjectives import Lp_torsion_pure -__all__ = ['LPBinormalCurvatureStrainPenalty', 'LPTorsionalStrainPenalty', 'StrainOpt'] +__all__ = ['LPBinormalCurvatureStrainPenalty', + 'LPTorsionalStrainPenalty', 'StrainOpt'] class LPBinormalCurvatureStrainPenalty(Optimizable): @@ -26,14 +27,17 @@ def __init__(self, framedcurve, width=1e-3, p=2, threshold=0): self.framedcurve = framedcurve self.strain = StrainOpt(framedcurve, width) self.width = width - self.p = p - self.threshold = threshold - self.J_jax = jit(lambda binorm, gammadash: Lp_torsion_pure(binorm, gammadash, p, threshold)) - self.grad0 = jit(lambda binorm, gammadash: grad(self.J_jax, argnums=0)(binorm, gammadash)) - self.grad1 = jit(lambda binorm, gammadash: grad(self.J_jax, argnums=1)(binorm, gammadash)) + self.p = p + self.threshold = threshold + self.J_jax = jit(lambda binorm, gammadash: Lp_torsion_pure( + binorm, gammadash, p, threshold)) + self.grad0 = jit(lambda binorm, gammadash: grad( + self.J_jax, argnums=0)(binorm, gammadash)) + self.grad1 = jit(lambda binorm, gammadash: grad( + self.J_jax, argnums=1)(binorm, gammadash)) super().__init__(depends_on=[framedcurve]) - def J(self): + def J(self): """ This returns the value of the quantity. """ @@ -44,9 +48,12 @@ def dJ(self): """ This returns the derivative of the quantity with respect to the curve and rotation dofs. """ - grad0 = self.grad0(self.strain.binormal_curvature_strain(), self.framedcurve.curve.gammadash()) - grad1 = self.grad1(self.strain.binormal_curvature_strain(), self.framedcurve.curve.gammadash()) - vjp0 = self.strain.binormstrain_vjp(self.framedcurve.frame_binormal_curvature(), self.width, grad0) + grad0 = self.grad0(self.strain.binormal_curvature_strain(), + self.framedcurve.curve.gammadash()) + grad1 = self.grad1(self.strain.binormal_curvature_strain(), + self.framedcurve.curve.gammadash()) + vjp0 = self.strain.binormstrain_vjp( + self.framedcurve.frame_binormal_curvature(), self.width, grad0) return self.framedcurve.dframe_binormal_curvature_by_dcoeff_vjp(vjp0) \ + self.framedcurve.curve.dgammadash_by_dcoeff_vjp(grad1) @@ -68,9 +75,10 @@ def __init__(self, framedcurve, width=1e-3, p=2, threshold=0): self.framedcurve = framedcurve self.strain = StrainOpt(framedcurve, width) self.width = width - self.p = p - self.threshold = threshold - self.J_jax = jit(lambda torsion, gammadash: Lp_torsion_pure(torsion, gammadash, p, threshold)) + self.p = p + self.threshold = threshold + self.J_jax = jit(lambda torsion, gammadash: Lp_torsion_pure( + torsion, gammadash, p, threshold)) self.grad0 = jit(lambda torsion, gammadash: grad( self.J_jax, argnums=0)(torsion, gammadash)) self.grad1 = jit(lambda torsion, gammadash: grad( @@ -88,9 +96,12 @@ def dJ(self): """ This returns the derivative of the quantity with respect to the curve and rotation dofs. """ - grad0 = self.grad0(self.strain.torsional_strain(), self.framedcurve.curve.gammadash()) - grad1 = self.grad1(self.strain.torsional_strain(), self.framedcurve.curve.gammadash()) - vjp0 = self.strain.torstrain_vjp(self.framedcurve.frame_torsion(), self.width, grad0) + grad0 = self.grad0(self.strain.torsional_strain(), + self.framedcurve.curve.gammadash()) + grad1 = self.grad1(self.strain.torsional_strain(), + self.framedcurve.curve.gammadash()) + vjp0 = self.strain.torstrain_vjp( + self.framedcurve.frame_torsion(), self.width, grad0) return self.framedcurve.dframe_torsion_by_dcoeff_vjp(vjp0) \ + self.framedcurve.curve.dgammadash_by_dcoeff_vjp(grad1) @@ -100,7 +111,7 @@ def dJ(self): class StrainOpt(Optimizable): r""" This class evaluates the torsional and binormal curvature strains on HTS, based on - a filamentary model of the coil and the orientation of the HTX tape. + a filamentary model of the coil and the orientation of the HTS tape. As defined in,