Skip to content

Commit

Permalink
Added CurveCurve distance objective.
Browse files Browse the repository at this point in the history
  • Loading branch information
ejpaul committed May 7, 2024
1 parent f6cb59b commit 2d2180a
Showing 1 changed file with 91 additions and 1 deletion.
92 changes: 91 additions & 1 deletion src/simsopt/geo/curveobjectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from jax import grad, vjp, lax
import jax.numpy as jnp
import jax

from .jit import jit
from .._core.optimizable import Optimizable
Expand All @@ -12,7 +13,8 @@

__all__ = ['CurveLength', 'LpCurveCurvature', 'LpCurveTorsion',
'CurveCurveDistance', 'CurveSurfaceDistance', 'ArclengthVariation',
'MeanSquaredCurvature', 'LinkingNumber', 'FramedCurveTwist']
'MeanSquaredCurvature', 'LinkingNumber', 'FramedCurveTwist',
'MinCurveCurveDistance']


@jit
Expand Down Expand Up @@ -108,6 +110,11 @@ def Lp_torsion_pure(torsion, gammadash, p, threshold):
This function is used in a Python+Jax implementation of the formula for the torsion penalty term.
"""
arc_length = jnp.linalg.norm(gammadash, axis=1)
# jax.debug.print("arc_length: {arc_length}",arc_length=arc_length)
# jax.debug.print('p: {p}',p=p)
# jax.debug.print('threshold: {threshold}',threshold=threshold)
# jax.debug.print('binorm: {binorm}',binorm=torsion)
# jax.debug.print('integrand: {integrand}',integrand=jnp.maximum(jnp.abs(torsion)-threshold, 0)**p)
return (1./p)*jnp.mean(jnp.maximum(jnp.abs(torsion)-threshold, 0)**p * arc_length)


Expand Down Expand Up @@ -684,3 +691,86 @@ def dJ(self):
grad += self.framedcurve.curve.dgammadash_by_dcoeff_vjp(grad1)

return grad

def max_distance_pure(g1, g2, dmax, p):
"""
This returns 0 if all points of g1 have at least one point of g2 at a distance smaller or equal to dmax
Otherwise, returns the sum of |g2-g1_i|-dmax where only points further than dmax are considered.
The minimum distance between a point g1_i and g2 is obtained using the p-norm, with p < -1.
"""
dists = jnp.sqrt(jnp.sum( (g1[:, None, :] - g2[None, :, :])**2, axis=2))

# Estimate min of dists using p-norm. The minimum is taken along the axis=1. mindists is then an array of length g1.size, where mindists[i]=min_j(|g1[i]-g2[j]|)
mindists = jnp.sum(dists**p, axis=1)**(1./p)

# We now evaluate if any of mindists is larger than dmax. If yes, we add the value of (mindists[i]-dmax)**2 to the output.
# We normalize by the number of quadrature points along the first curve g1.
return jnp.sum(jnp.maximum(mindists-dmax, 0)**2) / g1.shape[0]


class MinCurveCurveDistance(Optimizable):
"""
This class can be used to constrain a curve to remain close
to another curve.
"""
def __init__(self, curve1, curve2, maximum_distance, p=-10):
self.curve1 = curve1
self.curve2 = curve2
self.maximum_distance = maximum_distance
self.p = p
self.J_jax = lambda g1, g2: max_distance_pure(g1, g2, self.maximum_distance, p)
self.this_grad_0 = jit(lambda g1, g2: grad(self.J_jax, argnums=0)(g1, g2))
self.this_grad_1 = jit(lambda g1, g2: grad(self.J_jax, argnums=1)(g1, g2))

Optimizable.__init__(self, depends_on=[curve1, curve2])

def max_distance(self):
"""
returns the max distance between curve1 and curve2
"""
g1 = self.curve1.gamma()
g2 = self.curve2.gamma()
dists = jnp.sqrt(jnp.sum( (g1[:, None, :] - g2[None, :, :])**2, axis=2))
mindists = jnp.min(dists,axis=1)

return jnp.max(mindists)

def min_dists(self):
"""
returns the an array of the minimum distance between curve1 and curve2
"""
g1 = self.curve1.gamma()
g2 = self.curve2.gamma()
dists = jnp.sqrt(jnp.sum( (g1[:, None, :] - g2[None, :, :])**2, axis=2))
print(np.shape(dists))
mindists = jnp.min(dists,axis=1)

return mindists

def min_dists_p(self):
"""
returns the an array of the minimum distance between curve1 and curve2 (approximated w/ p norm)
"""
p = self.p
g1 = self.curve1.gamma()
g2 = self.curve2.gamma()
dists = jnp.sqrt(jnp.sum( (g1[:, None, :] - g2[None, :, :])**2, axis=2))
mindists = jnp.sum(dists**p, axis=1)**(1./p)

return mindists

def J(self):
g1 = self.curve1.gamma()
g2 = self.curve2.gamma()

return self.J_jax( g1, g2 )

@derivative_dec
def dJ(self):
g1 = self.curve1.gamma()
g2 = self.curve2.gamma()

grad0 = self.this_grad_0(g1, g2)
grad1 = self.this_grad_1(g1, g2)

return self.curve1.dgamma_by_dcoeff_vjp( grad0 ) + self.curve2.dgamma_by_dcoeff_vjp( grad1 )

0 comments on commit 2d2180a

Please sign in to comment.