Skip to content

Commit

Permalink
more minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
phuslage committed Sep 1, 2023
1 parent 645add2 commit b1c9c59
Showing 1 changed file with 28 additions and 17 deletions.
45 changes: 28 additions & 17 deletions src/simsopt/geo/strain_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from simsopt._core.derivative import derivative_dec
from simsopt.geo.curveobjectives import Lp_torsion_pure

__all__ = ['LPBinormalCurvatureStrainPenalty', 'LPTorsionalStrainPenalty', 'StrainOpt']
__all__ = ['LPBinormalCurvatureStrainPenalty',
'LPTorsionalStrainPenalty', 'StrainOpt']


class LPBinormalCurvatureStrainPenalty(Optimizable):
Expand All @@ -26,14 +27,17 @@ def __init__(self, framedcurve, width=1e-3, p=2, threshold=0):
self.framedcurve = framedcurve
self.strain = StrainOpt(framedcurve, width)
self.width = width
self.p = p
self.threshold = threshold
self.J_jax = jit(lambda binorm, gammadash: Lp_torsion_pure(binorm, gammadash, p, threshold))
self.grad0 = jit(lambda binorm, gammadash: grad(self.J_jax, argnums=0)(binorm, gammadash))
self.grad1 = jit(lambda binorm, gammadash: grad(self.J_jax, argnums=1)(binorm, gammadash))
self.p = p
self.threshold = threshold
self.J_jax = jit(lambda binorm, gammadash: Lp_torsion_pure(
binorm, gammadash, p, threshold))
self.grad0 = jit(lambda binorm, gammadash: grad(
self.J_jax, argnums=0)(binorm, gammadash))
self.grad1 = jit(lambda binorm, gammadash: grad(
self.J_jax, argnums=1)(binorm, gammadash))
super().__init__(depends_on=[framedcurve])

def J(self):
def J(self):
"""
This returns the value of the quantity.
"""
Expand All @@ -44,9 +48,12 @@ def dJ(self):
"""
This returns the derivative of the quantity with respect to the curve and rotation dofs.
"""
grad0 = self.grad0(self.strain.binormal_curvature_strain(), self.framedcurve.curve.gammadash())
grad1 = self.grad1(self.strain.binormal_curvature_strain(), self.framedcurve.curve.gammadash())
vjp0 = self.strain.binormstrain_vjp(self.framedcurve.frame_binormal_curvature(), self.width, grad0)
grad0 = self.grad0(self.strain.binormal_curvature_strain(),
self.framedcurve.curve.gammadash())
grad1 = self.grad1(self.strain.binormal_curvature_strain(),
self.framedcurve.curve.gammadash())
vjp0 = self.strain.binormstrain_vjp(
self.framedcurve.frame_binormal_curvature(), self.width, grad0)
return self.framedcurve.dframe_binormal_curvature_by_dcoeff_vjp(vjp0) \
+ self.framedcurve.curve.dgammadash_by_dcoeff_vjp(grad1)

Expand All @@ -68,9 +75,10 @@ def __init__(self, framedcurve, width=1e-3, p=2, threshold=0):
self.framedcurve = framedcurve
self.strain = StrainOpt(framedcurve, width)
self.width = width
self.p = p
self.threshold = threshold
self.J_jax = jit(lambda torsion, gammadash: Lp_torsion_pure(torsion, gammadash, p, threshold))
self.p = p
self.threshold = threshold
self.J_jax = jit(lambda torsion, gammadash: Lp_torsion_pure(
torsion, gammadash, p, threshold))
self.grad0 = jit(lambda torsion, gammadash: grad(
self.J_jax, argnums=0)(torsion, gammadash))
self.grad1 = jit(lambda torsion, gammadash: grad(
Expand All @@ -88,9 +96,12 @@ def dJ(self):
"""
This returns the derivative of the quantity with respect to the curve and rotation dofs.
"""
grad0 = self.grad0(self.strain.torsional_strain(), self.framedcurve.curve.gammadash())
grad1 = self.grad1(self.strain.torsional_strain(), self.framedcurve.curve.gammadash())
vjp0 = self.strain.torstrain_vjp(self.framedcurve.frame_torsion(), self.width, grad0)
grad0 = self.grad0(self.strain.torsional_strain(),
self.framedcurve.curve.gammadash())
grad1 = self.grad1(self.strain.torsional_strain(),
self.framedcurve.curve.gammadash())
vjp0 = self.strain.torstrain_vjp(
self.framedcurve.frame_torsion(), self.width, grad0)
return self.framedcurve.dframe_torsion_by_dcoeff_vjp(vjp0) \
+ self.framedcurve.curve.dgammadash_by_dcoeff_vjp(grad1)

Expand All @@ -100,7 +111,7 @@ def dJ(self):
class StrainOpt(Optimizable):
r"""
This class evaluates the torsional and binormal curvature strains on HTS, based on
a filamentary model of the coil and the orientation of the HTX tape.
a filamentary model of the coil and the orientation of the HTS tape.
As defined in,
Expand Down

0 comments on commit b1c9c59

Please sign in to comment.