diff --git a/examples/3_Advanced/QH_reactorscale_DA.py b/examples/3_Advanced/QH_reactorscale_DA.py index fef898a02..396cc388d 100644 --- a/examples/3_Advanced/QH_reactorscale_DA.py +++ b/examples/3_Advanced/QH_reactorscale_DA.py @@ -142,7 +142,7 @@ def initialize_coils_QH(TEST_DIR, s): aa = 0.05 bb = 0.05 -Nx = 7 +Nx = 6 Ny = Nx Nz = Nx # Create the initial coils: @@ -264,20 +264,20 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): LENGTH_WEIGHT = Weight(0.001) LENGTH_TARGET = 80 -LINK_WEIGHT = 1e3 +LINK_WEIGHT = 1e4 CC_THRESHOLD = 0.8 CC_WEIGHT = 1e1 CS_THRESHOLD = 1.5 CS_WEIGHT = 1e2 # Weight for the Coil Coil forces term -# FORCE_WEIGHT = Weight(1e-34) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons -# FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons -# TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons -# TORQUE_WEIGHT2 = Weight(4e-27) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons -FORCE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +FORCE_WEIGHT = Weight(1e-34) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons -TORQUE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +TORQUE_WEIGHT2 = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons # Directory for output OUT_DIR = ("./QH_reactorscale_TForder{:d}_n{:d}_p{:.2e}_c{:.2e}_lw{:.2e}_lt{:.2e}_lkw{:.2e}" + \ "_cct{:.2e}_ccw{:.2e}_cst{:.2e}_csw{:.2e}_fw{:.2e}_fww{:2e}_tw{:.2e}_tww{:2e}/").format( @@ -309,8 +309,8 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): NetTorques=coil_net_torques(coils, coils + coils_TF, regularization_rect(np.ones(len(coils)) * aa, np.ones(len(coils)) * bb), np.ones(len(coils)) * nturns) ) # Force and Torque calculations spawn a bunch of spurious BiotSavart child objects -- erase them! -# for c in (coils + coils_TF): -# c._children = set() +for c in (coils + coils_TF): + c._children = set() pointData = {"B_N": np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} s.to_vtk(OUT_DIR + "surf_init_DA", extra_data=pointData) @@ -348,30 +348,17 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): # interlink. linkNum = LinkingNumber(curves + curves_TF, downsample=4) -##### Note need coils_TF + coils below!!!!!!! -# Jforce2 = sum([LpCurveForce(c, coils + coils_TF, -# regularization=regularization_rect(base_a_list[i], base_b_list[i]), -# p=2, threshold=1e8) for i, c in enumerate(base_coils + base_coils_TF)]) -# Jforce2 = LpCurveForce2(coils, coils_TF, p=2, threshold=1e8) -# Jforce = sum([SquaredMeanForce(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) - -regularization_list = np.ones(len(coils)) * regularization_rect(aa, bb) -regularization_list2 = np.ones(len(coils_TF)) * regularization_rect(a, b) -# Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2) # [SquaredMeanForce2(c, coils) for c in (base_coils)] -# Jforce = MixedSquaredMeanForce(coils, coils_TF) -Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=4, threshold=4e5 * 100, downsample=1) for i, c in enumerate(base_coils + base_coils_TF)]) -Jforce2 = sum([SquaredMeanForce(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) -Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=4e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) -# Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=1e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) - - -Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils_TF)]) - -# Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) - +# Currently, all force terms involve all the coils +all_coils = coils + coils_TF +all_base_coils = base_coils + base_coils_TF +Jforce = sum([LpCurveForce(c, all_coils, regularization_rect(a_list[i], b_list[i]), p=4, threshold=4e5 * 100, downsample=1 + ) for i, c in enumerate(all_base_coils)]) +Jforce2 = sum([SquaredMeanForce(c, all_coils, downsample=4) for c in all_base_coils]) -# Jtorque = SquaredMeanTorque2(coils, coils_TF) # [SquaredMeanForce2(c, coils) for c in (base_coils)] -# Jtorque = [SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)] +# Errors creep in when downsample = 2 +Jtorque = sum([LpCurveTorque(c, all_coils, regularization_rect(a_list[i], b_list[i]), p=2, threshold=4e5 * 100, downsample=4 + ) for i, c in enumerate(all_base_coils)]) +Jtorque2 = sum([SquaredMeanTorque(c, all_coils, downsample=1) for c in all_base_coils]) JF = Jf \ + CC_WEIGHT * Jccdist \ @@ -424,10 +411,10 @@ def fun(dofs): cc_val = CC_WEIGHT * Jccdist.J() cs_val = CS_WEIGHT * Jcsdist.J() link_val = LINK_WEIGHT * linkNum.J() - forces_val = FORCE_WEIGHT.value * Jforce.J() - # forces_val2 = FORCE_WEIGHT2.value * Jforce2.J() - # torques_val = TORQUE_WEIGHT.value * Jtorque.J() - # torques_val2 = TORQUE_WEIGHT2.value * Jtorque2.J() + forces_val = Jforce.J() + forces_val2 = Jforce2.J() + torques_val = Jtorque.J() + torques_val2 = Jtorque2.J() BdotN = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))) BdotN_over_B = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)) ) / np.mean(btot.AbsB()) @@ -439,14 +426,14 @@ def fun(dofs): valuestr += f", ccObj={cc_val:.2e}" valuestr += f", csObj={cs_val:.2e}" valuestr += f", Lk1Obj={link_val:.2e}" - valuestr += f", forceObj={forces_val:.2e}" - # valuestr += f", forceObj2={forces_val2:.2e}" - # valuestr += f", torqueObj={torques_val:.2e}" - # valuestr += f", torqueObj2={torques_val2:.2e}" - outstr += f", F={Jforce.J():.2e}" - # outstr += f", Fnet={Jforce2.J():.2e}" - # outstr += f", T={Jtorque.J():.2e}" - # outstr += f", Tnet={Jtorque2.J():.2e}" + valuestr += f", forceObj={FORCE_WEIGHT.value * forces_val:.2e}" + valuestr += f", forceObj2={FORCE_WEIGHT2.value * forces_val2:.2e}" + valuestr += f", torqueObj={TORQUE_WEIGHT.value * torques_val:.2e}" + valuestr += f", torqueObj2={TORQUE_WEIGHT2.value * torques_val2:.2e}" + outstr += f", F={forces_val:.2e}" + outstr += f", Fnet={forces_val2:.2e}" + outstr += f", T={torques_val:.2e}" + outstr += f", Tnet={torques_val2:.2e}" outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist2.shortest_distance():.2f}" outstr += f", Link Number = {linkNum.J()}" outstr += f", ║∇J║={np.linalg.norm(grad):.1e}" @@ -456,8 +443,8 @@ def fun(dofs): sio = io.StringIO() sortby = SortKey.CUMULATIVE ps = pstats.Stats(pr, stream=sio).sort_stats(sortby) - ps.print_stats(20) - print(sio.getvalue()) + # ps.print_stats(20) + # print(sio.getvalue()) # for c in (coils + coils_TF): # c._children = set() # exit() diff --git a/examples/3_Advanced/QH_reactorscale_notfixed.py b/examples/3_Advanced/QH_reactorscale_notfixed.py index cb4e01450..6a9d19b83 100644 --- a/examples/3_Advanced/QH_reactorscale_notfixed.py +++ b/examples/3_Advanced/QH_reactorscale_notfixed.py @@ -332,13 +332,13 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): # coil-coil and coil-plasma distances should be between all coils ### Jcc below removed the dipoles! -Jccdist = CurveCurveDistance(curves_TF, CC_THRESHOLD, num_basecurves=len(coils_TF)) +Jccdist = CurveCurveDistance(curves_TF, CC_THRESHOLD, num_basecurves=len(coils_TF), downsample=4) Jcsdist = CurveSurfaceDistance(curves_TF, s, CS_THRESHOLD) # Jcsdist2 = CurveSurfaceDistance(curves_TF, s, CS_THRESHOLD) # While the coil array is not moving around, they cannot # interlink. -linkNum = LinkingNumber(curves_TF, downsample=4) +linkNum = LinkingNumber(curves_TF, downsample=8) ##### Note need coils_TF + coils below!!!!!!! # Jforce2 = sum([LpCurveForce(c, coils_TF, @@ -352,7 +352,7 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): # Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2) # [SquaredMeanForce2(c, coils) for c in (base_coils)] # Jforce = MixedSquaredMeanForce(coils, coils_TF) # Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2, p=4, threshold=4e5 * 100, downsample=4) -Jforce = sum([LpCurveForce(c, coils_TF, regularization_rect(a, b), p=4, threshold=4e5 * 100, downsample=4) for i, c in enumerate(base_coils_TF)]) +Jforce = sum([LpCurveForce(c, coils_TF, regularization_rect(a, b), p=4, threshold=4e5 * 100, downsample=1) for i, c in enumerate(base_coils_TF)]) # Jforce2 = sum([SquaredMeanForce(c, coils_TF) for c in (base_coils_TF)]) # Jtorque = sum([LpCurveTorque(c, coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=4e5 * 100) for i, c in enumerate(base_coils_TF)]) # Jtorque = sum([LpCurveTorque(c, coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=1e5 * 100) for i, c in enumerate(base_coils_TF)]) @@ -459,7 +459,7 @@ def fun(dofs): # cc_val = CC_WEIGHT * Jccdist.J() # cs_val = CS_WEIGHT * Jcsdist.J() # link_val = LINK_WEIGHT * linkNum.J() - forces_val = FORCE_WEIGHT.value * Jforce.J() + # forces_val = FORCE_WEIGHT.value * Jforce.J() # print(JF._children, btot._children, btot.coils[-1]._children, # btot.coils[-1].curve._children, btot._coils[-1]._children, # Jforce._children) @@ -478,7 +478,7 @@ def fun(dofs): # valuestr += f", ccObj={cc_val:.2e}" # valuestr += f", csObj={cs_val:.2e}" # valuestr += f", Lk1Obj={link_val:.2e}" - valuestr += f", forceObj={forces_val:.2e}" + # valuestr += f", forceObj={forces_val:.2e}" # valuestr += f", forceObj2={forces_val2:.2e}" # valuestr += f", torqueObj={torques_val:.2e}" # valuestr += f", torqueObj2={torques_val2:.2e}" @@ -494,7 +494,7 @@ def fun(dofs): print(valuestr) pr.disable() sio = io.StringIO() - sortby = SortKey.CUMULATIVE + sortby = SortKey.TIME ps = pstats.Stats(pr, stream=sio).sort_stats(sortby) ps.print_stats(20) print(sio.getvalue()) diff --git a/src/simsopt/_core/derivative.py b/src/simsopt/_core/derivative.py index 6b62d848a..9681bbfd6 100644 --- a/src/simsopt/_core/derivative.py +++ b/src/simsopt/_core/derivative.py @@ -115,11 +115,16 @@ def __add__(self, other): x = self.data y = other.data z = copy_numpy_dict(x) + # for k, yk in y.items(): + # if k in z: + # z[k] += yk + # else: + # z[k] = yk for k in y: if k in z: z[k] += y[k] else: - z[k] = y[k].copy() + z[k] = y[k].copy() # why copy here but not in subtract? return Derivative(z) def __sub__(self, other): diff --git a/src/simsopt/field/biotsavart.py b/src/simsopt/field/biotsavart.py index abd52d666..8df27f578 100644 --- a/src/simsopt/field/biotsavart.py +++ b/src/simsopt/field/biotsavart.py @@ -125,45 +125,45 @@ def B_vjp(self, v): sopp.biot_savart_vjp_graph(points, gammas, gammadashs, currents, v, res_gamma, res_gammadash, [], [], []) dB_by_dcoilcurrents = self.dB_by_dcoilcurrents() - res_current = [np.sum(v * self.dB_by_dcoilcurrents()[i]) for i in range(len(dB_by_dcoilcurrents))] + res_current = [np.sum(v * dB_by_dcoilcurrents[i]) for i in range(len(dB_by_dcoilcurrents))] return sum([coils[i].vjp(res_gamma[i], res_gammadash[i], np.asarray([res_current[i]])) for i in range(len(coils))]) - def B_vjp_pure(self, v): - r""" - Assume the field was evaluated at points :math:`\mathbf{x}_i, i\in \{1, \ldots, n\}` and denote the value of the field at those points by - :math:`\{\mathbf{B}_i\}_{i=1}^n`. - These values depend on the shape of the coils, i.e. on the dofs :math:`\mathbf{c}_k` of each coil. - This function returns the vector Jacobian product of this dependency, i.e. - - .. math:: - - \{ \sum_{i=1}^{n} \mathbf{v}_i \cdot \partial_{\mathbf{c}_k} \mathbf{B}_i \}_k. - - """ - coils = self._coils - t1 = time.time() - gammas = [coil.curve.gamma() for coil in coils] - gammadashs = [coil.curve.gammadash() for coil in coils] - currents = [coil.current.get_value() for coil in coils] - res_gamma = [np.zeros_like(gamma) for gamma in gammas] - res_gammadash = [np.zeros_like(gammadash) for gammadash in gammadashs] - - points = self.get_points_cart_ref() - sopp.biot_savart_vjp_graph(points, gammas, gammadashs, currents, np.array(v), - res_gamma, res_gammadash, [], [], []) - # t2 = time.time() - # print(t2 - t1) - # t1 = time.time() - # dB_by_dcoilcurrents = self.dB_by_dcoilcurrents() - # # res_current = np.sum(np.sum(v[None, :, :] * np.array(self.dB_by_dcoilcurrents()), axis=-1), axis=-1) - res_current = [jnp.sum(v * self.dB_by_dcoilcurrents()[i]) for i in range(len(coils))] - # t2 = time.time() - # print(t2 - t1) - # t1 = time.time() - # sum([coils[i].vjp(res_gamma[i], res_gammadash[i], np.asarray([res_current[i]])) for i in range(len(coils))]) - # t2 = time.time() - # print(t2 - t1) - return sum([coils[i].vjp(res_gamma[i], res_gammadash[i], jnp.asarray([res_current[i]])) for i in range(len(coils))]) + # def B_vjp_pure(self, v): + # r""" + # Assume the field was evaluated at points :math:`\mathbf{x}_i, i\in \{1, \ldots, n\}` and denote the value of the field at those points by + # :math:`\{\mathbf{B}_i\}_{i=1}^n`. + # These values depend on the shape of the coils, i.e. on the dofs :math:`\mathbf{c}_k` of each coil. + # This function returns the vector Jacobian product of this dependency, i.e. + + # .. math:: + + # \{ \sum_{i=1}^{n} \mathbf{v}_i \cdot \partial_{\mathbf{c}_k} \mathbf{B}_i \}_k. + + # """ + # coils = self._coils + # t1 = time.time() + # gammas = [coil.curve.gamma() for coil in coils] + # gammadashs = [coil.curve.gammadash() for coil in coils] + # currents = [coil.current.get_value() for coil in coils] + # res_gamma = [np.zeros_like(gamma) for gamma in gammas] + # res_gammadash = [np.zeros_like(gammadash) for gammadash in gammadashs] + + # points = self.get_points_cart_ref() + # sopp.biot_savart_vjp_graph(points, gammas, gammadashs, currents, np.array(v), + # res_gamma, res_gammadash, [], [], []) + # # t2 = time.time() + # # print(t2 - t1) + # # t1 = time.time() + # # dB_by_dcoilcurrents = self.dB_by_dcoilcurrents() + # # # res_current = np.sum(np.sum(v[None, :, :] * np.array(self.dB_by_dcoilcurrents()), axis=-1), axis=-1) + # res_current = [jnp.sum(v * self.dB_by_dcoilcurrents()[i]) for i in range(len(coils))] + # # t2 = time.time() + # # print(t2 - t1) + # # t1 = time.time() + # # sum([coils[i].vjp(res_gamma[i], res_gammadash[i], np.asarray([res_current[i]])) for i in range(len(coils))]) + # # t2 = time.time() + # # print(t2 - t1) + # return sum([coils[i].vjp(res_gamma[i], res_gammadash[i], jnp.asarray([res_current[i]])) for i in range(len(coils))]) def dA_by_dcoilcurrents(self, compute_derivatives=0): points = self.get_points_cart_ref() diff --git a/src/simsopt/field/force.py b/src/simsopt/field/force.py index af198cd6b..2ca5701b0 100644 --- a/src/simsopt/field/force.py +++ b/src/simsopt/field/force.py @@ -113,17 +113,16 @@ def lp_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regulari where :math:`\vec{F}` is the Lorentz force, :math:`F_0` is a threshold force, and :math:`\ell` is arclength along the coil. """ - B_mutual = B_mutual[::downsample, :] gamma = gamma[::downsample, :] gammadash = gammadash[::downsample, :] gammadashdash = gammadashdash[::downsample, :] quadpoints = quadpoints[::downsample] - B_self = B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization) gammadash_norm = jnp.linalg.norm(gammadash, axis=1)[:, None] tangent = gammadash / gammadash_norm - force = jnp.cross(current * tangent, B_self + B_mutual) - force_norm = jnp.linalg.norm(force, axis=1)[:, None] - return (jnp.sum(jnp.maximum(force_norm - threshold, 0)**p * gammadash_norm)) * (1. / p) / jnp.shape(gamma)[0] + return (jnp.sum(jnp.maximum( + jnp.linalg.norm(jnp.cross( + current * tangent, B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization) + B_mutual + ), axis=1)[:, None] - threshold, 0)**p * gammadash_norm) / jnp.shape(gamma)[0]) * (1. / p) class LpCurveForce(Optimizable): @@ -140,7 +139,6 @@ class LpCurveForce(Optimizable): def __init__(self, coil, allcoils, regularization, p=2.0, threshold=0.0, downsample=1): self.coil = coil - self.allcoils = allcoils self.othercoils = [c for c in allcoils if c is not coil] self.biotsavart = BiotSavart(self.othercoils) quadpoints = self.coil.curve.quadpoints @@ -186,99 +184,73 @@ def __init__(self, coil, allcoils, regularization, p=2.0, threshold=0.0, downsam super().__init__(depends_on=allcoils) def J(self): - # biotsavart._children = set() - self.biotsavart.set_points(self.coil.curve.gamma()) - - args = [ - self.coil.curve.gamma(), - self.coil.curve.gammadash(), - self.coil.curve.gammadashdash(), - self.coil.current.get_value(), - self.biotsavart.B(), - self.downsample - ] + gamma = self.coil.curve.gamma() + self.biotsavart.set_points(np.array(gamma[::self.downsample, :])) + J = self.J_jax(gamma, self.coil.curve.gammadash(), self.coil.curve.gammadashdash(), + self.coil.current.get_value(), self.biotsavart.B(), self.downsample) #### ABSOLUTELY ESSENTIAL LINES BELOW # Otherwise optimizable references multiply # like crazy as number of coils increases self.biotsavart._children = set() + self.coil._children = set() + self.coil.curve._children = set() + self.coil.current._children = set() for c in self.othercoils: c._children = set() c.curve._children = set() c.current._children = set() - - return self.J_jax(*args) + return J @derivative_dec def dJ(self): - - # biotsavart._children = set() - self.biotsavart.set_points(self.coil.curve.gamma()) + gamma = self.coil.curve.gamma() + gammadash = self.coil.curve.gammadash() + gammadashdash = self.coil.curve.gammadashdash() + current = self.coil.current.get_value() + self.biotsavart.set_points(gamma) args = [ - self.coil.curve.gamma(), - self.coil.curve.gammadash(), - self.coil.curve.gammadashdash(), - self.coil.current.get_value(), + gamma, + gammadash, + gammadashdash, + current, self.biotsavart.B(), - self.downsample + 1 ] dJ_dB = self.dJ_dB_mutual(*args) dB_dX = self.biotsavart.dB_by_dX() dJ_dX = np.einsum('ij,ikj->ik', dJ_dB, dB_dX) - - # coils = self.othercoils - # gammas = [coil.curve.gamma() for coil in coils] - # gammadashs = [coil.curve.gammadash() for coil in coils] - # currents = [coil.current.get_value() for coil in coils] - # res_gamma = [np.zeros_like(gamma) for gamma in gammas] - # res_gammadash = [np.zeros_like(gammadash) for gammadash in gammadashs] - # points = self.coil.curve.gamma() - # sopp.biot_savart_vjp_graph(points, gammas, gammadashs, currents, dJ_dB, - # res_gamma, res_gammadash, [], [], []) - # dB_by_dcoilcurrents = self.biotsavart.dB_by_dcoilcurrents() - # res_current = [np.sum(dJ_dB * dB_by_dcoilcurrents[i]) for i in range(len(dB_by_dcoilcurrents))] - # res_current = np.zeros(len(coils)) - # B_vjp = sum([coils[i].vjp(res_gamma[i], - # res_gammadash[i], - # np.asarray([res_current[i]])) for i in range(len(coils))]) - - - # print(self.othercoils[0]._children) - # print(B_vjp, B_vjp._children) + B_vjp = self.biotsavart.B_vjp(dJ_dB) + + self.biotsavart.set_points(np.array(gamma[::self.downsample, :])) + + args2 = [ + gamma, + gammadash, + gammadashdash, + current, + self.biotsavart.B(), + self.downsample + ] dJ = ( - self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args) + dJ_dX) - + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args)) - + self.coil.curve.dgammadashdash_by_dcoeff_vjp(self.dJ_dgammadashdash(*args)) - + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args)])) - + self.biotsavart.B_vjp(dJ_dB) + self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args2) + dJ_dX) + + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args2)) + + self.coil.curve.dgammadashdash_by_dcoeff_vjp(self.dJ_dgammadashdash(*args2)) + + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args2)])) + + B_vjp ) #### ABSOLUTELY ESSENTIAL LINES BELOW # Otherwise optimizable references multiply # like crazy as number of coils increases self.biotsavart._children = set() + self.coil._children = set() + self.coil.curve._children = set() + self.coil.current._children = set() for c in self.othercoils: c._children = set() c.curve._children = set() c.current._children = set() - # dJ = ( - # self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args) + dJ_dX) - # + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args)) - # + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args)])) - # + self.coil.curve.dgammadashdash_by_dcoeff_vjp(self.dJ_dgammadashdash(*args)) - # + biotsavart.B_vjp(dJ_dB) - # # ) - - #### Needed if JaxCurves are used? - # self.biotsavart._children = set() - # for c in self.othercoils: - # c._children = set() - # c.curve._children = set() - # c.current._children = set() - - # print(B_vjp, coils[0]._children) - # print(biotsavart.coils[0]._children, self.coil._children, self.coil.curve._children) - return dJ return_fn_map = {'J': J, 'dJ': dJ} @@ -384,31 +356,12 @@ def dJ(self): dB_dX = self.biotsavart.dB_by_dX() dJ_dX = np.einsum('ij,ikj->ik', dJ_dB, dB_dX) - coils = self.othercoils - gammas = [coil.curve.gamma() for coil in coils] - gammadashs = [coil.curve.gammadash() for coil in coils] - currents = [coil.current.get_value() for coil in coils] - res_gamma = [np.zeros_like(gamma) for gamma in gammas] - res_gammadash = [np.zeros_like(gammadash) for gammadash in gammadashs] - points = self.coil.curve.gamma() - sopp.biot_savart_vjp_graph(points, gammas, gammadashs, currents, v, - res_gamma, res_gammadash, [], [], []) - dB_by_dcoilcurrents = self.biotsavart.dB_by_dcoilcurrents() - res_current = [np.sum(v * dB_by_dcoilcurrents[i]) for i in range(len(dB_by_dcoilcurrents))] - # B_vjp = sum([coils[i].vjp(res_gamma[i], res_gammadash[i], np.asarray([res_current[i]])) for i in range(len(coils))]) - # for c in coils: - # c._children = set() - B_vjp = sum([coils[i].vjp(res_gamma[i], res_gammadash[i], np.asarray([res_current[i]])) for i in range(len(coils))]) - - print(self.othercoils[0]._children) - print(B_vjp, B_vjp._children) return ( self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args) + dJ_dX) + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args)) + self.coil.curve.dgammadashdash_by_dcoeff_vjp(self.dJ_dgammadashdash(*args)) + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args)])) - + B_vjp - # + self.biotsavart.B_vjp(dJ_dB) + + self.biotsavart.B_vjp(dJ_dB) ) return_fn_map = {'J': J, 'dJ': dJ} @@ -1176,14 +1129,15 @@ def dJ(self): return_fn_map = {'J': J, 'dJ': dJ} -@jit -def squared_mean_force_pure(current, gammadash, B_mutual): +# @jit +def squared_mean_force_pure(current, gammadash, B_mutual, downsample): r""" """ # B_self = B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization) # gammadash_norm = jnp.linalg.norm(gammadash, axis=1)[:, None] # tangent = gammadash / gammadash_norm - return (current * jnp.linalg.norm(jnp.sum(jnp.cross(gammadash, B_mutual), axis=0))) ** 2 # / jnp.sum(gammadash_norm) # factor for the integral + gammadash = gammadash[::downsample, :] + return (current * jnp.linalg.norm(jnp.sum(jnp.cross(gammadash, B_mutual), axis=0) / gammadash.shape[0])) ** 2 # / jnp.sum(gammadash_norm) # factor for the integral class SquaredMeanForce(Optimizable): r"""Optimizable class to minimize the net Lorentz force on a coil. @@ -1197,70 +1151,109 @@ class SquaredMeanForce(Optimizable): along the coil. """ - def __init__(self, coil, allcoils): + def __init__(self, coil, allcoils, downsample=1): self.coil = coil - self.allcoils = allcoils - self.othercoils = [c for c in self.allcoils if c is not self.coil] + self.othercoils = [c for c in allcoils if c is not self.coil] + self.downsample = downsample + self.biotsavart = BiotSavart(self.othercoils) + args = {"static_argnums": (3,)} self.J_jax = jit( - lambda current, gammadash, B_mutual: - squared_mean_force_pure(current, gammadash, B_mutual) + lambda current, gammadash, B_mutual, downsample: + squared_mean_force_pure(current, gammadash, B_mutual, downsample), + **args ) self.dJ_dcurrent = jit( - lambda current, gammadash, B_mutual: - grad(self.J_jax, argnums=0)(current, gammadash, B_mutual) + lambda current, gammadash, B_mutual, downsample: + grad(self.J_jax, argnums=0)(current, gammadash, B_mutual, downsample), + **args ) self.dJ_dgammadash = jit( - lambda current, gammadash, B_mutual: - grad(self.J_jax, argnums=1)(current, gammadash, B_mutual) + lambda current, gammadash, B_mutual, downsample: + grad(self.J_jax, argnums=1)(current, gammadash, B_mutual, downsample), + **args ) self.dJ_dB = jit( - lambda current, gammadash, B_mutual: - grad(self.J_jax, argnums=2)(current, gammadash, B_mutual) + lambda current, gammadash, B_mutual, downsample: + grad(self.J_jax, argnums=2)(current, gammadash, B_mutual, downsample), + **args ) super().__init__(depends_on=allcoils) def J(self): - biotsavart = BiotSavart(self.othercoils) - biotsavart.set_points(self.coil.curve.gamma()) + gamma = self.coil.curve.gamma() + self.biotsavart.set_points(np.array(gamma[::self.downsample, :])) args = [ self.coil.current.get_value(), self.coil.curve.gammadash(), - biotsavart.B(), + self.biotsavart.B(), + self.downsample, ] + #### ABSOLUTELY ESSENTIAL LINES BELOW + # Otherwise optimizable references multiply + # like crazy as number of coils increases + self.biotsavart._children = set() + self.coil._children = set() + self.coil.curve._children = set() + self.coil.current._children = set() for c in self.othercoils: c._children = set() + c.curve._children = set() + c.current._children = set() return self.J_jax(*args) @derivative_dec def dJ(self): - biotsavart = BiotSavart(self.othercoils) - biotsavart.set_points(self.coil.curve.gamma()) + gamma = self.coil.curve.gamma() + gammadash = self.coil.curve.gammadash() + current = self.coil.current.get_value() + + self.biotsavart.set_points(gamma) args = [ - self.coil.current.get_value(), - self.coil.curve.gammadash(), - biotsavart.B(), + current, + gammadash, + self.biotsavart.B(), + 1, ] dJ_dB = self.dJ_dB(*args) - dB_dX = biotsavart.dB_by_dX() + dB_dX = self.biotsavart.dB_by_dX() dJ_dX = np.einsum('ij,ikj->ik', dJ_dB, dB_dX) + B_vjp = self.biotsavart.B_vjp(dJ_dB) + + self.biotsavart.set_points(np.array(gamma[::self.downsample, :])) + + args2 = [ + current, + gammadash, + self.biotsavart.B(), + self.downsample, + ] dJ = ( self.coil.curve.dgamma_by_dcoeff_vjp(dJ_dX) - + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args)) - + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args)])) - + biotsavart.B_vjp(dJ_dB) + + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args2)) + + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args2)])) + + B_vjp ) + #### ABSOLUTELY ESSENTIAL LINES BELOW + # Otherwise optimizable references multiply + # like crazy as number of coils increases + self.biotsavart._children = set() + self.coil._children = set() + self.coil.curve._children = set() + self.coil.current._children = set() for c in self.othercoils: c._children = set() + c.curve._children = set() + c.current._children = set() return dJ @@ -1279,79 +1272,123 @@ class SquaredMeanTorque(Optimizable): """ # @jit - def squared_mean_torque_pure(self, current, gamma, gammadash, B_mutual): + def squared_mean_torque_pure(self, current, gamma, gammadash, B_mutual, downsample): r""" """ - return (current * jnp.linalg.norm(jnp.sum(jnp.cross(gamma - self.coil.curve.center(gamma, gammadash), jnp.cross(gammadash, B_mutual)), axis=0))) ** 2 # / jnp.sum(gammadash_norm) # factor for the integral + gamma = gamma[::downsample, :] + gammadash = gammadash[::downsample, :] + return (current * jnp.linalg.norm(jnp.sum(jnp.cross(gamma - self.coil.curve.center(gamma, gammadash), jnp.cross(gammadash, B_mutual)), axis=0) / gamma.shape[0])) ** 2 # / jnp.sum(gammadash_norm) # factor for the integral - def __init__(self, coil, allcoils): + def __init__(self, coil, allcoils, downsample=1): self.coil = coil - self.allcoils = allcoils self.othercoils = [c for c in allcoils if c is not coil] + self.biotsavart = BiotSavart(self.othercoils) + self.downsample = downsample + args = {"static_argnums": (4,)} self.J_jax = jit( - lambda current, gamma, gammadash, B_mutual: - self.squared_mean_torque_pure(current, gamma, gammadash, B_mutual) + lambda current, gamma, gammadash, B_mutual, downsample: + self.squared_mean_torque_pure(current, gamma, gammadash, B_mutual, downsample), + **args ) self.dJ_dcurrent = jit( - lambda current, gamma, gammadash, B_mutual: - grad(self.J_jax, argnums=0)(current, gamma, gammadash, B_mutual) + lambda current, gamma, gammadash, B_mutual, downsample: + grad(self.J_jax, argnums=0)(current, gamma, gammadash, B_mutual, downsample), + **args ) self.dJ_dgamma = jit( - lambda current, gamma, gammadash, B_mutual: - grad(self.J_jax, argnums=1)(current, gamma, gammadash, B_mutual) + lambda current, gamma, gammadash, B_mutual, downsample: + grad(self.J_jax, argnums=1)(current, gamma, gammadash, B_mutual, downsample), + **args ) self.dJ_dgammadash = jit( - lambda current, gamma, gammadash, B_mutual: - grad(self.J_jax, argnums=2)(current, gamma, gammadash, B_mutual) + lambda current, gamma, gammadash, B_mutual, downsample: + grad(self.J_jax, argnums=2)(current, gamma, gammadash, B_mutual, downsample), + **args ) self.dJ_dB = jit( - lambda current, gamma, gammadash, B_mutual: - grad(self.J_jax, argnums=3)(current, gamma, gammadash, B_mutual) + lambda current, gamma, gammadash, B_mutual, downsample: + grad(self.J_jax, argnums=3)(current, gamma, gammadash, B_mutual, downsample), + **args ) super().__init__(depends_on=allcoils) def J(self): - biotsavart = BiotSavart(self.othercoils) - biotsavart.set_points(self.coil.curve.gamma()) + gamma = self.coil.curve.gamma() + self.biotsavart.set_points(np.array(gamma[::self.downsample, :])) args = [ self.coil.current.get_value(), - self.coil.curve.gamma(), + gamma, self.coil.curve.gammadash(), - biotsavart.B(), + self.biotsavart.B(), + self.downsample ] + J = self.J_jax(*args) + #### ABSOLUTELY ESSENTIAL LINES BELOW + # Otherwise optimizable references multiply + # like crazy as number of coils increases + self.biotsavart._children = set() + self.coil._children = set() + self.coil.curve._children = set() + self.coil.current._children = set() for c in self.othercoils: - c._children = set() - return self.J_jax(*args) + c._children = set() + c.curve._children = set() + c.current._children = set() + return J @derivative_dec def dJ(self): - biotsavart = BiotSavart(self.othercoils) - biotsavart.set_points(self.coil.curve.gamma()) + current = self.coil.current.get_value() + gamma = self.coil.curve.gamma() + gammadash = self.coil.curve.gammadash() + self.biotsavart.set_points(gamma) args = [ - self.coil.current.get_value(), - self.coil.curve.gamma(), - self.coil.curve.gammadash(), - biotsavart.B(), + current, + gamma, + gammadash, + self.biotsavart.B(), + 1 ] dJ_dB = self.dJ_dB(*args) - dB_dX = biotsavart.dB_by_dX() + dB_dX = self.biotsavart.dB_by_dX() dJ_dX = np.einsum('ij,ikj->ik', dJ_dB, dB_dX) + B_vjp = self.biotsavart.B_vjp(dJ_dB) + + self.biotsavart.set_points(np.array(gamma[::self.downsample, :])) + args2 = [ + current, + gamma, + gammadash, + self.biotsavart.B(), + self.downsample + ] + + dJ = ( + self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args2) + dJ_dX) + + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args2)) + + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args2)])) + + B_vjp + ) + #### ABSOLUTELY ESSENTIAL LINES BELOW + # Otherwise optimizable references multiply + # like crazy as number of coils increases + self.biotsavart._children = set() + self.coil._children = set() + self.coil.curve._children = set() + self.coil.current._children = set() for c in self.othercoils: c._children = set() + c.curve._children = set() + c.current._children = set() - return ( - self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args) + dJ_dX) - + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args)) - + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args)])) - + biotsavart.B_vjp(dJ_dB) - ) + return dJ return_fn_map = {'J': J, 'dJ': dJ} @@ -1472,7 +1509,7 @@ class LpCurveTorque(Optimizable): """ # @jit - def lp_torque_pure(self, gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold): + def lp_torque_pure(self, gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold, downsample): r"""Pure function for minimizing the Lorentz force on a coil. The function is @@ -1483,90 +1520,144 @@ def lp_torque_pure(self, gamma, gammadash, gammadashdash, quadpoints, current, r where :math:`\vec{T}` is the Lorentz torque, :math:`T_0` is a threshold torque, and :math:`\ell` is arclength along the coil. """ + gamma = gamma[::downsample, :] + gammadash = gammadash[::downsample, :] + gammadashdash = gammadashdash[::downsample, :] + quadpoints = quadpoints[::downsample] B_self = B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization) gammadash_norm = jnp.linalg.norm(gammadash, axis=1)[:, None] tangent = gammadash / gammadash_norm force = jnp.cross(current * tangent, B_self + B_mutual) torque = jnp.cross(gamma - self.coil.curve.center(gamma, gammadash), force) torque_norm = jnp.linalg.norm(torque, axis=1)[:, None] - return (jnp.sum(jnp.maximum(torque_norm - threshold, 0)**p * gammadash_norm)) * (1 / p) #/ jnp.sum(gammadash_norm) + return (jnp.sum(jnp.maximum(torque_norm - threshold, 0)**p * gammadash_norm) / gamma.shape[0]) * (1 / p) #/ jnp.sum(gammadash_norm) - def __init__(self, coil, allcoils, regularization, p=2.0, threshold=0.0): + def __init__(self, coil, allcoils, regularization, p=2.0, threshold=0.0, downsample=1): self.coil = coil - self.allcoils = allcoils self.othercoils = [c for c in allcoils if c is not coil] self.biotsavart = BiotSavart(self.othercoils) quadpoints = self.coil.curve.quadpoints - center = self.coil.curve.center + self.downsample = downsample + args = {"static_argnums": (5,)} self.J_jax = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - self.lp_torque_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + self.lp_torque_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold, downsample), + **args ) self.dJ_dgamma = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) self.dJ_dgammadash = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) self.dJ_dgammadashdash = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) self.dJ_dcurrent = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=3)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=3)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) self.dJ_dB_mutual = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=4)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=4)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) super().__init__(depends_on=allcoils) def J(self): - self.biotsavart.set_points(self.coil.curve.gamma()) + gamma = self.coil.curve.gamma() + self.biotsavart.set_points(np.array(gamma[::self.downsample, :])) args = [ - self.coil.curve.gamma(), + gamma, self.coil.curve.gammadash(), self.coil.curve.gammadashdash(), self.coil.current.get_value(), self.biotsavart.B(), + self.downsample ] + J = self.J_jax(*args) + #### ABSOLUTELY ESSENTIAL LINES BELOW + # Otherwise optimizable references multiply + # like crazy as number of coils increases + self.biotsavart._children = set() + self.coil._children = set() + self.coil.curve._children = set() + self.coil.current._children = set() + for c in self.othercoils: + c._children = set() + c.curve._children = set() + c.current._children = set() - return self.J_jax(*args) + return J @derivative_dec def dJ(self): - self.biotsavart.set_points(self.coil.curve.gamma()) + gamma = self.coil.curve.gamma() + self.biotsavart.set_points(gamma) + gammadash = self.coil.curve.gammadash() + gammadashdash = self.coil.curve.gammadashdash() + current = self.coil.current.get_value() args = [ - self.coil.curve.gamma(), - self.coil.curve.gammadash(), - self.coil.curve.gammadashdash(), - self.coil.current.get_value(), + gamma, + gammadash, + gammadashdash, + current, self.biotsavart.B(), + 1 ] dJ_dB = self.dJ_dB_mutual(*args) dB_dX = self.biotsavart.dB_by_dX() dJ_dX = np.einsum('ij,ikj->ik', dJ_dB, dB_dX) + B_vjp = self.biotsavart.B_vjp(dJ_dB) - return ( - self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args) + dJ_dX) - + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args)) - + self.coil.curve.dgammadashdash_by_dcoeff_vjp(self.dJ_dgammadashdash(*args)) - + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args)])) - + self.biotsavart.B_vjp(dJ_dB) + self.biotsavart.set_points(np.array(gamma[::self.downsample, :])) + + args2 = [ + gamma, + gammadash, + gammadashdash, + current, + self.biotsavart.B(), + self.downsample + ] + + dJ = ( + self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args2) + dJ_dX) + + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args2)) + + self.coil.curve.dgammadashdash_by_dcoeff_vjp(self.dJ_dgammadashdash(*args2)) + + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args2)])) + + B_vjp ) + #### ABSOLUTELY ESSENTIAL LINES BELOW + # Otherwise optimizable references multiply + # like crazy as number of coils increases + self.biotsavart._children = set() + self.coil._children = set() + self.coil.curve._children = set() + self.coil.current._children = set() + for c in self.othercoils: + c._children = set() + c.curve._children = set() + c.current._children = set() + + return dJ return_fn_map = {'J': J, 'dJ': dJ} diff --git a/src/simsopt/field/selffield.py b/src/simsopt/field/selffield.py index ab6fff0c0..ccc70b30f 100755 --- a/src/simsopt/field/selffield.py +++ b/src/simsopt/field/selffield.py @@ -33,7 +33,9 @@ def regularization_rect(a, b): """Regularization for a rectangular conductor""" return a * b * rectangular_xsection_delta(a, b) +from ..geo.jit import jit +@jit def B_regularized_singularity_term(rc_prime, rc_prime_prime, regularization): """The term in the regularized Biot-Savart law in which the near-singularity has been integrated analytically. @@ -51,7 +53,7 @@ def B_regularized_singularity_term(rc_prime, rc_prime_prime, regularization): 0.5 * (-2 + jnp.log(64 * norm_rc_prime * norm_rc_prime / regularization)) / (norm_rc_prime**3) )[:, None] - +@jit def B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization): # The factors of 2π in the next few lines come from the fact that simsopt # uses a curve parameter that goes up to 1 rather than 2π. @@ -59,26 +61,24 @@ def B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, reg rc = gamma rc_prime = gammadash / 2 / jnp.pi rc_prime_prime = gammadashdash / 4 / jnp.pi**2 - n_quad = phi.shape[0] - dphi = 2 * jnp.pi / n_quad + dphi = 2 * jnp.pi / phi.shape[0] - analytic_term = B_regularized_singularity_term(rc_prime, rc_prime_prime, regularization) + # analytic_term = B_regularized_singularity_term(rc_prime, rc_prime_prime, regularization) dr = rc[:, None] - rc[None, :] first_term = jnp.cross(rc_prime[None, :], dr) / ((jnp.sum(dr * dr, axis=2) + regularization) ** 1.5)[:, :, None] cos_fac = 2.0 - 2.0 * jnp.cos(phi[None, :] - phi[:, None]) second_term = jnp.cross(rc_prime_prime, rc_prime)[:, None, :] * ( 0.5 * cos_fac / (cos_fac * jnp.sum(rc_prime * rc_prime, axis=1)[:, None] + regularization)**1.5)[:, :, None] + first_term = jnp.cross(rc_prime[None, :], dr) / ((jnp.sum(dr * dr, axis=2) + regularization) ** 1.5)[:, :, None] + # cos_fac = 2.0 - 2.0 * jnp.cos(phi[None, :] - phi[:, None]) + # integral_term = dphi * jnp.sum(jnp.cross(rc_prime, (dr / ((jnp.sum(dr * dr, axis=2) + regularization) ** 1.5)[:, :, None]) \ + # - rc_prime_prime * ( + # 0.5 * cos_fac / (cos_fac * jnp.sum(rc_prime * rc_prime, axis=1)[:, None] + regularization)**1.5)[:, :, None]), + # axis=1 + # ) integral_term = dphi * jnp.sum(first_term + second_term, 1) - # print(jnp.any(jnp.isnan(first_term))) - # print(jnp.any(jnp.isnan(second_term))) - # print(jnp.any(jnp.isnan(integral_term))) - # print(jnp.any(jnp.isnan(analytic_term))) - - # print(jnp.max(jnp.abs(first_term))) - # print(jnp.max(jnp.abs(second_term))) - # print(jnp.max(jnp.abs(integral_term))) - # print(jnp.max(jnp.abs(analytic_term))) - return current * Biot_savart_prefactor * (analytic_term + integral_term) + return current * Biot_savart_prefactor * ( + B_regularized_singularity_term(rc_prime, rc_prime_prime, regularization) + integral_term) def B_regularized(coil, regularization): diff --git a/src/simsopt/geo/curve.py b/src/simsopt/geo/curve.py index b3deda774..07f2253d5 100644 --- a/src/simsopt/geo/curve.py +++ b/src/simsopt/geo/curve.py @@ -408,6 +408,12 @@ def dkappadash_by_dcoeff(self): ) return dkappadash_by_dcoeff + def center(self, gamma, gammadash): + # Compute the centroid of the curve + arclength = jnp.linalg.norm(gammadash, axis=-1) + barycenter = jnp.sum(gamma * arclength[:, None], axis=0) / gamma.shape[0] / np.pi + return barycenter + class JaxCurve(sopp.Curve, Curve): def __init__(self, quadpoints, gamma_pure, **kwargs): @@ -696,10 +702,8 @@ def num_dofs(self): def center(self, gamma, gammadash): # Compute the centroid of the curve - quadpoints = self.quadpoints - N = len(quadpoints) arclength = jnp.linalg.norm(gammadash, axis=-1) - barycenter = jnp.sum(gamma * arclength[:, None], axis=0) / N / np.pi + barycenter = jnp.sum(gamma * arclength[:, None], axis=0) / gamma.shape[0] / np.pi return barycenter def gamma_impl(self, gamma, quadpoints): diff --git a/src/simsopt/geo/curveobjectives.py b/src/simsopt/geo/curveobjectives.py index fac00be34..0a1ddea24 100644 --- a/src/simsopt/geo/curveobjectives.py +++ b/src/simsopt/geo/curveobjectives.py @@ -148,10 +148,14 @@ def dJ(self): return_fn_map = {'J': J, 'dJ': dJ} -def cc_distance_pure(gamma1, l1, gamma2, l2, minimum_distance): +def cc_distance_pure(gamma1, l1, gamma2, l2, minimum_distance, downsample=1): """ This function is used in a Python+Jax implementation of the curve-curve distance formula. """ + gamma1 = gamma1[::downsample, :] + gamma2 = gamma2[::downsample, :] + l1 = l1[::downsample, :] + l2 = l2[::downsample, :] dists = jnp.sqrt(jnp.sum((gamma1[:, None, :] - gamma2[None, :, :])**2, axis=2)) alen = jnp.linalg.norm(l1, axis=1)[:, None] * jnp.linalg.norm(l2, axis=1)[None, :] return jnp.sum(alen * jnp.maximum(minimum_distance-dists, 0)**2)/(gamma1.shape[0]*gamma2.shape[0]) @@ -179,15 +183,16 @@ class CurveCurveDistance(Optimizable): """ - def __init__(self, curves, minimum_distance, num_basecurves=None): + def __init__(self, curves, minimum_distance, num_basecurves=None, downsample=1): self.curves = curves self.minimum_distance = minimum_distance - - self.J_jax = jit(lambda gamma1, l1, gamma2, l2: cc_distance_pure(gamma1, l1, gamma2, l2, minimum_distance)) - self.thisgrad0 = jit(lambda gamma1, l1, gamma2, l2: grad(self.J_jax, argnums=0)(gamma1, l1, gamma2, l2)) - self.thisgrad1 = jit(lambda gamma1, l1, gamma2, l2: grad(self.J_jax, argnums=1)(gamma1, l1, gamma2, l2)) - self.thisgrad2 = jit(lambda gamma1, l1, gamma2, l2: grad(self.J_jax, argnums=2)(gamma1, l1, gamma2, l2)) - self.thisgrad3 = jit(lambda gamma1, l1, gamma2, l2: grad(self.J_jax, argnums=3)(gamma1, l1, gamma2, l2)) + self.downsample = downsample + args = {"static_argnums": (4,)} + self.J_jax = jit(lambda gamma1, l1, gamma2, l2, dsample: cc_distance_pure(gamma1, l1, gamma2, l2, minimum_distance, dsample), **args) + self.thisgrad0 = jit(lambda gamma1, l1, gamma2, l2, dsample: grad(self.J_jax, argnums=0)(gamma1, l1, gamma2, l2, dsample), **args) + self.thisgrad1 = jit(lambda gamma1, l1, gamma2, l2, dsample: grad(self.J_jax, argnums=1)(gamma1, l1, gamma2, l2, dsample), **args) + self.thisgrad2 = jit(lambda gamma1, l1, gamma2, l2, dsample: grad(self.J_jax, argnums=2)(gamma1, l1, gamma2, l2, dsample), **args) + self.thisgrad3 = jit(lambda gamma1, l1, gamma2, l2, dsample: grad(self.J_jax, argnums=3)(gamma1, l1, gamma2, l2, dsample), **args) self.candidates = None self.num_basecurves = num_basecurves or len(curves) super().__init__(depends_on=curves) @@ -198,20 +203,22 @@ def recompute_bell(self, parent=None): def compute_candidates(self): if self.candidates is None: candidates = sopp.get_pointclouds_closer_than_threshold_within_collection( - [c.gamma() for c in self.curves], self.minimum_distance, self.num_basecurves) + [c.gamma()[::self.downsample, :] for c in self.curves], self.minimum_distance, self.num_basecurves) self.candidates = candidates def shortest_distance_among_candidates(self): self.compute_candidates() from scipy.spatial.distance import cdist - return min([self.minimum_distance] + [np.min(cdist(self.curves[i].gamma(), self.curves[j].gamma())) for i, j in self.candidates]) + return min([self.minimum_distance] + [np.min(cdist(self.curves[i].gamma()[::self.downsample, :], + self.curves[j].gamma()[::self.downsample, :])) for i, j in self.candidates]) def shortest_distance(self): self.compute_candidates() if len(self.candidates) > 0: return self.shortest_distance_among_candidates() from scipy.spatial.distance import cdist - return min([np.min(cdist(self.curves[i].gamma(), self.curves[j].gamma())) for i in range(len(self.curves)) for j in range(i)]) + return min([np.min(cdist(self.curves[i].gamma()[::self.downsample, :], + self.curves[j].gamma()[::self.downsample, :])) for i in range(len(self.curves)) for j in range(i)]) def J(self): """ @@ -224,7 +231,7 @@ def J(self): l1 = self.curves[i].gammadash() gamma2 = self.curves[j].gamma() l2 = self.curves[j].gammadash() - res += self.J_jax(gamma1, l1, gamma2, l2) + res += self.J_jax(gamma1, l1, gamma2, l2, self.downsample) return res @@ -242,10 +249,10 @@ def dJ(self): l1 = self.curves[i].gammadash() gamma2 = self.curves[j].gamma() l2 = self.curves[j].gammadash() - dgamma_by_dcoeff_vjp_vecs[i] += self.thisgrad0(gamma1, l1, gamma2, l2) - dgammadash_by_dcoeff_vjp_vecs[i] += self.thisgrad1(gamma1, l1, gamma2, l2) - dgamma_by_dcoeff_vjp_vecs[j] += self.thisgrad2(gamma1, l1, gamma2, l2) - dgammadash_by_dcoeff_vjp_vecs[j] += self.thisgrad3(gamma1, l1, gamma2, l2) + dgamma_by_dcoeff_vjp_vecs[i] += self.thisgrad0(gamma1, l1, gamma2, l2, self.downsample) + dgammadash_by_dcoeff_vjp_vecs[i] += self.thisgrad1(gamma1, l1, gamma2, l2, self.downsample) + dgamma_by_dcoeff_vjp_vecs[j] += self.thisgrad2(gamma1, l1, gamma2, l2, self.downsample) + dgammadash_by_dcoeff_vjp_vecs[j] += self.thisgrad3(gamma1, l1, gamma2, l2, self.downsample) res = [self.curves[i].dgamma_by_dcoeff_vjp(dgamma_by_dcoeff_vjp_vecs[i]) + self.curves[i].dgammadash_by_dcoeff_vjp(dgammadash_by_dcoeff_vjp_vecs[i]) for i in range(len(self.curves))] return sum(res) diff --git a/src/simsopt/geo/curveplanarfourier.py b/src/simsopt/geo/curveplanarfourier.py index eb75f3d1a..57bcfc86b 100644 --- a/src/simsopt/geo/curveplanarfourier.py +++ b/src/simsopt/geo/curveplanarfourier.py @@ -79,10 +79,8 @@ def set_dofs(self, dofs): def center(self, gamma, gammadash): # Compute the centroid of the curve - quadpoints = self.quadpoints - N = len(quadpoints) arclength = jnp.linalg.norm(gammadash, axis=-1) - barycenter = jnp.sum(gamma * arclength[:, None], axis=0) / N / np.pi + barycenter = jnp.sum(gamma * arclength[:, None], axis=0) / gamma.shape[0] / np.pi return barycenter def jaxplanarcurve_pure(dofs, quadpoints, order): @@ -151,14 +149,6 @@ def set_dofs_impl(self, dofs): """ self.dof_list = np.array(dofs) - def center(self, gamma, gammadash): - # Compute the centroid of the curve - quadpoints = self.quadpoints - N = len(quadpoints) - arclength = jnp.linalg.norm(gammadash, axis=-1) - barycenter = jnp.sum(gamma * arclength[:, None], axis=0) / N / np.pi - return barycenter - def set_quadpoints(self, quadpoints): self.quadpoints = quadpoints diff --git a/src/simsopt/geo/curvexyzfourier.py b/src/simsopt/geo/curvexyzfourier.py index f41142135..044a7e7c5 100644 --- a/src/simsopt/geo/curvexyzfourier.py +++ b/src/simsopt/geo/curvexyzfourier.py @@ -72,6 +72,12 @@ def set_dofs(self, dofs): self.local_x = dofs sopp.CurveXYZFourier.set_dofs(self, dofs) + def center(self, gamma, gammadash): + # Compute the centroid of the curve + arclength = jnp.linalg.norm(gammadash, axis=-1) + barycenter = jnp.sum(gamma * arclength[:, None], axis=0) / gamma.shape[0] / np.pi + return barycenter + @staticmethod def load_curves_from_file(filename, order=None, ppp=20, delimiter=','): """ @@ -188,14 +194,6 @@ def load_curves_from_makegrid_file(filename: str, order: int, ppp=20): coils[ic].local_x = np.concatenate(dofs) return coils - def center(self, gamma, gammadash): - # Compute the centroid of the curve - quadpoints = self.quadpoints - N = len(quadpoints) - arclength = jnp.linalg.norm(gammadash, axis=-1) - barycenter = jnp.sum(gamma * arclength[:, None], axis=0) / N / np.pi - return barycenter - def jaxfouriercurve_pure(dofs, quadpoints, order): k = jnp.shape(dofs)[0]//3 coeffs = [dofs[:k], dofs[k:(2*k)], dofs[(2*k):]] @@ -260,12 +258,4 @@ def set_dofs_impl(self, dofs): self.coefficients[i][2*j-1] = dofs[counter] counter += 1 self.coefficients[i][2*j] = dofs[counter] - counter += 1 - - def center(self, gamma, gammadash): - # Compute the centroid of the curve - quadpoints = self.quadpoints - N = len(quadpoints) - arclength = jnp.linalg.norm(gammadash, axis=-1) - barycenter = jnp.sum(gamma * arclength[:, None], axis=0) / N / np.pi - return barycenter \ No newline at end of file + counter += 1 \ No newline at end of file diff --git a/src/simsoptpp/biot_savart_vjp_py.cpp b/src/simsoptpp/biot_savart_vjp_py.cpp index 97f91589c..d99a10b8c 100644 --- a/src/simsoptpp/biot_savart_vjp_py.cpp +++ b/src/simsoptpp/biot_savart_vjp_py.cpp @@ -67,6 +67,7 @@ void biot_savart_vjp_graph(Array& points, vector& gammas, vector& auto pointsx = AlignedPaddedVec(points.shape(0), 0); auto pointsy = AlignedPaddedVec(points.shape(0), 0); auto pointsz = AlignedPaddedVec(points.shape(0), 0); + #pragma omp parallel for schedule(static) for (int i = 0; i < points.shape(0); ++i) { pointsx[i] = points(i, 0); pointsy[i] = points(i, 1); @@ -77,24 +78,23 @@ void biot_savart_vjp_graph(Array& points, vector& gammas, vector& bool compute_dB = res_grad_gamma.size() > 0; Array dummy = Array(); - #pragma omp parallel for + #pragma omp parallel for schedule(static) for(int i=0; i(pointsx, pointsy, pointsz, gammas[i], dgamma_by_dphis[i], v, res_gamma[i], res_dgamma_by_dphi[i], vgrad, res_grad_gamma[i], res_grad_dgamma_by_dphi[i]); + res_grad_gamma[i] *= fak; + res_grad_dgamma_by_dphi[i] *= fak; + } else biot_savart_vjp_kernel(pointsx, pointsy, pointsz, gammas[i], dgamma_by_dphis[i], v, res_gamma[i], res_dgamma_by_dphi[i], dummy, dummy, dummy); - double fak = (currents[i] * 1e-7/gammas[i].shape(0)); res_gamma[i] *= fak; res_dgamma_by_dphi[i] *= fak; - if(compute_dB) { - res_grad_gamma[i] *= fak; - res_grad_dgamma_by_dphi[i] *= fak; - } } } diff --git a/src/simsoptpp/curveplanarfourier.cpp b/src/simsoptpp/curveplanarfourier.cpp index 3f2e3192d..e15284158 100644 --- a/src/simsoptpp/curveplanarfourier.cpp +++ b/src/simsoptpp/curveplanarfourier.cpp @@ -18,7 +18,7 @@ void CurvePlanarFourier::gamma_impl(Array& data, Array& quadpoints) { data *= 0; Array q_norm = q * inv_magnitude(); -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int k = 0; k < numquadpoints; ++k) { double phi = 2 * M_PI * quadpoints[k]; double cosphi = cos(phi); @@ -32,7 +32,7 @@ void CurvePlanarFourier::gamma_impl(Array& data, Array& quadpoints) { data(k, 1) += (rc[i] * cosiphi + rs[i-1] * siniphi) * sinphi; } } -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int m = 0; m < numquadpoints; ++m) { double i = data(m, 0); double j = data(m, 1); @@ -53,7 +53,7 @@ void CurvePlanarFourier::gammadash_impl(Array& data) { Array q_norm = q * inv_sqrt_s; double cosiphi, siniphi; -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int k = 0; k < numquadpoints; ++k) { double phi = 2 * M_PI * quadpoints[k]; double cosphi = cos(phi); @@ -71,7 +71,7 @@ void CurvePlanarFourier::gammadash_impl(Array& data) { } data *= (2*M_PI); -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int m = 0; m < numquadpoints; ++m) { double i = data(m, 0); double j = data(m, 1); @@ -92,7 +92,7 @@ void CurvePlanarFourier::gammadashdash_impl(Array& data) { Array q_norm = q * inv_magnitude(); double cosiphi, siniphi; -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int k = 0; k < numquadpoints; ++k) { double phi = 2 * M_PI * quadpoints[k]; double cosphi = cos(phi); @@ -109,7 +109,7 @@ void CurvePlanarFourier::gammadashdash_impl(Array& data) { } } data *= 2*M_PI*2*M_PI; -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int m = 0; m < numquadpoints; ++m) { double i = data(m, 0); double j = data(m, 1); @@ -130,7 +130,7 @@ void CurvePlanarFourier::gammadashdashdash_impl(Array& data) { Array q_norm = q * inv_magnitude(); double cosiphi, siniphi; -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int k = 0; k < numquadpoints; ++k) { double phi = 2 * M_PI * quadpoints[k]; double cosphi = cos(phi); @@ -158,7 +158,7 @@ void CurvePlanarFourier::gammadashdashdash_impl(Array& data) { } } data *= 2*M_PI*2*M_PI*2*M_PI; -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int m = 0; m < numquadpoints; ++m) { double i = data(m, 0); double j = data(m, 1); @@ -179,7 +179,7 @@ void CurvePlanarFourier::dgamma_by_dcoeff_impl(Array& data) { Array q_norm = q * inv_magnitude(); double cosnphi, sinnphi; -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int m = 0; m < numquadpoints; ++m) { double phi = 2 * M_PI * quadpoints[m]; int counter = 0; @@ -303,7 +303,7 @@ void CurvePlanarFourier::dgammadash_by_dcoeff_impl(Array& data) { Array q_norm = q * inv_magnitude(); double cosnphi, sinnphi; -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int m = 0; m < numquadpoints; ++m) { double phi = 2 * M_PI * quadpoints[m]; double cosphi = cos(phi); @@ -435,7 +435,7 @@ void CurvePlanarFourier::dgammadashdash_by_dcoeff_impl(Array& data) { Array q_norm = q * inv_magnitude(); double cosnphi, sinnphi; -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int m = 0; m < numquadpoints; ++m) { double phi = 2 * M_PI * quadpoints[m]; double cosphi = cos(phi); @@ -567,7 +567,7 @@ void CurvePlanarFourier::dgammadashdashdash_by_dcoeff_impl(Array& data) { Array q_norm = q * inv_magnitude(); double cosnphi, sinnphi; -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int m = 0; m < numquadpoints; ++m) { double phi = 2 * M_PI * quadpoints[m]; double cosphi = cos(phi); diff --git a/src/simsoptpp/curvexyzfourier.cpp b/src/simsoptpp/curvexyzfourier.cpp index bae6886ef..39eb9b032 100644 --- a/src/simsoptpp/curvexyzfourier.cpp +++ b/src/simsoptpp/curvexyzfourier.cpp @@ -4,7 +4,7 @@ template void CurveXYZFourier::gamma_impl(Array& data, Array& quadpoints) { int numquadpoints = quadpoints.size(); data *= 0; -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int k = 0; k < numquadpoints; ++k) { for (int i = 0; i < 3; ++i) { data(k, i) += dofs[i][0]; @@ -19,7 +19,7 @@ void CurveXYZFourier::gamma_impl(Array& data, Array& quadpoints) { template void CurveXYZFourier::gammadash_impl(Array& data) { data *= 0; -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int k = 0; k < numquadpoints; ++k) { for (int i = 0; i < 3; ++i) { for (int j = 1; j < order+1; ++j) { @@ -33,7 +33,7 @@ void CurveXYZFourier::gammadash_impl(Array& data) { template void CurveXYZFourier::gammadashdash_impl(Array& data) { data *= 0; -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int k = 0; k < numquadpoints; ++k) { for (int i = 0; i < 3; ++i) { for (int j = 1; j < order+1; ++j) { @@ -47,7 +47,7 @@ void CurveXYZFourier::gammadashdash_impl(Array& data) { template void CurveXYZFourier::gammadashdashdash_impl(Array& data) { data *= 0; -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int k = 0; k < numquadpoints; ++k) { for (int i = 0; i < 3; ++i) { for (int j = 1; j < order+1; ++j) { @@ -60,7 +60,7 @@ void CurveXYZFourier::gammadashdashdash_impl(Array& data) { template void CurveXYZFourier::dgamma_by_dcoeff_impl(Array& data) { -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int k = 0; k < numquadpoints; ++k) { for (int i = 0; i < 3; ++i) { data(k, i, i*(2*order+1)) = 1.; @@ -74,7 +74,7 @@ void CurveXYZFourier::dgamma_by_dcoeff_impl(Array& data) { template void CurveXYZFourier::dgammadash_by_dcoeff_impl(Array& data) { -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int k = 0; k < numquadpoints; ++k) { for (int i = 0; i < 3; ++i) { for (int j = 1; j < order+1; ++j) { @@ -87,7 +87,7 @@ void CurveXYZFourier::dgammadash_by_dcoeff_impl(Array& data) { template void CurveXYZFourier::dgammadashdash_by_dcoeff_impl(Array& data) { -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int k = 0; k < numquadpoints; ++k) { for (int i = 0; i < 3; ++i) { for (int j = 1; j < order+1; ++j) { @@ -100,7 +100,7 @@ void CurveXYZFourier::dgammadashdash_by_dcoeff_impl(Array& data) { template void CurveXYZFourier::dgammadashdashdash_by_dcoeff_impl(Array& data) { -#pragma omp parallel for schedule(static) +// #pragma omp parallel for schedule(static) for (int k = 0; k < numquadpoints; ++k) { for (int i = 0; i < 3; ++i) { for (int j = 1; j < order+1; ++j) {