From 2d2180ab2807f56cdc7fc00620f9fd67f282d262 Mon Sep 17 00:00:00 2001 From: Elizabeth Date: Tue, 7 May 2024 13:33:57 -0400 Subject: [PATCH] Added CurveCurve distance objective. --- src/simsopt/geo/curveobjectives.py | 92 +++++++++++++++++++++++++++++- 1 file changed, 91 insertions(+), 1 deletion(-) diff --git a/src/simsopt/geo/curveobjectives.py b/src/simsopt/geo/curveobjectives.py index 694814787..7b5e52009 100644 --- a/src/simsopt/geo/curveobjectives.py +++ b/src/simsopt/geo/curveobjectives.py @@ -3,6 +3,7 @@ import numpy as np from jax import grad, vjp, lax import jax.numpy as jnp +import jax from .jit import jit from .._core.optimizable import Optimizable @@ -12,7 +13,8 @@ __all__ = ['CurveLength', 'LpCurveCurvature', 'LpCurveTorsion', 'CurveCurveDistance', 'CurveSurfaceDistance', 'ArclengthVariation', - 'MeanSquaredCurvature', 'LinkingNumber', 'FramedCurveTwist'] + 'MeanSquaredCurvature', 'LinkingNumber', 'FramedCurveTwist', + 'MinCurveCurveDistance'] @jit @@ -108,6 +110,11 @@ def Lp_torsion_pure(torsion, gammadash, p, threshold): This function is used in a Python+Jax implementation of the formula for the torsion penalty term. """ arc_length = jnp.linalg.norm(gammadash, axis=1) + # jax.debug.print("arc_length: {arc_length}",arc_length=arc_length) + # jax.debug.print('p: {p}',p=p) + # jax.debug.print('threshold: {threshold}',threshold=threshold) + # jax.debug.print('binorm: {binorm}',binorm=torsion) + # jax.debug.print('integrand: {integrand}',integrand=jnp.maximum(jnp.abs(torsion)-threshold, 0)**p) return (1./p)*jnp.mean(jnp.maximum(jnp.abs(torsion)-threshold, 0)**p * arc_length) @@ -684,3 +691,86 @@ def dJ(self): grad += self.framedcurve.curve.dgammadash_by_dcoeff_vjp(grad1) return grad + +def max_distance_pure(g1, g2, dmax, p): + """ + This returns 0 if all points of g1 have at least one point of g2 at a distance smaller or equal to dmax + Otherwise, returns the sum of |g2-g1_i|-dmax where only points further than dmax are considered. + The minimum distance between a point g1_i and g2 is obtained using the p-norm, with p < -1. + """ + dists = jnp.sqrt(jnp.sum( (g1[:, None, :] - g2[None, :, :])**2, axis=2)) + + # Estimate min of dists using p-norm. The minimum is taken along the axis=1. mindists is then an array of length g1.size, where mindists[i]=min_j(|g1[i]-g2[j]|) + mindists = jnp.sum(dists**p, axis=1)**(1./p) + + # We now evaluate if any of mindists is larger than dmax. If yes, we add the value of (mindists[i]-dmax)**2 to the output. + # We normalize by the number of quadrature points along the first curve g1. + return jnp.sum(jnp.maximum(mindists-dmax, 0)**2) / g1.shape[0] + + +class MinCurveCurveDistance(Optimizable): + """ + This class can be used to constrain a curve to remain close + to another curve. + """ + def __init__(self, curve1, curve2, maximum_distance, p=-10): + self.curve1 = curve1 + self.curve2 = curve2 + self.maximum_distance = maximum_distance + self.p = p + self.J_jax = lambda g1, g2: max_distance_pure(g1, g2, self.maximum_distance, p) + self.this_grad_0 = jit(lambda g1, g2: grad(self.J_jax, argnums=0)(g1, g2)) + self.this_grad_1 = jit(lambda g1, g2: grad(self.J_jax, argnums=1)(g1, g2)) + + Optimizable.__init__(self, depends_on=[curve1, curve2]) + + def max_distance(self): + """ + returns the max distance between curve1 and curve2 + """ + g1 = self.curve1.gamma() + g2 = self.curve2.gamma() + dists = jnp.sqrt(jnp.sum( (g1[:, None, :] - g2[None, :, :])**2, axis=2)) + mindists = jnp.min(dists,axis=1) + + return jnp.max(mindists) + + def min_dists(self): + """ + returns the an array of the minimum distance between curve1 and curve2 + """ + g1 = self.curve1.gamma() + g2 = self.curve2.gamma() + dists = jnp.sqrt(jnp.sum( (g1[:, None, :] - g2[None, :, :])**2, axis=2)) + print(np.shape(dists)) + mindists = jnp.min(dists,axis=1) + + return mindists + + def min_dists_p(self): + """ + returns the an array of the minimum distance between curve1 and curve2 (approximated w/ p norm) + """ + p = self.p + g1 = self.curve1.gamma() + g2 = self.curve2.gamma() + dists = jnp.sqrt(jnp.sum( (g1[:, None, :] - g2[None, :, :])**2, axis=2)) + mindists = jnp.sum(dists**p, axis=1)**(1./p) + + return mindists + + def J(self): + g1 = self.curve1.gamma() + g2 = self.curve2.gamma() + + return self.J_jax( g1, g2 ) + + @derivative_dec + def dJ(self): + g1 = self.curve1.gamma() + g2 = self.curve2.gamma() + + grad0 = self.this_grad_0(g1, g2) + grad1 = self.this_grad_1(g1, g2) + + return self.curve1.dgamma_by_dcoeff_vjp( grad0 ) + self.curve2.dgamma_by_dcoeff_vjp( grad1 )