diff --git a/src/simsopt/geo/curveobjectives.py b/src/simsopt/geo/curveobjectives.py index 116682f73..326ec1ba2 100644 --- a/src/simsopt/geo/curveobjectives.py +++ b/src/simsopt/geo/curveobjectives.py @@ -260,11 +260,15 @@ def ws_distance_pure(gammac, lc, gammas, ns, maximum_distance): This function is used in a Python+Jax implementation of the curve-surface distance formula. """ + nss = gammas.size + ncc = gammac.size dists = jnp.sqrt(jnp.sum( (gammac[:, None, :] - gammas[None, :, :])**2, axis=2)) integralweight = jnp.linalg.norm(lc, axis=1)[:, None] \ * jnp.linalg.norm(ns, axis=1)[None, :] - return jnp.mean(integralweight * jnp.maximum(dists-maximum_distance, 0)**2) + return jnp.mean( + integralweight * jnp.maximum(dists**2-maximum_distance**2, 0)**2 + ) / (nss*ncc)**2 class WindingSurface(Optimizable): r"""Used to constrain coils to remain on a surface