From 81a6d35714a0baf7c72397d86eea442df2412735 Mon Sep 17 00:00:00 2001 From: Matt Landreman Date: Sat, 2 Mar 2024 15:35:25 -0500 Subject: [PATCH] Refactor LinkingNumber --- src/simsopt/geo/curveobjectives.py | 30 ++++++++++-------------------- src/simsoptpp/python_distance.cpp | 10 +++++----- 2 files changed, 15 insertions(+), 25 deletions(-) diff --git a/src/simsopt/geo/curveobjectives.py b/src/simsopt/geo/curveobjectives.py index bb81ec318..4bed06736 100644 --- a/src/simsopt/geo/curveobjectives.py +++ b/src/simsopt/geo/curveobjectives.py @@ -501,26 +501,16 @@ def __init__(self, curves): """ def J(self): - ncoils = len(self.curves) - linkNum = np.zeros([ncoils + 1, ncoils + 1]) - i = 0 - for c1 in self.curves[:(ncoils + 1)]: - j = 0 - i = i + 1 - for c2 in self.curves[:(ncoils + 1)]: - j = j + 1 - if i < j: - R1 = c1.gamma() - R2 = c2.gamma() - dS = c1.quadpoints[1] - c1.quadpoints[0] - dT = c2.quadpoints[1] - c1.quadpoints[0] - dR1 = c1.gammadash() - dR2 = c2.gammadash() - - integrals = sopp.linkNumber(R1, R2, dR1, dR2) * dS * dT - linkNum[i-1][j-1] = 1/(4*np.pi) * (integrals) - linkNumSum = sum(sum(abs(linkNum))) - return round(linkNumSum) + ncurves = len(self.curves) + gammas = [c.gamma() for c in self.curves] + dphis = [c.quadpoints[1] - c.quadpoints[0] for c in self.curves] + gammadashs = [c.gammadash() for c in self.curves] + link_num = 0.0 + for i in range(1, ncurves): + for j in range(i): + link_num += sopp.linkNumber(gammas[i], gammas[j], gammadashs[i], gammadashs[j]) * dphis[i] * dphis[j] + + return link_num @derivative_dec def dJ(self): diff --git a/src/simsoptpp/python_distance.cpp b/src/simsoptpp/python_distance.cpp index 36fba8eb1..0eac7a04e 100644 --- a/src/simsoptpp/python_distance.cpp +++ b/src/simsoptpp/python_distance.cpp @@ -184,18 +184,18 @@ void init_distance(py::module_ &m){ const double *curve2dash_ptr = curve2dash.data(); double difference[3] = { 0 }; double total = 0; + double dr, det; for(int i=0; i < linknphi1; i++){ for(int j=0; j < linknphi2; j++){ difference[0] = (curve1_ptr[3*i+0] - curve2_ptr[3*j+0]); difference[1] = (curve1_ptr[3*i+1] - curve2_ptr[3*j+1]); difference[2] = (curve1_ptr[3*i+2] - curve2_ptr[3*j+2]); - double denom = pow(std::sqrt(difference[0]*difference[0] + difference[1]*difference[1] + difference[2]*difference[2]), 3); - double det = curve1dash_ptr[3*i+0]*(curve2dash_ptr[3*j+1]*difference[2] - curve2dash_ptr[3*j+2]*difference[1]) - curve1dash_ptr[3*i+1]*(curve2dash_ptr[3*j+0]*difference[2] - curve2dash_ptr[3*j+2]*difference[0]) + curve1dash_ptr[3*i+2]*(curve2dash_ptr[3*j+0]*difference[1] - curve2dash_ptr[3*j+1]*difference[0]); - double r = det/denom; - total += r; + dr = std::sqrt(difference[0]*difference[0] + difference[1]*difference[1] + difference[2]*difference[2]); + det = curve1dash_ptr[3*i+0]*(curve2dash_ptr[3*j+1]*difference[2] - curve2dash_ptr[3*j+2]*difference[1]) - curve1dash_ptr[3*i+1]*(curve2dash_ptr[3*j+0]*difference[2] - curve2dash_ptr[3*j+2]*difference[0]) + curve1dash_ptr[3*i+2]*(curve2dash_ptr[3*j+0]*difference[1] - curve2dash_ptr[3*j+1]*difference[0]); + total += det / (dr * dr * dr); } } - return total; + return std::round(std::abs(total) / (4 * M_PI)); }); }