From 5ae14c0d0c5ffade4969231a83be9bfdb2adccae Mon Sep 17 00:00:00 2001 From: Alan Kaptanoglu Date: Fri, 1 Nov 2024 17:00:55 -0400 Subject: [PATCH] Made some more useful changes. Got the MixedLpCurve class running faster by downsampling the calculation, but for some reason the initial compilation is very slow. Still trying to understand how to avoid spawning all the weak references to child processes during optimization with the normal LpCurveForce object, and added downsampling to this too. This is probably worth trying to figure out definitively. --- examples/3_Advanced/QA_single_TFcoil.py | 207 ++++-- examples/3_Advanced/QH_reactorscale_DA.py | 39 +- .../3_Advanced/QH_reactorscale_nodipoles.py | 469 +++++++++++++ .../3_Advanced/QH_reactorscale_notfixed.py | 643 ++++++++++++++++++ src/simsopt/field/biotsavart.py | 2 +- src/simsopt/field/force.py | 193 ++++-- src/simsopt/field/selffield.py | 9 +- src/simsopt/geo/curve.py | 9 + src/simsopt/geo/curveplanarfourier.py | 2 + src/simsopt/geo/jit.py | 4 +- tests/field/test_selffieldforces.py | 27 +- 11 files changed, 1452 insertions(+), 152 deletions(-) create mode 100644 examples/3_Advanced/QH_reactorscale_nodipoles.py create mode 100644 examples/3_Advanced/QH_reactorscale_notfixed.py diff --git a/examples/3_Advanced/QA_single_TFcoil.py b/examples/3_Advanced/QA_single_TFcoil.py index bfad035fc..783aa5164 100644 --- a/examples/3_Advanced/QA_single_TFcoil.py +++ b/examples/3_Advanced/QA_single_TFcoil.py @@ -7,6 +7,7 @@ from pathlib import Path import time import numpy as np +import warnings from scipy.optimize import minimize from simsopt.field import BiotSavart, Current, coils_via_symmetries # from simsopt.field import CoilCoilNetForces, CoilCoilNetTorques, \ @@ -41,8 +42,8 @@ range_param = "half period" nphi = 32 ntheta = 32 -poff = 1.5 -coff = 2.0 +poff = 1.75 +coff = 2.5 s = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi, ntheta=ntheta) s_inner = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4) s_outer = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4) @@ -88,7 +89,7 @@ def initialize_coils_QA(TEST_DIR, s): ncoils = 1 R0 = s.get_rc(0, 0) * 2 R1 = s.get_rc(1, 0) * 10 - order = 4 + order = 10 from simsopt.mhd.vmec import Vmec vmec_file = 'wout_LandremanPaul2021_QA_reactorScale_lowres_reference.nc' @@ -141,7 +142,7 @@ def initialize_coils_QA(TEST_DIR, s): aa = 0.05 bb = 0.05 -Nx = 3 +Nx = 6 Ny = Nx Nz = Nx # Create the initial coils: @@ -149,25 +150,98 @@ def initialize_coils_QA(TEST_DIR, s): s, s_inner, s_outer, Nx, Ny, Nz, order=order, coil_coil_flag=True, jax_flag=True, # numquadpoints=10 # Defaults is (order + 1) * 40 so this halves it ) +# ncoils = len(base_curves) +# print('Ncoils = ', ncoils) +# for i in range(len(base_curves)): +# # base_curves[i].set('x' + str(2 * order + 1), np.random.rand(1) - 0.5) +# # base_curves[i].set('x' + str(2 * order + 2), np.random.rand(1) - 0.5) +# # base_curves[i].set('x' + str(2 * order + 3), np.random.rand(1) - 0.5) +# # base_curves[i].set('x' + str(2 * order + 4), np.random.rand(1) - 0.5) + +# # Fix shape of each coil +# for j in range(2 * order + 1): +# base_curves[i].fix('x' + str(j)) +# # Fix center points of each coil +# # base_curves[i].fix('x' + str(2 * order + 5)) +# # base_curves[i].fix('x' + str(2 * order + 6)) +# # base_curves[i].fix('x' + str(2 * order + 7)) +# base_currents = [Current(1e-1) * 2e7 for i in range(ncoils)] +# # Fix currents in each coil +# # for i in range(ncoils): +# # base_currents[i].fix_all() + +# coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) +# base_coils = coils[:ncoils] + + +keep_inds = [] +for ii in range(len(base_curves)): + counter = 0 + for i in range(base_curves[0].gamma().shape[0]): + eps = 0.05 + for j in range(len(base_curves_TF)): + for k in range(base_curves_TF[j].gamma().shape[0]): + dij = np.sqrt(np.sum((base_curves[ii].gamma()[i, :] - base_curves_TF[j].gamma()[k, :]) ** 2)) + conflict_bool = (dij < (1.0 + eps) * base_curves[0].x[0]) + if conflict_bool: + print('bad indices = ', i, j, dij, base_curves[0].x[0]) + warnings.warn( + 'There is a PSC coil initialized such that it is within a radius' + 'of a TF coil. Deleting these PSCs now.') + counter += 1 + break + if counter == 0: + keep_inds.append(ii) + +print(keep_inds) +base_curves = np.array(base_curves)[keep_inds] + ncoils = len(base_curves) print('Ncoils = ', ncoils) +coil_normals = np.zeros((ncoils, 3)) +plasma_points = s.gamma().reshape(-1, 3) +plasma_unitnormals = s.unitnormal().reshape(-1, 3) +for i in range(ncoils): + point = (base_curves[i].get_dofs()[-3:]) + dists = np.sum((point - plasma_points) ** 2, axis=-1) + min_ind = np.argmin(dists) + coil_normals[i, :] = plasma_unitnormals[min_ind, :] + # coil_normals[i, :] = (plasma_points[min_ind, :] - point) +coil_normals = coil_normals / np.linalg.norm(coil_normals, axis=-1)[:, None] +# alphas = np.arctan2( +# -coil_normals[:, 1], +# np.sqrt(coil_normals[:, 0] ** 2 + coil_normals[:, 2] ** 2)) +# deltas = np.arcsin(coil_normals[:, 0] / \ +# np.sqrt(coil_normals[:, 0] ** 2 + coil_normals[:, 2] ** 2)) +alphas = np.arcsin( + -coil_normals[:, 1], + ) +deltas = np.arctan2(coil_normals[:, 0], coil_normals[:, 2]) for i in range(len(base_curves)): - # base_curves[i].set('x' + str(2 * order + 1), np.random.rand(1) - 0.5) - # base_curves[i].set('x' + str(2 * order + 2), np.random.rand(1) - 0.5) - # base_curves[i].set('x' + str(2 * order + 3), np.random.rand(1) - 0.5) - # base_curves[i].set('x' + str(2 * order + 4), np.random.rand(1) - 0.5) + alpha2 = alphas[i] / 2.0 + delta2 = deltas[i] / 2.0 + calpha2 = np.cos(alpha2) + salpha2 = np.sin(alpha2) + cdelta2 = np.cos(delta2) + sdelta2 = np.sin(delta2) + base_curves[i].set('x' + str(2 * order + 1), calpha2 * cdelta2) + base_curves[i].set('x' + str(2 * order + 2), salpha2 * cdelta2) + base_curves[i].set('x' + str(2 * order + 3), calpha2 * sdelta2) + base_curves[i].set('x' + str(2 * order + 4), -salpha2 * sdelta2) + # Fix orientations of each coil + base_curves[i].fix('x' + str(2 * order + 1)) + base_curves[i].fix('x' + str(2 * order + 2)) + base_curves[i].fix('x' + str(2 * order + 3)) + base_curves[i].fix('x' + str(2 * order + 4)) # Fix shape of each coil for j in range(2 * order + 1): base_curves[i].fix('x' + str(j)) # Fix center points of each coil - # base_curves[i].fix('x' + str(2 * order + 5)) - # base_curves[i].fix('x' + str(2 * order + 6)) - # base_curves[i].fix('x' + str(2 * order + 7)) + base_curves[i].fix('x' + str(2 * order + 5)) + base_curves[i].fix('x' + str(2 * order + 6)) + base_curves[i].fix('x' + str(2 * order + 7)) base_currents = [Current(1e-1) * 2e7 for i in range(ncoils)] -# Fix currents in each coil -# for i in range(ncoils): -# base_currents[i].fix_all() coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) base_coils = coils[:ncoils] @@ -210,24 +284,31 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): base_a_list = np.hstack((np.ones(len(base_coils)) * aa, np.ones(len(base_coils_TF)) * a)) base_b_list = np.hstack((np.ones(len(base_coils)) * bb, np.ones(len(base_coils_TF)) * b)) -LENGTH_WEIGHT = Weight(0.01) -LENGTH_TARGET = 90 # 90 works very well... can we get it down? -LINK_WEIGHT = 1e3 +LENGTH_WEIGHT = Weight(0.001) +LENGTH_TARGET = 80 # 90 works very well... can we get it down? +LINK_WEIGHT = 1e2 CC_THRESHOLD = 0.8 -CC_WEIGHT = 1e2 +CC_WEIGHT = 1e1 CS_THRESHOLD = 1.5 -CS_WEIGHT = 1e2 +CS_WEIGHT = 1e1 # Weight for the Coil Coil forces term -FORCE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +FORCE_WEIGHT = Weight(1e-37) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons -TORQUE_WEIGHT = Weight(1e-18) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +TORQUE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons + +# FORCE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons # Directory for output OUT_DIR = ("./QA_singleTF_n{:d}_p{:.2e}_c{:.2e}_lw{:.2e}_lt{:.2e}_lkw{:.2e}" + \ "_cct{:.2e}_ccw{:.2e}_cst{:.2e}_csw{:.2e}_fw{:.2e}_fww{:2e}_tw{:.2e}/").format( ncoils, poff, coff, LENGTH_WEIGHT.value, LENGTH_TARGET, LINK_WEIGHT, CC_THRESHOLD, CC_WEIGHT, CS_THRESHOLD, CS_WEIGHT, FORCE_WEIGHT.value, FORCE_WEIGHT2.value, - TORQUE_WEIGHT.value) + TORQUE_WEIGHT.value, + TORQUE_WEIGHT2.value) if os.path.exists(OUT_DIR): shutil.rmtree(OUT_DIR) os.makedirs(OUT_DIR, exist_ok=True) @@ -281,7 +362,8 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): # coil-coil and coil-plasma distances should be between all coils Jccdist = CurveCurveDistance(curves + curves_TF, CC_THRESHOLD, num_basecurves=len(coils + coils_TF)) -Jcsdist = CurveSurfaceDistance(curves + curves_TF, s, CS_THRESHOLD) +Jcsdist = CurveSurfaceDistance(curves_TF, s, CS_THRESHOLD) +Jcsdist2 = CurveSurfaceDistance(curves + curves_TF, s, CS_THRESHOLD) # While the coil array is not moving around, they cannot # interlink. @@ -298,9 +380,13 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): regularization_list2 = np.zeros(len(coils_TF)) * regularization_rect(a, b) # Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2) # [SquaredMeanForce2(c, coils) for c in (base_coils)] # Jforce = MixedSquaredMeanForce(coils, coils_TF) -Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=2e5 * 40) for i, c in enumerate(base_coils + base_coils_TF)]) + +###### NOTE JFORCE BELOW ONLY DOING THE DIPOLE COILS!!!! +Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=4, threshold=1e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) +# Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=4, threshold=4e5 * 100) for i, c in enumerate(base_coils)]) Jforce2 = sum([SquaredMeanForce(c, coils + coils_TF) for i, c in enumerate(base_coils + base_coils_TF)]) -Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i])) for i, c in enumerate(base_coils + base_coils_TF)]) +Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=4e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) +Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) # Jtorque = SquaredMeanTorque2(coils, coils_TF) # [SquaredMeanForce2(c, coils) for c in (base_coils)] # Jtorque = [SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)] @@ -317,17 +403,11 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): JF += FORCE_WEIGHT2.value * Jforce2 #\ if TORQUE_WEIGHT.value > 0.0: - JF += TORQUE_WEIGHT.value * Jtorque #\ - # + FORCE_WEIGHT2.value * Jforce2 - # + TORQUE_WEIGHT * Jtorque - # + TVE_WEIGHT * Jtve - # + SF_WEIGHT * Jsf - # + CURRENTS_WEIGHT * DipoleCurrentsObj - # + CURVATURE_WEIGHT * sum(Jcs_TF) \ - # + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD) for J in Jmscs_TF) \ -# + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD) for J in Jmscs) \ - # + CURVATURE_WEIGHT * sum(Jcs) \ + JF += TORQUE_WEIGHT * Jtorque +if TORQUE_WEIGHT2.value > 0.0: + JF += TORQUE_WEIGHT2 * Jtorque2 + # We don't have a general interface in SIMSOPT for optimisation problems that # are not in least-squares form, so we write a little wrapper function that we # pass directly to scipy.optimize.minimize @@ -359,6 +439,7 @@ def fun(dofs): forces_val = FORCE_WEIGHT.value * Jforce.J() forces_val2 = FORCE_WEIGHT2.value * Jforce2.J() torques_val = TORQUE_WEIGHT.value * Jtorque.J() + torques_val2 = TORQUE_WEIGHT2.value * Jtorque2.J() BdotN = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))) BdotN_over_B = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)) ) / np.mean(btot.AbsB()) @@ -373,10 +454,12 @@ def fun(dofs): valuestr += f", forceObj={forces_val:.2e}" valuestr += f", forceObj2={forces_val2:.2e}" valuestr += f", torqueObj={torques_val:.2e}" + valuestr += f", torqueObj2={torques_val2:.2e}" outstr += f", F={Jforce.J():.2e}" - outstr += f", Fpointwise={Jforce2.J():.2e}" + outstr += f", Fnet={Jforce2.J():.2e}" outstr += f", T={Jtorque.J():.2e}" - outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist.shortest_distance():.2f}" + outstr += f", Tnet={Jtorque2.J():.2e}" + outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist2.shortest_distance():.2f}" outstr += f", Link Number = {linkNum.J()}" outstr += f", ║∇J║={np.linalg.norm(grad):.1e}" print(outstr) @@ -438,33 +521,33 @@ def fun(dofs): Jlength.dJ() t2 = time.time() print('sum(Jls_TF) time = ', t2 - t1, ' s') -t1 = time.time() -Jforce.J() -t2 = time.time() -print('Jforces time = ', t2 - t1, ' s') -t1 = time.time() -Jforce.dJ() -t2 = time.time() -print('dJforces time = ', t2 - t1, ' s') -t1 = time.time() -Jforce2.J() -t2 = time.time() -print('Jforces2 time = ', t2 - t1, ' s') -t1 = time.time() -Jforce2.dJ() -t2 = time.time() -print('dJforces2 time = ', t2 - t1, ' s') -t1 = time.time() -Jtorque.J() -t2 = time.time() -print('Jtorques time = ', t2 - t1, ' s') -t1 = time.time() -Jtorque.dJ() -t2 = time.time() -print('dJtorques time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce.J() +# t2 = time.time() +# print('Jforces time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce.dJ() +# t2 = time.time() +# print('dJforces time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce2.J() +# t2 = time.time() +# print('Jforces2 time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce2.dJ() +# t2 = time.time() +# print('dJforces2 time = ', t2 - t1, ' s') +# t1 = time.time() +# Jtorque.J() +# t2 = time.time() +# print('Jtorques time = ', t2 - t1, ' s') +# t1 = time.time() +# Jtorque.dJ() +# t2 = time.time() +# print('dJtorques time = ', t2 - t1, ' s') n_saves = 1 -MAXITER = 400 +MAXITER = 200 for i in range(1, n_saves + 1): print('Iteration ' + str(i) + ' / ' + str(n_saves)) res = minimize(fun, dofs, jac=True, method='L-BFGS-B', diff --git a/examples/3_Advanced/QH_reactorscale_DA.py b/examples/3_Advanced/QH_reactorscale_DA.py index 2bf0f8eb1..5e4950ada 100644 --- a/examples/3_Advanced/QH_reactorscale_DA.py +++ b/examples/3_Advanced/QH_reactorscale_DA.py @@ -88,7 +88,7 @@ def initialize_coils_QH(TEST_DIR, s): ncoils = 2 R0 = s.get_rc(0, 0) * 1 R1 = s.get_rc(1, 0) * 4 - order = 7 + order = 8 from simsopt.mhd.vmec import Vmec vmec_file = 'wout_LandremanPaul2021_QH_reactorScale_lowres_reference.nc' @@ -207,18 +207,18 @@ def initialize_coils_QH(TEST_DIR, s): base_curves[i].set('x' + str(2 * order + 3), calpha2 * sdelta2) base_curves[i].set('x' + str(2 * order + 4), -salpha2 * sdelta2) # Fix orientations of each coil - base_curves[i].fix('x' + str(2 * order + 1)) - base_curves[i].fix('x' + str(2 * order + 2)) - base_curves[i].fix('x' + str(2 * order + 3)) - base_curves[i].fix('x' + str(2 * order + 4)) + # base_curves[i].fix('x' + str(2 * order + 1)) + # base_curves[i].fix('x' + str(2 * order + 2)) + # base_curves[i].fix('x' + str(2 * order + 3)) + # base_curves[i].fix('x' + str(2 * order + 4)) # Fix shape of each coil for j in range(2 * order + 1): base_curves[i].fix('x' + str(j)) # Fix center points of each coil - base_curves[i].fix('x' + str(2 * order + 5)) - base_curves[i].fix('x' + str(2 * order + 6)) - base_curves[i].fix('x' + str(2 * order + 7)) + # base_curves[i].fix('x' + str(2 * order + 5)) + # base_curves[i].fix('x' + str(2 * order + 6)) + # base_curves[i].fix('x' + str(2 * order + 7)) base_currents = [Current(1e-1) * 2e7 for i in range(ncoils)] coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) @@ -270,11 +270,11 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): CS_THRESHOLD = 1.5 CS_WEIGHT = 1e2 # Weight for the Coil Coil forces term -# FORCE_WEIGHT = Weight(1e-22) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT = Weight(1e-34) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons # FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons -# TORQUE_WEIGHT = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons -# TORQUE_WEIGHT2 = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons -FORCE_WEIGHT = Weight(1e-34) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT2 = Weight(4e-27) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +FORCE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons TORQUE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons @@ -355,15 +355,20 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): # Jforce2 = LpCurveForce2(coils, coils_TF, p=2, threshold=1e8) # Jforce = sum([SquaredMeanForce(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) -regularization_list = np.zeros(len(coils)) * regularization_rect(aa, bb) -regularization_list2 = np.zeros(len(coils_TF)) * regularization_rect(a, b) +regularization_list = np.ones(len(coils)) * regularization_rect(aa, bb) +regularization_list2 = np.ones(len(coils_TF)) * regularization_rect(a, b) # Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2) # [SquaredMeanForce2(c, coils) for c in (base_coils)] # Jforce = MixedSquaredMeanForce(coils, coils_TF) Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=4, threshold=4e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) Jforce2 = sum([SquaredMeanForce(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=4e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) # Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=1e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) -Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) + + +Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils_TF)]) + +# Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) + # Jtorque = SquaredMeanTorque2(coils, coils_TF) # [SquaredMeanForce2(c, coils) for c in (base_coils)] # Jtorque = [SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)] @@ -371,8 +376,8 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): JF = Jf \ + CC_WEIGHT * Jccdist \ + CS_WEIGHT * Jcsdist \ - + LINK_WEIGHT * linkNum \ - + LENGTH_WEIGHT * Jlength + + LENGTH_WEIGHT * Jlength \ + + LINK_WEIGHT * linkNum if FORCE_WEIGHT.value > 0.0: JF += FORCE_WEIGHT.value * Jforce #\ diff --git a/examples/3_Advanced/QH_reactorscale_nodipoles.py b/examples/3_Advanced/QH_reactorscale_nodipoles.py new file mode 100644 index 000000000..53d98e2b2 --- /dev/null +++ b/examples/3_Advanced/QH_reactorscale_nodipoles.py @@ -0,0 +1,469 @@ +#!/usr/bin/env python +r""" +""" + +import os +import shutil +from pathlib import Path +import time +import numpy as np +from scipy.optimize import minimize +from simsopt.field import BiotSavart, Current, coils_via_symmetries +# from simsopt.field import CoilCoilNetForces, CoilCoilNetTorques, \ +# TotalVacuumEnergy, CoilSelfNetForces, CoilCoilNetForces12, CoilCoilNetTorques12 +from simsopt.field import regularization_rect +from simsopt.field.force import MeanSquaredForce, coil_force, coil_torque, coil_net_torques, coil_net_forces, LpCurveForce, \ + SquaredMeanForce, \ + MeanSquaredTorque, SquaredMeanTorque, LpCurveTorque, MixedSquaredMeanForce, MixedLpCurveForce +from simsopt.util import calculate_on_axis_B +from simsopt.geo import ( + CurveLength, CurveCurveDistance, + MeanSquaredCurvature, LpCurveCurvature, CurveSurfaceDistance, LinkingNumber, + SurfaceRZFourier, curves_to_vtk, create_equally_spaced_planar_curves, + create_planar_curves_between_two_toroidal_surfaces +) +from simsopt.objectives import Weight, SquaredFlux, QuadraticPenalty +from simsopt.util import in_github_actions +import cProfile +import re + +t1 = time.time() + +# Number of Fourier modes describing each Cartesian component of each coil: +order = 0 + +# File for the desired boundary magnetic surface: +TEST_DIR = (Path(__file__).parent / ".." / ".." / "tests" / "test_files").resolve() +input_name = 'input.LandremanPaul2021_QH_reactorScale_lowres' +filename = TEST_DIR / input_name + +# Initialize the boundary magnetic surface: +range_param = "half period" +nphi = 32 +ntheta = 32 +poff = 2.0 +coff = 3.0 +s = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi, ntheta=ntheta) +s_inner = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4) +s_outer = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4) + +# Make the inner and outer surfaces by extending the plasma surface +s_inner.extend_via_normal(poff) +s_outer.extend_via_normal(poff + coff) + +qphi = nphi * 2 +qtheta = ntheta * 2 +quadpoints_phi = np.linspace(0, 1, qphi, endpoint=True) +quadpoints_theta = np.linspace(0, 1, qtheta, endpoint=True) + +# Make high resolution, full torus version of the plasma boundary for plotting +s_plot = SurfaceRZFourier.from_vmec_input( + filename, + quadpoints_phi=quadpoints_phi, + quadpoints_theta=quadpoints_theta +) + +### Initialize some TF coils +def initialize_coils_QH(TEST_DIR, s): + """ + Initializes coils for each of the target configurations that are + used for permanent magnet optimization. + + Args: + config_flag: String denoting the stellarator configuration + being initialized. + TEST_DIR: String denoting where to find the input files. + out_dir: Path or string for the output directory for saved files. + s: plasma boundary surface. + Returns: + base_curves: List of CurveXYZ class objects. + curves: List of Curve class objects. + coils: List of Coil class objects. + """ + from simsopt.geo import create_equally_spaced_curves + from simsopt.field import Current, Coil, coils_via_symmetries + from simsopt.geo import curves_to_vtk + + # generate planar TF coils + ncoils = 2 + R0 = s.get_rc(0, 0) * 1 + R1 = s.get_rc(1, 0) * 4 + order = 8 + + from simsopt.mhd.vmec import Vmec + vmec_file = 'wout_LandremanPaul2021_QH_reactorScale_lowres_reference.nc' + total_current = Vmec(TEST_DIR / vmec_file).external_current() / (2 * s.nfp) / 1.4 + print('Total current = ', total_current) + + # Only need Jax flag for CurvePlanarFourier class + base_curves = create_equally_spaced_curves( + ncoils, s.nfp, stellsym=True, + R0=R0, R1=R1, order=order, numquadpoints=256, + jax_flag=True, + ) + + base_currents = [(Current(total_current / ncoils * 1e-7) * 1e7) for _ in range(ncoils - 1)] + # base_currents = [(Current(total_current / ncoils * 1e-7) * 1e7) for _ in range(ncoils)] + # base_currents[0].fix_all() + + total_current = Current(total_current) + total_current.fix_all() + base_currents += [total_current - sum(base_currents)] + coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) + # for c in coils: + # c.current.fix_all() + # c.curve.fix_all() + + # Initialize the coil curves and save the data to vtk + curves = [c.curve for c in coils] + currents = [c.current.get_value() for c in coils] + return base_curves, curves, coils, base_currents + +# initialize the coils +base_curves_TF, curves_TF, coils_TF, currents_TF = initialize_coils_QH(TEST_DIR, s) +num_TF_unique_coils = len(coils_TF) // 4 +base_coils_TF = coils_TF[:num_TF_unique_coils] +currents_TF = np.array([coil.current.get_value() for coil in coils_TF]) + +# # Set up BiotSavart fields +bs_TF = BiotSavart(coils_TF) + +# # Calculate average, approximate on-axis B field strength +calculate_on_axis_B(bs_TF, s) + +# wire cross section for the TF coils is a square 20 cm x 20 cm +# Only need this if make self forces and TVE nonzero in the objective! +a = 0.2 +b = 0.2 +nturns = 100 +nturns_TF = 200 + +def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): + contig = np.ascontiguousarray + forces = np.zeros((len(coils), len(coils[0].curve.gamma()) + 1, 3)) + torques = np.zeros((len(coils), len(coils[0].curve.gamma()) + 1, 3)) + for i, c in enumerate(coils): + aprime = aprimes[i] + bprime = bprimes[i] + # print(np.shape(bs._coils), np.shape(coils)) + # B_other = BiotSavart([cc for j, cc in enumerate(bs._coils) if i != j]).set_points(c.curve.gamma()).B() + # print(np.shape(bs._coils)) + # B_other = bs.set_points(c.curve.gamma()).B() + # print(B_other) + # exit() + forces[i, :-1, :] = coil_force(c, allcoils, regularization_rect(aprime, bprime), nturns_list[i]) + # print(i, forces[i, :-1, :]) + # bs._coils = coils + torques[i, :-1, :] = coil_torque(c, allcoils, regularization_rect(aprime, bprime), nturns_list[i]) + + forces[:, -1, :] = forces[:, 0, :] + torques[:, -1, :] = torques[:, 0, :] + forces = forces.reshape(-1, 3) + torques = torques.reshape(-1, 3) + point_data = {"Pointwise_Forces": (contig(forces[:, 0]), contig(forces[:, 1]), contig(forces[:, 2])), + "Pointwise_Torques": (contig(torques[:, 0]), contig(torques[:, 1]), contig(torques[:, 2]))} + return point_data + +btot = bs_TF +calculate_on_axis_B(btot, s) +btot.set_points(s.gamma().reshape((-1, 3))) +# a_list = np.hstack((np.ones(len(coils)) * aa, np.ones(len(coils_TF)) * a)) +# b_list = np.hstack((np.ones(len(coils)) * bb, np.ones(len(coils_TF)) * b)) +# base_a_list = np.hstack((np.ones(len(base_coils)) * aa, np.ones(len(base_coils_TF)) * a)) +# base_b_list = np.hstack((np.ones(len(base_coils)) * bb, np.ones(len(base_coils_TF)) * b)) + +LENGTH_WEIGHT = Weight(0.001) +LENGTH_TARGET = 80 +LINK_WEIGHT = 1e3 +CC_THRESHOLD = 0.8 +CC_WEIGHT = 1e1 +CS_THRESHOLD = 1.5 +CS_WEIGHT = 1e2 +# Weight for the Coil Coil forces term +FORCE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT2 = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT = Weight(1e-22) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT2 = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# Directory for output +OUT_DIR = ("./QH_nodipoles/") +if os.path.exists(OUT_DIR): + shutil.rmtree(OUT_DIR) +os.makedirs(OUT_DIR, exist_ok=True) + +curves_to_vtk( + curves_TF, + OUT_DIR + "curves_TF_0", + close=True, + extra_point_data=pointData_forces_torques(coils_TF, coils_TF, np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b, np.ones(len(coils_TF)) * nturns_TF), + I=currents_TF, + NetForces=coil_net_forces(coils_TF, coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF), + NetTorques=coil_net_torques(coils_TF, coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF) +) +# Force and Torque calculations spawn a bunch of spurious BiotSavart child objects -- erase them! +for c in (coils_TF): + c._children = set() + +pointData = {"B_N": np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} +s.to_vtk(OUT_DIR + "surf_init_DA", extra_data=pointData) + +btot.set_points(s_plot.gamma().reshape((-1, 3))) +pointData = {"B_N": np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} +s_plot.to_vtk(OUT_DIR + "surf_full_init_DA", extra_data=pointData) +btot.set_points(s.gamma().reshape((-1, 3))) + +# Repeat for whole B field +pointData = {"B_N": np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} +s.to_vtk(OUT_DIR + "surf_init", extra_data=pointData) + +btot.set_points(s_plot.gamma().reshape((-1, 3))) +pointData = {"B_N": np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} +s_plot.to_vtk(OUT_DIR + "surf_full_init", extra_data=pointData) +btot.set_points(s.gamma().reshape((-1, 3))) + +# Define the individual terms objective function: +Jf = SquaredFlux(s, btot) +# Separate length penalties on the dipole coils and the TF coils +# since they have very different sizes +# Jls = [CurveLength(c) for c in base_curves] +Jls_TF = [CurveLength(c) for c in base_curves_TF] +Jlength = QuadraticPenalty(sum(Jls_TF), LENGTH_TARGET, "max") + +# coil-coil and coil-plasma distances should be between all coils + +### Jcc below removed the dipoles! +Jccdist = CurveCurveDistance(curves_TF, CC_THRESHOLD, num_basecurves=len(coils_TF)) +Jcsdist = CurveSurfaceDistance(curves_TF, s, CS_THRESHOLD) +Jcsdist2 = CurveSurfaceDistance(curves_TF, s, CS_THRESHOLD) + +# While the coil array is not moving around, they cannot +# interlink. +linkNum = LinkingNumber(curves_TF) + +##### Note need coils_TF + coils below!!!!!!! +# Jforce2 = sum([LpCurveForce(c, coils_TF, +# regularization=regularization_rect(base_a_list[i], base_b_list[i]), +# p=2, threshold=1e8) for i, c in enumerate(base_coils + base_coils_TF)]) +# Jforce2 = LpCurveForce2(coils, coils_TF, p=2, threshold=1e8) +# Jforce = sum([SquaredMeanForce(c, coils_TF) for c in (base_coils + base_coils_TF)]) + +# regularization_list = np.zeros(len(coils)) * regularization_rect(aa, bb) +# regularization_list2 = np.zeros(len(coils_TF)) * regularization_rect(a, b) +# Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2) # [SquaredMeanForce2(c, coils) for c in (base_coils)] +# Jforce = MixedSquaredMeanForce(coils, coils_TF) +Jforce = sum([LpCurveForce(c, coils_TF, regularization_rect(a, b), p=2, threshold=1e5 * 100) for i, c in enumerate(base_coils_TF)]) +# Jforce2 = sum([SquaredMeanForce(c, coils_TF) for c in (base_coils + base_coils_TF)]) +# Jtorque = sum([LpCurveTorque(c, coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=1e5 * 40) for i, c in enumerate(base_coils + base_coils_TF)]) +# Jtorque2 = sum([SquaredMeanTorque(c, coils_TF) for c in (base_coils + base_coils_TF)]) + +# Jtorque = SquaredMeanTorque2(coils, coils_TF) # [SquaredMeanForce2(c, coils) for c in (base_coils)] +# Jtorque = [SquaredMeanTorque(c, coils_TF) for c in (base_coils + base_coils_TF)] + +JF = Jf \ + + CC_WEIGHT * Jccdist \ + + CS_WEIGHT * Jcsdist \ + + LINK_WEIGHT * linkNum \ + + LENGTH_WEIGHT * Jlength + +if FORCE_WEIGHT.value > 0.0: + JF += FORCE_WEIGHT.value * Jforce #\ + +# if FORCE_WEIGHT2.value > 0.0: +# JF += FORCE_WEIGHT2.value * Jforce2 #\ + +# if TORQUE_WEIGHT.value > 0.0: +# JF += TORQUE_WEIGHT * Jtorque + +# if TORQUE_WEIGHT2.value > 0.0: +# JF += TORQUE_WEIGHT2 * Jtorque2 + # + TVE_WEIGHT * Jtve + # + SF_WEIGHT * Jsf + # + CURRENTS_WEIGHT * DipoleCurrentsObj + # + CURVATURE_WEIGHT * sum(Jcs_TF) \ + # + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD) for J in Jmscs_TF) \ +# + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD) for J in Jmscs) \ + # + CURVATURE_WEIGHT * sum(Jcs) \ + +# We don't have a general interface in SIMSOPT for optimisation problems that +# are not in least-squares form, so we write a little wrapper function that we +# pass directly to scipy.optimize.minimize + +import pstats, io +from pstats import SortKey +# print(btot.ancestors,len(btot.ancestors)) +# print(JF.ancestors,len(JF.ancestors)) + + +def fun(dofs): + JF.x = dofs + # pr = cProfile.Profile() + # pr.enable() + J = JF.J() + grad = JF.dJ() + # pr.disable() + # sio = io.StringIO() + # sortby = SortKey.CUMULATIVE + # ps = pstats.Stats(pr, stream=sio).sort_stats(sortby) + # ps.print_stats(20) + # print(sio.getvalue()) + # exit() + jf = Jf.J() + length_val = LENGTH_WEIGHT.value * Jlength.J() + cc_val = CC_WEIGHT * Jccdist.J() + cs_val = CS_WEIGHT * Jcsdist.J() + link_val = LINK_WEIGHT * linkNum.J() + forces_val = FORCE_WEIGHT.value * Jforce.J() + # forces_val2 = FORCE_WEIGHT2.value * Jforce2.J() + # torques_val = TORQUE_WEIGHT.value * Jtorque.J() + # torques_val2 = TORQUE_WEIGHT2.value * Jtorque2.J() + BdotN = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))) + BdotN_over_B = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)) + ) / np.mean(btot.AbsB()) + outstr = f"J={J:.1e}, Jf={jf:.1e}, ⟨B·n⟩={BdotN:.1e}, ⟨B·n⟩/⟨B⟩={BdotN_over_B:.1e}" + valuestr = f"J={J:.2e}, Jf={jf:.2e}" + cl_string = ", ".join([f"{J.J():.1f}" for J in Jls_TF]) + outstr += f", Len=sum([{cl_string}])={sum(J.J() for J in Jls_TF):.2f}" + valuestr += f", LenObj={length_val:.2e}" + valuestr += f", ccObj={cc_val:.2e}" + valuestr += f", csObj={cs_val:.2e}" + valuestr += f", Lk1Obj={link_val:.2e}" + valuestr += f", forceObj={forces_val:.2e}" + # valuestr += f", forceObj2={forces_val2:.2e}" + # valuestr += f", torqueObj={torques_val:.2e}" + # valuestr += f", torqueObj2={torques_val2:.2e}" + outstr += f", F={Jforce.J():.2e}" + # outstr += f", Fnet={Jforce2.J():.2e}" + # outstr += f", T={Jtorque.J():.2e}" + # outstr += f", Tnet={Jtorque2.J():.2e}" + outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist2.shortest_distance():.2f}" + outstr += f", Link Number = {linkNum.J()}" + outstr += f", ║∇J║={np.linalg.norm(grad):.1e}" + print(outstr) + print(valuestr) + return J, grad + + +print(""" +################################################################################ +### Perform a Taylor test ###################################################### +################################################################################ +""") +f = fun +dofs = JF.x +np.random.seed(1) +h = np.random.uniform(size=dofs.shape) +J0, dJ0 = f(dofs) +dJh = sum(dJ0 * h) +for eps in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]: + t1 = time.time() + J1, _ = f(dofs + eps*h) + J2, _ = f(dofs - eps*h) + t2 = time.time() + print("err", (J1-J2)/(2*eps) - dJh) + +print(""" +################################################################################ +### Run the optimisation ####################################################### +################################################################################ +""") + + +print('Timing calls: ') +t1 = time.time() +Jf.J() +t2 = time.time() +print('Jf time = ', t2 - t1, ' s') +t1 = time.time() +Jf.dJ() +t2 = time.time() +print('dJf time = ', t2 - t1, ' s') +t1 = time.time() +Jccdist.J() +Jccdist.dJ() +t2 = time.time() +print('Jcc time = ', t2 - t1, ' s') +t1 = time.time() +Jcsdist.J() +Jcsdist.dJ() +t2 = time.time() +print('Jcs time = ', t2 - t1, ' s') +t1 = time.time() +linkNum.J() +linkNum.dJ() +t2 = time.time() +print('linkNum time = ', t2 - t1, ' s') +t1 = time.time() +Jlength.J() +Jlength.dJ() +t2 = time.time() +print('sum(Jls_TF) time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce.J() +# t2 = time.time() +# print('Jforces time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce.dJ() +# t2 = time.time() +# print('dJforces time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce2.J() +# t2 = time.time() +# print('Jforces2 time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce2.dJ() +# t2 = time.time() +# print('dJforces2 time = ', t2 - t1, ' s') +# t1 = time.time() +# Jtorque.J() +# t2 = time.time() +# print('Jtorques time = ', t2 - t1, ' s') +# t1 = time.time() +# Jtorque.dJ() +# t2 = time.time() +# print('dJtorques time = ', t2 - t1, ' s') + +n_saves = 1 +MAXITER = 400 +for i in range(1, n_saves + 1): + print('Iteration ' + str(i) + ' / ' + str(n_saves)) + res = minimize(fun, dofs, jac=True, method='L-BFGS-B', + options={'maxiter': MAXITER, 'maxcor': 400}, tol=1e-15) + # dofs = res.x + + curves_to_vtk( + [c.curve for c in bs_TF.coils], + OUT_DIR + "curves_TF_{0:d}".format(i), + close=True, + extra_point_data=pointData_forces_torques(coils_TF, coils_TF, np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b, np.ones(len(coils_TF)) * nturns_TF), + I=[c.current.get_value() for c in bs_TF.coils], + NetForces=coil_net_forces(coils_TF, coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF), + NetTorques=coil_net_torques(coils_TF, coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF), + ) + + btot.set_points(s_plot.gamma().reshape((-1, 3))) + pointData = {"B_N": np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} + s_plot.to_vtk(OUT_DIR + "surf_full_{0:d}".format(i), extra_data=pointData) + + pointData = {"B_N / B": (np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2 + ) / np.linalg.norm(btot.B().reshape(qphi, qtheta, 3), axis=-1))[:, :, None]} + s_plot.to_vtk(OUT_DIR + "surf_full_normalizedBn_{0:d}".format(i), extra_data=pointData) + + btot.set_points(s.gamma().reshape((-1, 3))) + calculate_on_axis_B(btot, s) + # LENGTH_WEIGHT *= 0.01 + # JF = Jf \ + # + CC_WEIGHT * Jccdist \ + # + CS_WEIGHT * Jcsdist \ + # + LINK_WEIGHT * linkNum \ + # + LINK_WEIGHT2 * linkNum2 \ + # + LENGTH_WEIGHT * sum(Jls_TF) + + +t2 = time.time() +print('Total time = ', t2 - t1) +btot.save(OUT_DIR + "biot_savart_optimized_QH.json") +print(OUT_DIR) + diff --git a/examples/3_Advanced/QH_reactorscale_notfixed.py b/examples/3_Advanced/QH_reactorscale_notfixed.py new file mode 100644 index 000000000..05429ba7e --- /dev/null +++ b/examples/3_Advanced/QH_reactorscale_notfixed.py @@ -0,0 +1,643 @@ +#!/usr/bin/env python +r""" +""" + +import os +import shutil +from pathlib import Path +import time +import numpy as np +from scipy.optimize import minimize +from simsopt.field import BiotSavart, Current, coils_via_symmetries +# from simsopt.field import CoilCoilNetForces, CoilCoilNetTorques, \ +# TotalVacuumEnergy, CoilSelfNetForces, CoilCoilNetForces12, CoilCoilNetTorques12 +from simsopt.field import regularization_rect +from simsopt.field.force import MeanSquaredForce, coil_force, coil_torque, coil_net_torques, coil_net_forces, LpCurveForce, \ + SquaredMeanForce, \ + MeanSquaredTorque, SquaredMeanTorque, LpCurveTorque, MixedSquaredMeanForce, MixedLpCurveForce +from simsopt.util import calculate_on_axis_B +from simsopt.geo import ( + CurveLength, CurveCurveDistance, + MeanSquaredCurvature, LpCurveCurvature, CurveSurfaceDistance, LinkingNumber, + SurfaceRZFourier, curves_to_vtk, create_equally_spaced_planar_curves, + create_planar_curves_between_two_toroidal_surfaces +) +from simsopt.objectives import Weight, SquaredFlux, QuadraticPenalty +from simsopt.util import in_github_actions +import cProfile +import re + +t1 = time.time() + +# Number of Fourier modes describing each Cartesian component of each coil: +order = 0 + +# File for the desired boundary magnetic surface: +TEST_DIR = (Path(__file__).parent / ".." / ".." / "tests" / "test_files").resolve() +input_name = 'input.LandremanPaul2021_QH_reactorScale_lowres' +filename = TEST_DIR / input_name + +# Initialize the boundary magnetic surface: +range_param = "half period" +nphi = 32 +ntheta = 32 +poff = 2.0 +coff = 3.0 +s = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi, ntheta=ntheta) +s_inner = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4) +s_outer = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4) + +# Make the inner and outer surfaces by extending the plasma surface +s_inner.extend_via_normal(poff) +s_outer.extend_via_normal(poff + coff) + +qphi = nphi * 2 +qtheta = ntheta * 2 +quadpoints_phi = np.linspace(0, 1, qphi, endpoint=True) +quadpoints_theta = np.linspace(0, 1, qtheta, endpoint=True) + +# Make high resolution, full torus version of the plasma boundary for plotting +s_plot = SurfaceRZFourier.from_vmec_input( + filename, + quadpoints_phi=quadpoints_phi, + quadpoints_theta=quadpoints_theta +) + +### Initialize some TF coils +def initialize_coils_QH(TEST_DIR, s): + """ + Initializes coils for each of the target configurations that are + used for permanent magnet optimization. + + Args: + config_flag: String denoting the stellarator configuration + being initialized. + TEST_DIR: String denoting where to find the input files. + out_dir: Path or string for the output directory for saved files. + s: plasma boundary surface. + Returns: + base_curves: List of CurveXYZ class objects. + curves: List of Curve class objects. + coils: List of Coil class objects. + """ + from simsopt.geo import create_equally_spaced_curves + from simsopt.field import Current, Coil, coils_via_symmetries + from simsopt.geo import curves_to_vtk + + # generate planar TF coils + ncoils = 2 + R0 = s.get_rc(0, 0) * 1 + R1 = s.get_rc(1, 0) * 4 + order = 8 + + from simsopt.mhd.vmec import Vmec + vmec_file = 'wout_LandremanPaul2021_QH_reactorScale_lowres_reference.nc' + total_current = Vmec(TEST_DIR / vmec_file).external_current() / (2 * s.nfp) / 1.4 + print('Total current = ', total_current) + + # Only need Jax flag for CurvePlanarFourier class + base_curves = create_equally_spaced_curves( + ncoils, s.nfp, stellsym=True, + R0=R0, R1=R1, order=order, numquadpoints=256, + jax_flag=True, + ) + + base_currents = [(Current(total_current / ncoils * 1e-7) * 1e7) for _ in range(ncoils - 1)] + # base_currents = [(Current(total_current / ncoils * 1e-7) * 1e7) for _ in range(ncoils)] + # base_currents[0].fix_all() + + total_current = Current(total_current) + total_current.fix_all() + base_currents += [total_current - sum(base_currents)] + coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) + # for c in coils: + # c.current.fix_all() + # c.curve.fix_all() + + # Initialize the coil curves and save the data to vtk + curves = [c.curve for c in coils] + currents = [c.current.get_value() for c in coils] + return base_curves, curves, coils, base_currents + +# initialize the coils +base_curves_TF, curves_TF, coils_TF, currents_TF = initialize_coils_QH(TEST_DIR, s) +num_TF_unique_coils = len(coils_TF) // 4 +base_coils_TF = coils_TF[:num_TF_unique_coils] +currents_TF = np.array([coil.current.get_value() for coil in coils_TF]) + +# # Set up BiotSavart fields +bs_TF = BiotSavart(coils_TF) + +# # Calculate average, approximate on-axis B field strength +calculate_on_axis_B(bs_TF, s) + +# wire cross section for the TF coils is a square 20 cm x 20 cm +# Only need this if make self forces and TVE nonzero in the objective! +a = 0.2 +b = 0.2 +nturns = 100 +nturns_TF = 200 + +# wire cross section for the dipole coils should be more like 5 cm x 5 cm +aa = 0.05 +bb = 0.05 + +Nx = 6 +Ny = Nx +Nz = Nx +# Create the initial coils: +base_curves, all_curves = create_planar_curves_between_two_toroidal_surfaces( + s, s_inner, s_outer, Nx, Ny, Nz, order=order, coil_coil_flag=True, jax_flag=True, + # numquadpoints=10 # Defaults is (order + 1) * 40 so this halves it +) +import warnings + +keep_inds = [] +for ii in range(len(base_curves)): + counter = 0 + for i in range(base_curves[0].gamma().shape[0]): + eps = 0.05 + for j in range(len(base_curves_TF)): + for k in range(base_curves_TF[j].gamma().shape[0]): + dij = np.sqrt(np.sum((base_curves[ii].gamma()[i, :] - base_curves_TF[j].gamma()[k, :]) ** 2)) + conflict_bool = (dij < (1.0 + eps) * base_curves[0].x[0]) + if conflict_bool: + print('bad indices = ', i, j, dij, base_curves[0].x[0]) + warnings.warn( + 'There is a PSC coil initialized such that it is within a radius' + 'of a TF coil. Deleting these PSCs now.') + counter += 1 + break + if counter == 0: + keep_inds.append(ii) + +print(keep_inds) +base_curves = np.array(base_curves)[keep_inds] + +ncoils = len(base_curves) +print('Ncoils = ', ncoils) +coil_normals = np.zeros((ncoils, 3)) +plasma_points = s.gamma().reshape(-1, 3) +plasma_unitnormals = s.unitnormal().reshape(-1, 3) +for i in range(ncoils): + point = (base_curves[i].get_dofs()[-3:]) + dists = np.sum((point - plasma_points) ** 2, axis=-1) + min_ind = np.argmin(dists) + coil_normals[i, :] = plasma_unitnormals[min_ind, :] + # coil_normals[i, :] = (plasma_points[min_ind, :] - point) +coil_normals = coil_normals / np.linalg.norm(coil_normals, axis=-1)[:, None] +# alphas = np.arctan2( +# -coil_normals[:, 1], +# np.sqrt(coil_normals[:, 0] ** 2 + coil_normals[:, 2] ** 2)) +# deltas = np.arcsin(coil_normals[:, 0] / \ +# np.sqrt(coil_normals[:, 0] ** 2 + coil_normals[:, 2] ** 2)) +alphas = np.arcsin( + -coil_normals[:, 1], + ) +deltas = np.arctan2(coil_normals[:, 0], coil_normals[:, 2]) +for i in range(len(base_curves)): + alpha2 = alphas[i] / 2.0 + delta2 = deltas[i] / 2.0 + calpha2 = np.cos(alpha2) + salpha2 = np.sin(alpha2) + cdelta2 = np.cos(delta2) + sdelta2 = np.sin(delta2) + base_curves[i].set('x' + str(2 * order + 1), calpha2 * cdelta2) + base_curves[i].set('x' + str(2 * order + 2), salpha2 * cdelta2) + base_curves[i].set('x' + str(2 * order + 3), calpha2 * sdelta2) + base_curves[i].set('x' + str(2 * order + 4), -salpha2 * sdelta2) + # Fix orientations of each coil + base_curves[i].fix('x' + str(2 * order + 1)) + base_curves[i].fix('x' + str(2 * order + 2)) + base_curves[i].fix('x' + str(2 * order + 3)) + base_curves[i].fix('x' + str(2 * order + 4)) + + # Fix shape of each coil + for j in range(2 * order + 1): + base_curves[i].fix('x' + str(j)) + # Fix center points of each coil + base_curves[i].fix('x' + str(2 * order + 5)) + base_curves[i].fix('x' + str(2 * order + 6)) + base_curves[i].fix('x' + str(2 * order + 7)) +base_currents = [Current(1e-1) * 2e7 for i in range(ncoils)] + +coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) +base_coils = coils[:ncoils] + +def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): + contig = np.ascontiguousarray + forces = np.zeros((len(coils), len(coils[0].curve.gamma()) + 1, 3)) + torques = np.zeros((len(coils), len(coils[0].curve.gamma()) + 1, 3)) + for i, c in enumerate(coils): + aprime = aprimes[i] + bprime = bprimes[i] + # print(np.shape(bs._coils), np.shape(coils)) + # B_other = BiotSavart([cc for j, cc in enumerate(bs._coils) if i != j]).set_points(c.curve.gamma()).B() + # print(np.shape(bs._coils)) + # B_other = bs.set_points(c.curve.gamma()).B() + # print(B_other) + # exit() + forces[i, :-1, :] = coil_force(c, allcoils, regularization_rect(aprime, bprime), nturns_list[i]) + # print(i, forces[i, :-1, :]) + # bs._coils = coils + torques[i, :-1, :] = coil_torque(c, allcoils, regularization_rect(aprime, bprime), nturns_list[i]) + + forces[:, -1, :] = forces[:, 0, :] + torques[:, -1, :] = torques[:, 0, :] + forces = forces.reshape(-1, 3) + torques = torques.reshape(-1, 3) + point_data = {"Pointwise_Forces": (contig(forces[:, 0]), contig(forces[:, 1]), contig(forces[:, 2])), + "Pointwise_Torques": (contig(torques[:, 0]), contig(torques[:, 1]), contig(torques[:, 2]))} + return point_data + +bs = BiotSavart(coils) # + coils_TF) +btot = bs + bs_TF +calculate_on_axis_B(btot, s) +btot.set_points(s.gamma().reshape((-1, 3))) +bs.set_points(s.gamma().reshape((-1, 3))) +curves = [c.curve for c in coils] +currents = [c.current.get_value() for c in coils] +a_list = np.hstack((np.ones(len(coils)) * aa, np.ones(len(coils_TF)) * a)) +b_list = np.hstack((np.ones(len(coils)) * bb, np.ones(len(coils_TF)) * b)) +base_a_list = np.hstack((np.ones(len(base_coils)) * aa, np.ones(len(base_coils_TF)) * a)) +base_b_list = np.hstack((np.ones(len(base_coils)) * bb, np.ones(len(base_coils_TF)) * b)) + +LENGTH_WEIGHT = Weight(0.001) +LENGTH_TARGET = 80 +LINK_WEIGHT = 1e3 +CC_THRESHOLD = 0.8 +CC_WEIGHT = 1e1 +CS_THRESHOLD = 1.5 +CS_WEIGHT = 1e2 +# Weight for the Coil Coil forces term +FORCE_WEIGHT = Weight(1e-34) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT2 = Weight(4e-27) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +TORQUE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# Directory for output +OUT_DIR = ("./QH_debug/") +if os.path.exists(OUT_DIR): + shutil.rmtree(OUT_DIR) +os.makedirs(OUT_DIR, exist_ok=True) + +curves_to_vtk( + curves_TF, + OUT_DIR + "curves_TF_0", + close=True, + extra_point_data=pointData_forces_torques(coils_TF, coils + coils_TF, np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b, np.ones(len(coils_TF)) * nturns_TF), + I=currents_TF, + NetForces=coil_net_forces(coils_TF, coils + coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF), + NetTorques=coil_net_torques(coils_TF, coils + coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF) +) +curves_to_vtk( + curves, + OUT_DIR + "curves_0", + close=True, + extra_point_data=pointData_forces_torques(coils, coils + coils_TF, np.ones(len(coils)) * aa, np.ones(len(coils)) * bb, np.ones(len(coils)) * nturns), + I=currents, + NetForces=coil_net_forces(coils, coils + coils_TF, regularization_rect(np.ones(len(coils)) * aa, np.ones(len(coils)) * bb), np.ones(len(coils)) * nturns), + NetTorques=coil_net_torques(coils, coils + coils_TF, regularization_rect(np.ones(len(coils)) * aa, np.ones(len(coils)) * bb), np.ones(len(coils)) * nturns) +) +# Force and Torque calculations spawn a bunch of spurious BiotSavart child objects -- erase them! +for c in (coils + coils_TF): + c._children = set() + +pointData = {"B_N": np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} +s.to_vtk(OUT_DIR + "surf_init_DA", extra_data=pointData) + +btot.set_points(s_plot.gamma().reshape((-1, 3))) +pointData = {"B_N": np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} +s_plot.to_vtk(OUT_DIR + "surf_full_init_DA", extra_data=pointData) +btot.set_points(s.gamma().reshape((-1, 3))) + +# Repeat for whole B field +pointData = {"B_N": np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} +s.to_vtk(OUT_DIR + "surf_init", extra_data=pointData) + +btot.set_points(s_plot.gamma().reshape((-1, 3))) +pointData = {"B_N": np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} +s_plot.to_vtk(OUT_DIR + "surf_full_init", extra_data=pointData) +btot.set_points(s.gamma().reshape((-1, 3))) + +# Define the individual terms objective function: +Jf = SquaredFlux(s, btot) +# Separate length penalties on the dipole coils and the TF coils +# since they have very different sizes +# Jls = [CurveLength(c) for c in base_curves] +Jls_TF = [CurveLength(c) for c in base_curves_TF] +Jlength = QuadraticPenalty(sum(Jls_TF), LENGTH_TARGET, "max") + +# coil-coil and coil-plasma distances should be between all coils + +### Jcc below removed the dipoles! +Jccdist = CurveCurveDistance(curves_TF, CC_THRESHOLD, num_basecurves=len(coils_TF)) +Jcsdist = CurveSurfaceDistance(curves_TF, s, CS_THRESHOLD) +Jcsdist2 = CurveSurfaceDistance(curves + curves_TF, s, CS_THRESHOLD) + +# While the coil array is not moving around, they cannot +# interlink. +linkNum = LinkingNumber(curves + curves_TF, downsample=2) + +##### Note need coils_TF + coils below!!!!!!! +# Jforce2 = sum([LpCurveForce(c, coils + coils_TF, +# regularization=regularization_rect(base_a_list[i], base_b_list[i]), +# p=2, threshold=1e8) for i, c in enumerate(base_coils + base_coils_TF)]) +# Jforce2 = LpCurveForce2(coils, coils_TF, p=2, threshold=1e8) +# Jforce = sum([SquaredMeanForce(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) + +regularization_list = np.ones(len(coils)) * regularization_rect(aa, bb) +regularization_list2 = np.ones(len(coils_TF)) * regularization_rect(a, b) +# Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2) # [SquaredMeanForce2(c, coils) for c in (base_coils)] +# Jforce = MixedSquaredMeanForce(coils, coils_TF) +# Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2, p=4, threshold=4e5 * 100, downsample=4) +Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=4, threshold=4e5 * 100, downsample=4) for i, c in enumerate(base_coils + base_coils_TF)]) +# Jforce2 = sum([SquaredMeanForce(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) +# Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=4e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) +# Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=1e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) + + +# Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils_TF)]) + +# Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) + + +# Jtorque = SquaredMeanTorque2(coils, coils_TF) # [SquaredMeanForce2(c, coils) for c in (base_coils)] +# Jtorque = [SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)] + +JF = Jf \ + + CC_WEIGHT * Jccdist \ + + CS_WEIGHT * Jcsdist \ + + LENGTH_WEIGHT * Jlength \ + + LINK_WEIGHT * linkNum + +if FORCE_WEIGHT.value > 0.0: + JF += FORCE_WEIGHT.value * Jforce #\ + +# if FORCE_WEIGHT2.value > 0.0: +# JF += FORCE_WEIGHT2.value * Jforce2 #\ + +# if TORQUE_WEIGHT.value > 0.0: +# JF += TORQUE_WEIGHT * Jtorque + +# if TORQUE_WEIGHT2.value > 0.0: +# JF += TORQUE_WEIGHT2 * Jtorque2 + # + TVE_WEIGHT * Jtve + # + SF_WEIGHT * Jsf + # + CURRENTS_WEIGHT * DipoleCurrentsObj + # + CURVATURE_WEIGHT * sum(Jcs_TF) \ + # + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD) for J in Jmscs_TF) \ +# + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD) for J in Jmscs) \ + # + CURVATURE_WEIGHT * sum(Jcs) \ + +# We don't have a general interface in SIMSOPT for optimisation problems that +# are not in least-squares form, so we write a little wrapper function that we +# pass directly to scipy.optimize.minimize + + +print('Timing calls: ') +t1 = time.time() +Jf.J() +t2 = time.time() +print('Jf time = ', t2 - t1, ' s') +t1 = time.time() +Jf.dJ() +t2 = time.time() +print('dJf time = ', t2 - t1, ' s') +t1 = time.time() +Jccdist.J() +Jccdist.dJ() +t2 = time.time() +print('Jcc time = ', t2 - t1, ' s') +t1 = time.time() +Jcsdist.J() +Jcsdist.dJ() +t2 = time.time() +print('Jcs time = ', t2 - t1, ' s') +t1 = time.time() +linkNum.J() +linkNum.dJ() +t2 = time.time() +print('linkNum time = ', t2 - t1, ' s') +t1 = time.time() +Jlength.J() +Jlength.dJ() +t2 = time.time() +print('sum(Jls_TF) time = ', t2 - t1, ' s') +t1 = time.time() +print(Jforce.J()) +print(Jforce.dJ()) +t2 = time.time() +print('Jforce time = ', t2 - t1, ' s') +t1 = time.time() +JF.J() +JF.dJ() +t2 = time.time() +print('JF time = ', t2 - t1, ' s') +import pstats, io +from pstats import SortKey +# print(btot.ancestors,len(btot.ancestors)) +# print(JF.ancestors,len(JF.ancestors)) + + +def fun(dofs): + pr = cProfile.Profile() + pr.enable() + JF.x = dofs + J = JF.J() + grad = JF.dJ() + jf = Jf.J() + length_val = LENGTH_WEIGHT.value * Jlength.J() + cc_val = CC_WEIGHT * Jccdist.J() + cs_val = CS_WEIGHT * Jcsdist.J() + link_val = LINK_WEIGHT * linkNum.J() + forces_val = FORCE_WEIGHT.value * Jforce.J() + # Jforce3.dJ() + # forces_val2 = FORCE_WEIGHT2.value * Jforce2.J() + # torques_val = TORQUE_WEIGHT.value * Jtorque.J() + # torques_val2 = TORQUE_WEIGHT2.value * Jtorque2.J() + BdotN = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))) + BdotN_over_B = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)) + ) / np.mean(btot.AbsB()) + outstr = f"J={J:.1e}, Jf={jf:.1e}, ⟨B·n⟩={BdotN:.1e}, ⟨B·n⟩/⟨B⟩={BdotN_over_B:.1e}" + valuestr = f"J={J:.2e}, Jf={jf:.2e}" + cl_string = ", ".join([f"{J.J():.1f}" for J in Jls_TF]) + outstr += f", Len=sum([{cl_string}])={sum(J.J() for J in Jls_TF):.2f}" + valuestr += f", LenObj={length_val:.2e}" + valuestr += f", ccObj={cc_val:.2e}" + valuestr += f", csObj={cs_val:.2e}" + valuestr += f", Lk1Obj={link_val:.2e}" + valuestr += f", forceObj={forces_val:.2e}" + # valuestr += f", forceObj2={forces_val2:.2e}" + # valuestr += f", torqueObj={torques_val:.2e}" + # valuestr += f", torqueObj2={torques_val2:.2e}" + outstr += f", F={Jforce.J():.2e}" + # outstr += f", Fnet={Jforce2.J():.2e}" + # outstr += f", T={Jtorque.J():.2e}" + # outstr += f", Tnet={Jtorque2.J():.2e}" + outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist2.shortest_distance():.2f}" + outstr += f", Link Number = {linkNum.J()}" + outstr += f", ║∇J║={np.linalg.norm(grad):.1e}" + print(coils[0]._children, coils_TF[0]._children, JF._children, Jforce._children) + print(outstr) + print(valuestr) + pr.disable() + sio = io.StringIO() + sortby = SortKey.CUMULATIVE + ps = pstats.Stats(pr, stream=sio).sort_stats(sortby) + ps.print_stats(20) + print(sio.getvalue()) + # exit() + return J, grad + + +print(""" +################################################################################ +### Perform a Taylor test ###################################################### +################################################################################ +""") +f = fun +dofs = JF.x +np.random.seed(1) +h = np.random.uniform(size=dofs.shape) + +print(""" +Calling f now +""") +J0, dJ0 = f(dofs) +dJh = sum(dJ0 * h) +for eps in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]: + t1 = time.time() + J1, _ = f(dofs + eps*h) + J2, _ = f(dofs - eps*h) + t2 = time.time() + print("err", (J1-J2)/(2*eps) - dJh) + +print(""" +################################################################################ +### Run the optimisation ####################################################### +################################################################################ +""") + + +print('Timing calls: ') +t1 = time.time() +Jf.J() +t2 = time.time() +print('Jf time = ', t2 - t1, ' s') +t1 = time.time() +Jf.dJ() +t2 = time.time() +print('dJf time = ', t2 - t1, ' s') +t1 = time.time() +Jccdist.J() +Jccdist.dJ() +t2 = time.time() +print('Jcc time = ', t2 - t1, ' s') +t1 = time.time() +Jcsdist.J() +Jcsdist.dJ() +t2 = time.time() +print('Jcs time = ', t2 - t1, ' s') +t1 = time.time() +linkNum.J() +linkNum.dJ() +t2 = time.time() +print('linkNum time = ', t2 - t1, ' s') +t1 = time.time() +Jlength.J() +Jlength.dJ() +t2 = time.time() +print('sum(Jls_TF) time = ', t2 - t1, ' s') +t1 = time.time() +print(Jforce.J()) +print(Jforce.dJ()) +t2 = time.time() +print('Jforce time = ', t2 - t1, ' s') +t1 = time.time() +JF.J() +JF.dJ() +t2 = time.time() +print('JF time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce.J() +# t2 = time.time() +# print('Jforces time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce.dJ() +# t2 = time.time() +# print('dJforces time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce2.J() +# t2 = time.time() +# print('Jforces2 time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce2.dJ() +# t2 = time.time() +# print('dJforces2 time = ', t2 - t1, ' s') +# t1 = time.time() +# Jtorque.J() +# t2 = time.time() +# print('Jtorques time = ', t2 - t1, ' s') +# t1 = time.time() +# Jtorque.dJ() +# t2 = time.time() +# print('dJtorques time = ', t2 - t1, ' s') + +n_saves = 1 +MAXITER = 200 +for i in range(1, n_saves + 1): + print('Iteration ' + str(i) + ' / ' + str(n_saves)) + res = minimize(fun, dofs, jac=True, method='L-BFGS-B', + options={'maxiter': MAXITER, 'maxcor': 200}, tol=1e-15) + # dofs = res.x + + dipole_currents = [c.current.get_value() for c in bs.coils] + curves_to_vtk( + [c.curve for c in bs.coils], + OUT_DIR + "curves_{0:d}".format(i), + close=True, + extra_point_data=pointData_forces_torques(coils, coils + coils_TF, np.ones(len(coils)) * aa, np.ones(len(coils)) * bb, np.ones(len(coils)) * nturns), + I=dipole_currents, + NetForces=coil_net_forces(coils, coils + coils_TF, regularization_rect(np.ones(len(coils)) * aa, np.ones(len(coils)) * bb), np.ones(len(coils)) * nturns), + NetTorques=coil_net_torques(coils, coils + coils_TF, regularization_rect(np.ones(len(coils)) * aa, np.ones(len(coils)) * bb), np.ones(len(coils)) * nturns), + ) + curves_to_vtk( + [c.curve for c in bs_TF.coils], + OUT_DIR + "curves_TF_{0:d}".format(i), + close=True, + extra_point_data=pointData_forces_torques(coils_TF, coils + coils_TF, np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b, np.ones(len(coils_TF)) * nturns_TF), + I=[c.current.get_value() for c in bs_TF.coils], + NetForces=coil_net_forces(coils_TF, coils + coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF), + NetTorques=coil_net_torques(coils_TF, coils + coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF), + ) + + btot.set_points(s_plot.gamma().reshape((-1, 3))) + pointData = {"B_N": np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} + s_plot.to_vtk(OUT_DIR + "surf_full_{0:d}".format(i), extra_data=pointData) + + pointData = {"B_N / B": (np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2 + ) / np.linalg.norm(btot.B().reshape(qphi, qtheta, 3), axis=-1))[:, :, None]} + s_plot.to_vtk(OUT_DIR + "surf_full_normalizedBn_{0:d}".format(i), extra_data=pointData) + + btot.set_points(s.gamma().reshape((-1, 3))) + print('Max I = ', np.max(np.abs(dipole_currents))) + print('Min I = ', np.min(np.abs(dipole_currents))) + calculate_on_axis_B(btot, s) + # LENGTH_WEIGHT *= 0.01 + # JF = Jf \ + # + CC_WEIGHT * Jccdist \ + # + CS_WEIGHT * Jcsdist \ + # + LINK_WEIGHT * linkNum \ + # + LINK_WEIGHT2 * linkNum2 \ + # + LENGTH_WEIGHT * sum(Jls_TF) + + +t2 = time.time() +print('Total time = ', t2 - t1) +btot.save(OUT_DIR + "biot_savart_optimized_QH.json") +print(OUT_DIR) + diff --git a/src/simsopt/field/biotsavart.py b/src/simsopt/field/biotsavart.py index 8cfe5571d..0cf2001ec 100644 --- a/src/simsopt/field/biotsavart.py +++ b/src/simsopt/field/biotsavart.py @@ -34,7 +34,7 @@ def __init__(self, coils): self._coils = coils sopp.BiotSavart.__init__(self, coils) MagneticField.__init__(self, depends_on=coils) - self.B_vjp_jax = jit(lambda v: self.B_vjp_pure(v)) + # self.B_vjp_jax = jit(lambda v: self.B_vjp_pure(v)) def dB_by_dcoilcurrents(self, compute_derivatives=0): points = self.get_points_cart_ref() diff --git a/src/simsopt/field/force.py b/src/simsopt/field/force.py index f544484f9..b8bc6c286 100644 --- a/src/simsopt/field/force.py +++ b/src/simsopt/field/force.py @@ -39,9 +39,13 @@ def coil_force(coil, allcoils, regularization, nturns=1): ### Line below seems to be the issue -- all these BiotSavart objects seem to stick ### around and not to go out of scope after these calls! - mutual_field = BiotSavart(mutual_coils).set_points(coil.curve.gamma()).B() - mutualforce = np.cross(coil.current.get_value() * tangent, mutual_field) + mutual_field = BiotSavart(mutual_coils).set_points(coil.curve.gamma()) + B_mutual = mutual_field.B() + mutualforce = np.cross(coil.current.get_value() * tangent, B_mutual) selfforce = self_force(coil, regularization) + mutual_field._children = set() + for c in mutual_coils: + c._children = set() return (selfforce + mutualforce) / nturns def coil_net_forces(coils, allcoils, regularization, nturns=None): @@ -96,8 +100,8 @@ def self_force_rect(coil, a, b): return self_force(coil, regularization_rect(a, b)) -@jit -def lp_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold): +# @jit +def lp_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold, downsample): r"""Pure function for minimizing the Lorentz force on a coil. The function is @@ -108,13 +112,17 @@ def lp_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regulari where :math:`\vec{F}` is the Lorentz force, :math:`F_0` is a threshold force, and :math:`\ell` is arclength along the coil. """ - + B_mutual = B_mutual[::downsample, :] + gamma = gamma[::downsample, :] + gammadash = gammadash[::downsample, :] + gammadashdash = gammadashdash[::downsample, :] + quadpoints = quadpoints[::downsample] B_self = B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization) gammadash_norm = jnp.linalg.norm(gammadash, axis=1)[:, None] tangent = gammadash / gammadash_norm force = jnp.cross(current * tangent, B_self + B_mutual) force_norm = jnp.linalg.norm(force, axis=1)[:, None] - return (jnp.sum(jnp.maximum(force_norm - threshold, 0)**p * gammadash_norm)) * (1. / p) + return (jnp.sum(jnp.maximum(force_norm - threshold, 0)**p * gammadash_norm)) * (1. / p) / jnp.shape(gamma)[0] class LpCurveForce(Optimizable): @@ -129,81 +137,107 @@ class LpCurveForce(Optimizable): and :math:`\ell` is arclength along the coil. """ - def __init__(self, coil, allcoils, regularization, p=2.0, threshold=0.0): + def __init__(self, coil, allcoils, regularization, p=2.0, threshold=0.0, downsample=1): self.coil = coil self.allcoils = allcoils self.othercoils = [c for c in allcoils if c is not coil] - self.biotsavart = BiotSavart(self.othercoils) + # self.biotsavart = BiotSavart(self.othercoils) quadpoints = self.coil.curve.quadpoints + self.downsample = downsample + args = {"static_argnums": (5,)} self.J_jax = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - lp_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + lp_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold, downsample), + **args ) self.dJ_dgamma = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) self.dJ_dgammadash = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) self.dJ_dgammadashdash = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) self.dJ_dcurrent = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=3)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=3)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) self.dJ_dB_mutual = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=4)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=4)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) super().__init__(depends_on=allcoils) def J(self): - self.biotsavart.set_points(self.coil.curve.gamma()) + biotsavart = BiotSavart(self.othercoils) + biotsavart.set_points(self.coil.curve.gamma()) + # biotsavart._children = set() args = [ self.coil.curve.gamma(), self.coil.curve.gammadash(), self.coil.curve.gammadashdash(), self.coil.current.get_value(), - self.biotsavart.B() + biotsavart.B(), + self.downsample ] + # biotsavart._children = set() + for c in self.othercoils: + c._children = set() + c.curve._children = set() + c.current._children = set() return self.J_jax(*args) @derivative_dec def dJ(self): - self.biotsavart.set_points(self.coil.curve.gamma()) + biotsavart = BiotSavart(self.othercoils) + biotsavart.set_points(self.coil.curve.gamma()) + # biotsavart._children = set() args = [ self.coil.curve.gamma(), self.coil.curve.gammadash(), self.coil.curve.gammadashdash(), self.coil.current.get_value(), - self.biotsavart.B() + biotsavart.B(), + self.downsample ] dJ_dB = self.dJ_dB_mutual(*args) - dB_dX = self.biotsavart.dB_by_dX() + dB_dX = biotsavart.dB_by_dX() dJ_dX = np.einsum('ij,ikj->ik', dJ_dB, dB_dX) - return ( + dJ = ( self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args) + dJ_dX) + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args)) + self.coil.curve.dgammadashdash_by_dcoeff_vjp(self.dJ_dgammadashdash(*args)) + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args)])) - + self.biotsavart.B_vjp(dJ_dB) + + biotsavart.B_vjp(dJ_dB) ) + # biotsavart._children = set() + for c in self.othercoils: + c._children = set() + c.curve._children = set() + c.current._children = set() + + return dJ return_fn_map = {'J': J, 'dJ': dJ} @@ -543,21 +577,38 @@ def dJ(self): return_fn_map = {'J': J, 'dJ': dJ} - -@jit +# @jit def mixed_lp_force_pure(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, quadpoints, quadpoints2, - currents, currents2, regularizations, regularizations2, p, threshold): + currents, currents2, regularizations, regularizations2, p, threshold, + downsample=1, + ): r""" """ - B_self = [B_regularized_pure(gammas[i], gammadashs[i], gammadashdashs[i], quadpoints, currents[i], regularizations[i]) for i in range(jnp.shape(gammas)[0])] - B_self2 = [B_regularized_pure(gammas2[i], gammadashs2[i], gammadashdashs2[i], quadpoints2, currents2[i], regularizations2[i]) for i in range(jnp.shape(gammas2)[0])] + # print(jnp.shape(quadpoints), jnp.shape(gammas), jnp.shape(gammadashs), jnp.shape(gammadashdashs)) + # print(jnp.shape(quadpoints2), jnp.shape(gammas2), jnp.shape(gammadashs2), jnp.shape(gammadashdashs2)) + quadpoints = quadpoints[::downsample] + gammas = gammas[:, ::downsample, :] + gammadashs = gammadashs[:, ::downsample, :] + gammadashdashs = gammadashdashs[:, ::downsample, :] + + quadpoints2 = quadpoints2[::downsample] + gammas2 = gammas2[:, ::downsample, :] + gammadashs2 = gammadashs2[:, ::downsample, :] + gammadashdashs2 = gammadashdashs2[:, ::downsample, :] + # print(jnp.shape(quadpoints), jnp.shape(gammas), jnp.shape(gammadashs), jnp.shape(gammadashdashs)) + # print(jnp.shape(quadpoints2), jnp.shape(gammas2), jnp.shape(gammadashs2), jnp.shape(gammadashdashs2)) + + B_self = jnp.array([B_regularized_pure(gammas[i], gammadashs[i], gammadashdashs[i], quadpoints, + currents[i], regularizations[i]) for i in range(jnp.shape(gammas)[0])]) + B_self2 = jnp.array([B_regularized_pure(gammas2[i], gammadashs2[i], gammadashdashs2[i], quadpoints2, + currents2[i], regularizations2[i]) for i in range(jnp.shape(gammas2)[0])]) gammadash_norms = jnp.linalg.norm(gammadashs, axis=-1)[:, :, None] tangents = gammadashs / gammadash_norms gammadash_norms2 = jnp.linalg.norm(gammadashs2, axis=-1)[:, :, None] tangents2 = gammadashs2 / gammadash_norms2 - selfforce = jnp.array([jnp.cross(currents[i] * tangents[i], B_self[i]) for i in range(jnp.shape(gammas)[0])]) - selfforce2 = jnp.array([jnp.cross(currents2[i] * tangents2[i], B_self2[i]) for i in range(jnp.shape(gammas2)[0])]) + # selfforce = jnp.array([jnp.cross(currents[i] * tangents[i], B_self[i]) for i in range(jnp.shape(gammas)[0])]) + # selfforce2 = jnp.array([jnp.cross(currents2[i] * tangents2[i], B_self2[i]) for i in range(jnp.shape(gammas2)[0])]) eps = 1e-10 # small number to avoid blow up in the denominator when i = j r_ij = gammas[:, None, :, None, :] - gammas[None, :, None, :, :] # Note, do not use the i = j indices @@ -565,7 +616,8 @@ def mixed_lp_force_pure(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs ### Note that need to do dl1 x dl2 x r12 here instead of just (dl1 * dl2)r12 # because these are not equivalent expressions if we are squaring the pointwise forces # before integration over coil i! - cross_prod = jnp.cross(tangents[:, None, :, None, :], jnp.cross(gammadashs[None, :, None, :, :], r_ij)) + # cross_prod = jnp.cross(tangents[:, None, :, None, :], jnp.cross(gammadashs[None, :, None, :, :], r_ij)) + cross_prod = jnp.cross(gammadashs[None, :, None, :, :], r_ij) rij_norm3 = jnp.linalg.norm(r_ij + eps, axis=-1) ** 3 Ii_Ij = currents[:, None] * currents[None, :] Ii_Ij = Ii_Ij.at[:, :].add(-jnp.diag(jnp.diag(Ii_Ij))) @@ -573,29 +625,32 @@ def mixed_lp_force_pure(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs # repeat with gamma, gamma2 r_ij = gammas[:, None, :, None, :] - gammas2[None, :, None, :, :] # Note, do not use the i = j indices - cross_prod = jnp.cross(tangents[:, None, :, None, :], jnp.cross(gammadashs2[None, :, None, :, :], r_ij)) + # cross_prod = jnp.cross(tangents[:, None, :, None, :], jnp.cross(gammadashs2[None, :, None, :, :], r_ij)) + cross_prod = jnp.cross(gammadashs2[None, :, None, :, :], r_ij) rij_norm3 = jnp.linalg.norm(r_ij + eps, axis=-1) ** 3 Ii_Ij = currents[:, None] * currents2[None, :] F += jnp.sum(Ii_Ij[:, :, None, None] * jnp.sum(cross_prod / rij_norm3[:, :, :, :, None], axis=3), axis=1) / jnp.shape(gammas2)[1] - force_norm = jnp.linalg.norm(F * 1e-7 + selfforce, axis=-1) - summ = jnp.sum(jnp.maximum(force_norm[:, :, None] - threshold, 0) ** p * gammadash_norms) + force_norm = jnp.linalg.norm(jnp.cross(tangents, F * 1e-7 + currents[:, None, None] * B_self), axis=-1) + summ = jnp.sum(jnp.maximum(force_norm[:, :, None] - threshold, 0) ** p * gammadash_norms) / jnp.shape(gammas)[1] # repeat with gamma2, gamma r_ij = gammas2[:, None, :, None, :] - gammas[None, :, None, :, :] # Note, do not use the i = j indices - cross_prod = jnp.cross(tangents2[:, None, :, None, :], jnp.cross(gammadashs[None, :, None, :, :], r_ij)) + cross_prod = jnp.cross(gammadashs[None, :, None, :, :], r_ij) + # cross_prod = jnp.cross(tangents2[:, None, :, None, :], jnp.cross(gammadashs[None, :, None, :, :], r_ij)) rij_norm3 = jnp.linalg.norm(r_ij + eps, axis=-1) ** 3 Ii_Ij = currents2[:, None] * currents[None, :] F = jnp.sum(Ii_Ij[:, :, None, None] * jnp.sum(cross_prod / rij_norm3[:, :, :, :, None], axis=3), axis=1) / jnp.shape(gammas)[1] # repeat with gamma2, gamma2 r_ij = gammas2[:, None, :, None, :] - gammas2[None, :, None, :, :] # Note, do not use the i = j indices - cross_prod = jnp.cross(tangents2[:, None, :, None, :], jnp.cross(gammadashs2[None, :, None, :, :], r_ij)) + cross_prod = jnp.cross(gammadashs2[None, :, None, :, :], r_ij) + # cross_prod = jnp.cross(tangents2[:, None, :, None, :], jnp.cross(gammadashs2[None, :, None, :, :], r_ij)) rij_norm3 = jnp.linalg.norm(r_ij + eps, axis=-1) ** 3 Ii_Ij = currents2[:, None] * currents2[None, :] Ii_Ij = Ii_Ij.at[:, :].add(-jnp.diag(jnp.diag(Ii_Ij))) F += jnp.sum(Ii_Ij[:, :, None, None] * jnp.sum(cross_prod / rij_norm3[:, :, :, :, None], axis=3), axis=1) / jnp.shape(gammas2)[1] - force_norm2 = jnp.linalg.norm(F * 1e-7 + selfforce2, axis=-1) - summ += jnp.sum(jnp.maximum(force_norm2[:, :, None] - threshold, 0) ** p * gammadash_norms2) + force_norm2 = jnp.linalg.norm(jnp.cross(tangents2, F * 1e-7 + currents2[:, None, None] * B_self2), axis=-1) + summ += jnp.sum(jnp.maximum(force_norm2[:, :, None] - threshold, 0) ** p * gammadash_norms2) / jnp.shape(gammas2)[1] return summ * (1 / p) class MixedLpCurveForce(Optimizable): @@ -610,56 +665,66 @@ class MixedLpCurveForce(Optimizable): along the coil. """ - def __init__(self, allcoils, allcoils2, regularizations, regularizations2, p=2.0, threshold=0.0): + def __init__(self, allcoils, allcoils2, regularizations, regularizations2, p=2.0, threshold=0.0, downsample=1): self.allcoils = allcoils self.allcoils2 = allcoils2 quadpoints = self.allcoils[0].curve.quadpoints quadpoints2 = self.allcoils2[0].curve.quadpoints + self.downsample = downsample + args = {"static_argnums": (8,)} self.J_jax = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: mixed_lp_force_pure(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, quadpoints, quadpoints2, - currents, currents2, regularizations, regularizations2, p, threshold) + currents, currents2, regularizations, regularizations2, p, threshold, downsample), + **args ) self.dJ_dgamma = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=0)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=0)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) self.dJ_dgamma2 = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=1)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=1)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) self.dJ_dgammadash = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=2)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=2)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) self.dJ_dgammadash2 = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=3)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=3)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) self.dJ_dgammadashdash = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=4)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=4)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) self.dJ_dgammadashdash2 = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=5)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=5)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) self.dJ_dcurrent = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=6)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=6)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) - self.dJ_dcurrent2 = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=7)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=7)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) super().__init__(depends_on=(allcoils + allcoils2)) @@ -676,6 +741,7 @@ def J(self): jnp.array([c.curve.gammadashdash() for c in self.allcoils2]), jnp.array([c.current.get_value() for c in self.allcoils]), jnp.array([c.current.get_value() for c in self.allcoils2]), + self.downsample ] return self.J_jax(*args) @@ -692,6 +758,7 @@ def dJ(self): jnp.array([c.curve.gammadashdash() for c in self.allcoils2]), jnp.array([c.current.get_value() for c in self.allcoils]), jnp.array([c.current.get_value() for c in self.allcoils2]), + self.downsample ] dJ_dgamma = self.dJ_dgamma(*args) dJ_dgammadash = self.dJ_dgammadash(*args) @@ -709,7 +776,7 @@ def dJ(self): + sum([c.current.vjp(jnp.asarray([dJ_dcurrent[i]])) for i, c in enumerate(self.allcoils)]) + sum([c.curve.dgamma_by_dcoeff_vjp(dJ_dgamma2[i]) for i, c in enumerate(self.allcoils2)]) + sum([c.curve.dgammadash_by_dcoeff_vjp(dJ_dgammadash2[i]) for i, c in enumerate(self.allcoils2)]) - + sum([c.curve.dgammadashdash_by_dcoeff_vjp(dJ_dgammadashdash[i]) for i, c in enumerate(self.allcoils2)]) + + sum([c.curve.dgammadashdash_by_dcoeff_vjp(dJ_dgammadashdash2[i]) for i, c in enumerate(self.allcoils2)]) + sum([c.current.vjp(jnp.asarray([dJ_dcurrent2[i]])) for i, c in enumerate(self.allcoils2)]) ) diff --git a/src/simsopt/field/selffield.py b/src/simsopt/field/selffield.py index 46b9a003c..ab6fff0c0 100755 --- a/src/simsopt/field/selffield.py +++ b/src/simsopt/field/selffield.py @@ -55,8 +55,8 @@ def B_regularized_singularity_term(rc_prime, rc_prime_prime, regularization): def B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization): # The factors of 2π in the next few lines come from the fact that simsopt # uses a curve parameter that goes up to 1 rather than 2π. - phi = quadpoints * 2 * jnp.pi - rc = gamma + phi = quadpoints * 2 * jnp.pi + rc = gamma rc_prime = gammadash / 2 / jnp.pi rc_prime_prime = gammadashdash / 4 / jnp.pi**2 n_quad = phi.shape[0] @@ -69,6 +69,11 @@ def B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, reg second_term = jnp.cross(rc_prime_prime, rc_prime)[:, None, :] * ( 0.5 * cos_fac / (cos_fac * jnp.sum(rc_prime * rc_prime, axis=1)[:, None] + regularization)**1.5)[:, :, None] integral_term = dphi * jnp.sum(first_term + second_term, 1) + # print(jnp.any(jnp.isnan(first_term))) + # print(jnp.any(jnp.isnan(second_term))) + # print(jnp.any(jnp.isnan(integral_term))) + # print(jnp.any(jnp.isnan(analytic_term))) + # print(jnp.max(jnp.abs(first_term))) # print(jnp.max(jnp.abs(second_term))) # print(jnp.max(jnp.abs(integral_term))) diff --git a/src/simsopt/geo/curve.py b/src/simsopt/geo/curve.py index 6fa97b6c4..9bc79341e 100644 --- a/src/simsopt/geo/curve.py +++ b/src/simsopt/geo/curve.py @@ -470,6 +470,15 @@ def gamma_impl(self, gamma, quadpoints): def incremental_arclength_pure(self, dofs): gammadash = self.gammadash_jax(dofs) return jnp.linalg.norm(gammadash, axis=1) + + @property + def qps(self): + return self.quadpoints + + @qps.setter + def qps(self, new_quadpoints): + self.quadpoints = new_quadpoints + self.numquadpoints = len(new_quadpoints) def incremental_arclength(self): return self.incremental_arclength_jax(self.get_dofs()) diff --git a/src/simsopt/geo/curveplanarfourier.py b/src/simsopt/geo/curveplanarfourier.py index 3c71c28f2..7c5d16143 100644 --- a/src/simsopt/geo/curveplanarfourier.py +++ b/src/simsopt/geo/curveplanarfourier.py @@ -159,6 +159,8 @@ def center(self, gamma, gammadash): barycenter = jnp.sum(gamma * arclength[:, None], axis=0) / N / np.pi return barycenter + def set_quadpoints(self, quadpoints): + self.quadpoints = quadpoints # class JaxCurvePlanarFourier(JaxCurve): diff --git a/src/simsopt/geo/jit.py b/src/simsopt/geo/jit.py index 3f888f2f9..796fc150e 100644 --- a/src/simsopt/geo/jit.py +++ b/src/simsopt/geo/jit.py @@ -3,8 +3,8 @@ from .config import parameters -def jit(fun): +def jit(fun, **args): if parameters['jit']: - return jaxjit(fun) + return jaxjit(fun, **args) else: return fun diff --git a/tests/field/test_selffieldforces.py b/tests/field/test_selffieldforces.py index 342f18c7f..49265558e 100644 --- a/tests/field/test_selffieldforces.py +++ b/tests/field/test_selffieldforces.py @@ -211,7 +211,7 @@ def test_force_objectives(self): gammadash_norm = np.linalg.norm(coils[0].curve.gammadash(), axis=1) force_norm = np.linalg.norm(coil_force(coils[0], coils, regularization), axis=1) print("force_norm mean:", np.mean(force_norm), "max:", np.max(force_norm)) - objective_alt = (1 / p) * np.sum(np.maximum(force_norm - threshold, 0)**p * gammadash_norm) + objective_alt = (1 / p) * np.sum(np.maximum(force_norm - threshold, 0)**p * gammadash_norm) / np.shape(gammadash_norm)[0] print("objective:", objective, "objective_alt:", objective_alt, "diff:", objective - objective_alt) np.testing.assert_allclose(objective, objective_alt) @@ -260,22 +260,36 @@ def test_force_objectives(self): # # Test MixedLpCurveForce objective = 0.0 + objective2 = 0.0 objective_alt = 0.0 for i in range(len(coils)): objective += float(LpCurveForce(coils[i], coils, regularization, p=p, threshold=threshold).J()) + objective2 += float(LpCurveForce(coils[i], coils, regularization, p=p, threshold=threshold, downsample=2).J()) force_norm = np.linalg.norm(coil_force(coils[i], coils, regularization), axis=1) gammadash_norm = np.linalg.norm(coils[i].curve.gammadash(), axis=1) - objective_alt += (1 / p) * np.sum(np.maximum(force_norm - threshold, 0)**p * gammadash_norm) + objective_alt += (1 / p) * np.sum(np.maximum(force_norm - threshold, 0)**p * gammadash_norm) / np.shape(gammadash_norm)[0] regularization_list = np.ones(len(coils)) * regularization - objective_mixed = float(MixedLpCurveForce(coils[0:1], coils[1:], regularization_list[0:1], regularization_list[1:], p=p, threshold=threshold).J()) + objective_mixed = float(MixedLpCurveForce(coils[0:1], coils[1:], regularization_list[0:1], + regularization_list[1:], p=p, threshold=threshold).J()) print("objective:", objective, "objective_alt:", objective_alt, "diff:", objective - objective_alt) np.testing.assert_allclose(objective, objective_alt) + + print("objective:", objective, "objective2:", objective2, "diff:", objective - objective2) + np.testing.assert_allclose(objective, objective2, rtol=1e-2) + print("objective:", objective, "objective_mixed:", objective_mixed, "diff:", objective - objective_mixed) np.testing.assert_allclose(objective, objective_mixed) + objective_mixed = float(MixedLpCurveForce(coils[0:1], coils[1:], + regularization_list[0:1], regularization_list[1:], p=p, + threshold=threshold, downsample=2).J()) + + print("objective:", objective, "objective_mixed:", objective_mixed, "diff:", objective - objective_mixed) + np.testing.assert_allclose(objective, objective_mixed, rtol=1e-2) + # Test SquaredMeanTorque # Scramble the orientations so the torques are nonzero @@ -459,6 +473,7 @@ def objectives_time_test(self): def test_update_points(self): """Confirm that Biot-Savart evaluation points are updated when the curve shapes change.""" + from simsopt.field import BiotSavart nfp = 4 ncoils = 3 I = 1.7e4 @@ -472,7 +487,8 @@ def test_update_points(self): objective = objective_class(coils[0], coils, regularization) old_objective_value = objective.J() - old_biot_savart_points = objective.biotsavart.get_points_cart() + biotsavart = BiotSavart(objective.othercoils) + old_biot_savart_points = biotsavart.get_points_cart() # A deterministic random shift to the coil dofs: shift = np.array([-0.06797948, -0.0808704 , -0.02680599, -0.02775893, -0.0325402 , @@ -488,7 +504,8 @@ def test_update_points(self): objective.x = objective.x + shift assert abs(objective.J() - old_objective_value) > 1e-6 - new_biot_savart_points = objective.biotsavart.get_points_cart() + biotsavart = BiotSavart(objective.othercoils) + new_biot_savart_points = biotsavart.get_points_cart() assert not np.allclose(old_biot_savart_points, new_biot_savart_points) # Objective2 is created directly at the new points after they are moved: objective2 = objective_class(coils[0], coils, regularization)