diff --git a/examples/3_Advanced/QA_single_TFcoil.py b/examples/3_Advanced/QA_single_TFcoil.py index bfad035fc..783aa5164 100644 --- a/examples/3_Advanced/QA_single_TFcoil.py +++ b/examples/3_Advanced/QA_single_TFcoil.py @@ -7,6 +7,7 @@ from pathlib import Path import time import numpy as np +import warnings from scipy.optimize import minimize from simsopt.field import BiotSavart, Current, coils_via_symmetries # from simsopt.field import CoilCoilNetForces, CoilCoilNetTorques, \ @@ -41,8 +42,8 @@ range_param = "half period" nphi = 32 ntheta = 32 -poff = 1.5 -coff = 2.0 +poff = 1.75 +coff = 2.5 s = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi, ntheta=ntheta) s_inner = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4) s_outer = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4) @@ -88,7 +89,7 @@ def initialize_coils_QA(TEST_DIR, s): ncoils = 1 R0 = s.get_rc(0, 0) * 2 R1 = s.get_rc(1, 0) * 10 - order = 4 + order = 10 from simsopt.mhd.vmec import Vmec vmec_file = 'wout_LandremanPaul2021_QA_reactorScale_lowres_reference.nc' @@ -141,7 +142,7 @@ def initialize_coils_QA(TEST_DIR, s): aa = 0.05 bb = 0.05 -Nx = 3 +Nx = 6 Ny = Nx Nz = Nx # Create the initial coils: @@ -149,25 +150,98 @@ def initialize_coils_QA(TEST_DIR, s): s, s_inner, s_outer, Nx, Ny, Nz, order=order, coil_coil_flag=True, jax_flag=True, # numquadpoints=10 # Defaults is (order + 1) * 40 so this halves it ) +# ncoils = len(base_curves) +# print('Ncoils = ', ncoils) +# for i in range(len(base_curves)): +# # base_curves[i].set('x' + str(2 * order + 1), np.random.rand(1) - 0.5) +# # base_curves[i].set('x' + str(2 * order + 2), np.random.rand(1) - 0.5) +# # base_curves[i].set('x' + str(2 * order + 3), np.random.rand(1) - 0.5) +# # base_curves[i].set('x' + str(2 * order + 4), np.random.rand(1) - 0.5) + +# # Fix shape of each coil +# for j in range(2 * order + 1): +# base_curves[i].fix('x' + str(j)) +# # Fix center points of each coil +# # base_curves[i].fix('x' + str(2 * order + 5)) +# # base_curves[i].fix('x' + str(2 * order + 6)) +# # base_curves[i].fix('x' + str(2 * order + 7)) +# base_currents = [Current(1e-1) * 2e7 for i in range(ncoils)] +# # Fix currents in each coil +# # for i in range(ncoils): +# # base_currents[i].fix_all() + +# coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) +# base_coils = coils[:ncoils] + + +keep_inds = [] +for ii in range(len(base_curves)): + counter = 0 + for i in range(base_curves[0].gamma().shape[0]): + eps = 0.05 + for j in range(len(base_curves_TF)): + for k in range(base_curves_TF[j].gamma().shape[0]): + dij = np.sqrt(np.sum((base_curves[ii].gamma()[i, :] - base_curves_TF[j].gamma()[k, :]) ** 2)) + conflict_bool = (dij < (1.0 + eps) * base_curves[0].x[0]) + if conflict_bool: + print('bad indices = ', i, j, dij, base_curves[0].x[0]) + warnings.warn( + 'There is a PSC coil initialized such that it is within a radius' + 'of a TF coil. Deleting these PSCs now.') + counter += 1 + break + if counter == 0: + keep_inds.append(ii) + +print(keep_inds) +base_curves = np.array(base_curves)[keep_inds] + ncoils = len(base_curves) print('Ncoils = ', ncoils) +coil_normals = np.zeros((ncoils, 3)) +plasma_points = s.gamma().reshape(-1, 3) +plasma_unitnormals = s.unitnormal().reshape(-1, 3) +for i in range(ncoils): + point = (base_curves[i].get_dofs()[-3:]) + dists = np.sum((point - plasma_points) ** 2, axis=-1) + min_ind = np.argmin(dists) + coil_normals[i, :] = plasma_unitnormals[min_ind, :] + # coil_normals[i, :] = (plasma_points[min_ind, :] - point) +coil_normals = coil_normals / np.linalg.norm(coil_normals, axis=-1)[:, None] +# alphas = np.arctan2( +# -coil_normals[:, 1], +# np.sqrt(coil_normals[:, 0] ** 2 + coil_normals[:, 2] ** 2)) +# deltas = np.arcsin(coil_normals[:, 0] / \ +# np.sqrt(coil_normals[:, 0] ** 2 + coil_normals[:, 2] ** 2)) +alphas = np.arcsin( + -coil_normals[:, 1], + ) +deltas = np.arctan2(coil_normals[:, 0], coil_normals[:, 2]) for i in range(len(base_curves)): - # base_curves[i].set('x' + str(2 * order + 1), np.random.rand(1) - 0.5) - # base_curves[i].set('x' + str(2 * order + 2), np.random.rand(1) - 0.5) - # base_curves[i].set('x' + str(2 * order + 3), np.random.rand(1) - 0.5) - # base_curves[i].set('x' + str(2 * order + 4), np.random.rand(1) - 0.5) + alpha2 = alphas[i] / 2.0 + delta2 = deltas[i] / 2.0 + calpha2 = np.cos(alpha2) + salpha2 = np.sin(alpha2) + cdelta2 = np.cos(delta2) + sdelta2 = np.sin(delta2) + base_curves[i].set('x' + str(2 * order + 1), calpha2 * cdelta2) + base_curves[i].set('x' + str(2 * order + 2), salpha2 * cdelta2) + base_curves[i].set('x' + str(2 * order + 3), calpha2 * sdelta2) + base_curves[i].set('x' + str(2 * order + 4), -salpha2 * sdelta2) + # Fix orientations of each coil + base_curves[i].fix('x' + str(2 * order + 1)) + base_curves[i].fix('x' + str(2 * order + 2)) + base_curves[i].fix('x' + str(2 * order + 3)) + base_curves[i].fix('x' + str(2 * order + 4)) # Fix shape of each coil for j in range(2 * order + 1): base_curves[i].fix('x' + str(j)) # Fix center points of each coil - # base_curves[i].fix('x' + str(2 * order + 5)) - # base_curves[i].fix('x' + str(2 * order + 6)) - # base_curves[i].fix('x' + str(2 * order + 7)) + base_curves[i].fix('x' + str(2 * order + 5)) + base_curves[i].fix('x' + str(2 * order + 6)) + base_curves[i].fix('x' + str(2 * order + 7)) base_currents = [Current(1e-1) * 2e7 for i in range(ncoils)] -# Fix currents in each coil -# for i in range(ncoils): -# base_currents[i].fix_all() coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) base_coils = coils[:ncoils] @@ -210,24 +284,31 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): base_a_list = np.hstack((np.ones(len(base_coils)) * aa, np.ones(len(base_coils_TF)) * a)) base_b_list = np.hstack((np.ones(len(base_coils)) * bb, np.ones(len(base_coils_TF)) * b)) -LENGTH_WEIGHT = Weight(0.01) -LENGTH_TARGET = 90 # 90 works very well... can we get it down? -LINK_WEIGHT = 1e3 +LENGTH_WEIGHT = Weight(0.001) +LENGTH_TARGET = 80 # 90 works very well... can we get it down? +LINK_WEIGHT = 1e2 CC_THRESHOLD = 0.8 -CC_WEIGHT = 1e2 +CC_WEIGHT = 1e1 CS_THRESHOLD = 1.5 -CS_WEIGHT = 1e2 +CS_WEIGHT = 1e1 # Weight for the Coil Coil forces term -FORCE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +FORCE_WEIGHT = Weight(1e-37) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons -TORQUE_WEIGHT = Weight(1e-18) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +TORQUE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons + +# FORCE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons # Directory for output OUT_DIR = ("./QA_singleTF_n{:d}_p{:.2e}_c{:.2e}_lw{:.2e}_lt{:.2e}_lkw{:.2e}" + \ "_cct{:.2e}_ccw{:.2e}_cst{:.2e}_csw{:.2e}_fw{:.2e}_fww{:2e}_tw{:.2e}/").format( ncoils, poff, coff, LENGTH_WEIGHT.value, LENGTH_TARGET, LINK_WEIGHT, CC_THRESHOLD, CC_WEIGHT, CS_THRESHOLD, CS_WEIGHT, FORCE_WEIGHT.value, FORCE_WEIGHT2.value, - TORQUE_WEIGHT.value) + TORQUE_WEIGHT.value, + TORQUE_WEIGHT2.value) if os.path.exists(OUT_DIR): shutil.rmtree(OUT_DIR) os.makedirs(OUT_DIR, exist_ok=True) @@ -281,7 +362,8 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): # coil-coil and coil-plasma distances should be between all coils Jccdist = CurveCurveDistance(curves + curves_TF, CC_THRESHOLD, num_basecurves=len(coils + coils_TF)) -Jcsdist = CurveSurfaceDistance(curves + curves_TF, s, CS_THRESHOLD) +Jcsdist = CurveSurfaceDistance(curves_TF, s, CS_THRESHOLD) +Jcsdist2 = CurveSurfaceDistance(curves + curves_TF, s, CS_THRESHOLD) # While the coil array is not moving around, they cannot # interlink. @@ -298,9 +380,13 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): regularization_list2 = np.zeros(len(coils_TF)) * regularization_rect(a, b) # Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2) # [SquaredMeanForce2(c, coils) for c in (base_coils)] # Jforce = MixedSquaredMeanForce(coils, coils_TF) -Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=2e5 * 40) for i, c in enumerate(base_coils + base_coils_TF)]) + +###### NOTE JFORCE BELOW ONLY DOING THE DIPOLE COILS!!!! +Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=4, threshold=1e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) +# Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=4, threshold=4e5 * 100) for i, c in enumerate(base_coils)]) Jforce2 = sum([SquaredMeanForce(c, coils + coils_TF) for i, c in enumerate(base_coils + base_coils_TF)]) -Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i])) for i, c in enumerate(base_coils + base_coils_TF)]) +Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=4e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) +Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) # Jtorque = SquaredMeanTorque2(coils, coils_TF) # [SquaredMeanForce2(c, coils) for c in (base_coils)] # Jtorque = [SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)] @@ -317,17 +403,11 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): JF += FORCE_WEIGHT2.value * Jforce2 #\ if TORQUE_WEIGHT.value > 0.0: - JF += TORQUE_WEIGHT.value * Jtorque #\ - # + FORCE_WEIGHT2.value * Jforce2 - # + TORQUE_WEIGHT * Jtorque - # + TVE_WEIGHT * Jtve - # + SF_WEIGHT * Jsf - # + CURRENTS_WEIGHT * DipoleCurrentsObj - # + CURVATURE_WEIGHT * sum(Jcs_TF) \ - # + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD) for J in Jmscs_TF) \ -# + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD) for J in Jmscs) \ - # + CURVATURE_WEIGHT * sum(Jcs) \ + JF += TORQUE_WEIGHT * Jtorque +if TORQUE_WEIGHT2.value > 0.0: + JF += TORQUE_WEIGHT2 * Jtorque2 + # We don't have a general interface in SIMSOPT for optimisation problems that # are not in least-squares form, so we write a little wrapper function that we # pass directly to scipy.optimize.minimize @@ -359,6 +439,7 @@ def fun(dofs): forces_val = FORCE_WEIGHT.value * Jforce.J() forces_val2 = FORCE_WEIGHT2.value * Jforce2.J() torques_val = TORQUE_WEIGHT.value * Jtorque.J() + torques_val2 = TORQUE_WEIGHT2.value * Jtorque2.J() BdotN = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))) BdotN_over_B = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)) ) / np.mean(btot.AbsB()) @@ -373,10 +454,12 @@ def fun(dofs): valuestr += f", forceObj={forces_val:.2e}" valuestr += f", forceObj2={forces_val2:.2e}" valuestr += f", torqueObj={torques_val:.2e}" + valuestr += f", torqueObj2={torques_val2:.2e}" outstr += f", F={Jforce.J():.2e}" - outstr += f", Fpointwise={Jforce2.J():.2e}" + outstr += f", Fnet={Jforce2.J():.2e}" outstr += f", T={Jtorque.J():.2e}" - outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist.shortest_distance():.2f}" + outstr += f", Tnet={Jtorque2.J():.2e}" + outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist2.shortest_distance():.2f}" outstr += f", Link Number = {linkNum.J()}" outstr += f", ║∇J║={np.linalg.norm(grad):.1e}" print(outstr) @@ -438,33 +521,33 @@ def fun(dofs): Jlength.dJ() t2 = time.time() print('sum(Jls_TF) time = ', t2 - t1, ' s') -t1 = time.time() -Jforce.J() -t2 = time.time() -print('Jforces time = ', t2 - t1, ' s') -t1 = time.time() -Jforce.dJ() -t2 = time.time() -print('dJforces time = ', t2 - t1, ' s') -t1 = time.time() -Jforce2.J() -t2 = time.time() -print('Jforces2 time = ', t2 - t1, ' s') -t1 = time.time() -Jforce2.dJ() -t2 = time.time() -print('dJforces2 time = ', t2 - t1, ' s') -t1 = time.time() -Jtorque.J() -t2 = time.time() -print('Jtorques time = ', t2 - t1, ' s') -t1 = time.time() -Jtorque.dJ() -t2 = time.time() -print('dJtorques time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce.J() +# t2 = time.time() +# print('Jforces time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce.dJ() +# t2 = time.time() +# print('dJforces time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce2.J() +# t2 = time.time() +# print('Jforces2 time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce2.dJ() +# t2 = time.time() +# print('dJforces2 time = ', t2 - t1, ' s') +# t1 = time.time() +# Jtorque.J() +# t2 = time.time() +# print('Jtorques time = ', t2 - t1, ' s') +# t1 = time.time() +# Jtorque.dJ() +# t2 = time.time() +# print('dJtorques time = ', t2 - t1, ' s') n_saves = 1 -MAXITER = 400 +MAXITER = 200 for i in range(1, n_saves + 1): print('Iteration ' + str(i) + ' / ' + str(n_saves)) res = minimize(fun, dofs, jac=True, method='L-BFGS-B', diff --git a/examples/3_Advanced/QH_reactorscale_DA.py b/examples/3_Advanced/QH_reactorscale_DA.py index 2bf0f8eb1..5e4950ada 100644 --- a/examples/3_Advanced/QH_reactorscale_DA.py +++ b/examples/3_Advanced/QH_reactorscale_DA.py @@ -88,7 +88,7 @@ def initialize_coils_QH(TEST_DIR, s): ncoils = 2 R0 = s.get_rc(0, 0) * 1 R1 = s.get_rc(1, 0) * 4 - order = 7 + order = 8 from simsopt.mhd.vmec import Vmec vmec_file = 'wout_LandremanPaul2021_QH_reactorScale_lowres_reference.nc' @@ -207,18 +207,18 @@ def initialize_coils_QH(TEST_DIR, s): base_curves[i].set('x' + str(2 * order + 3), calpha2 * sdelta2) base_curves[i].set('x' + str(2 * order + 4), -salpha2 * sdelta2) # Fix orientations of each coil - base_curves[i].fix('x' + str(2 * order + 1)) - base_curves[i].fix('x' + str(2 * order + 2)) - base_curves[i].fix('x' + str(2 * order + 3)) - base_curves[i].fix('x' + str(2 * order + 4)) + # base_curves[i].fix('x' + str(2 * order + 1)) + # base_curves[i].fix('x' + str(2 * order + 2)) + # base_curves[i].fix('x' + str(2 * order + 3)) + # base_curves[i].fix('x' + str(2 * order + 4)) # Fix shape of each coil for j in range(2 * order + 1): base_curves[i].fix('x' + str(j)) # Fix center points of each coil - base_curves[i].fix('x' + str(2 * order + 5)) - base_curves[i].fix('x' + str(2 * order + 6)) - base_curves[i].fix('x' + str(2 * order + 7)) + # base_curves[i].fix('x' + str(2 * order + 5)) + # base_curves[i].fix('x' + str(2 * order + 6)) + # base_curves[i].fix('x' + str(2 * order + 7)) base_currents = [Current(1e-1) * 2e7 for i in range(ncoils)] coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) @@ -270,11 +270,11 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): CS_THRESHOLD = 1.5 CS_WEIGHT = 1e2 # Weight for the Coil Coil forces term -# FORCE_WEIGHT = Weight(1e-22) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT = Weight(1e-34) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons # FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons -# TORQUE_WEIGHT = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons -# TORQUE_WEIGHT2 = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons -FORCE_WEIGHT = Weight(1e-34) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT2 = Weight(4e-27) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +FORCE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons TORQUE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons @@ -355,15 +355,20 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): # Jforce2 = LpCurveForce2(coils, coils_TF, p=2, threshold=1e8) # Jforce = sum([SquaredMeanForce(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) -regularization_list = np.zeros(len(coils)) * regularization_rect(aa, bb) -regularization_list2 = np.zeros(len(coils_TF)) * regularization_rect(a, b) +regularization_list = np.ones(len(coils)) * regularization_rect(aa, bb) +regularization_list2 = np.ones(len(coils_TF)) * regularization_rect(a, b) # Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2) # [SquaredMeanForce2(c, coils) for c in (base_coils)] # Jforce = MixedSquaredMeanForce(coils, coils_TF) Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=4, threshold=4e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) Jforce2 = sum([SquaredMeanForce(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=4e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) # Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=1e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) -Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) + + +Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils_TF)]) + +# Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) + # Jtorque = SquaredMeanTorque2(coils, coils_TF) # [SquaredMeanForce2(c, coils) for c in (base_coils)] # Jtorque = [SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)] @@ -371,8 +376,8 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): JF = Jf \ + CC_WEIGHT * Jccdist \ + CS_WEIGHT * Jcsdist \ - + LINK_WEIGHT * linkNum \ - + LENGTH_WEIGHT * Jlength + + LENGTH_WEIGHT * Jlength \ + + LINK_WEIGHT * linkNum if FORCE_WEIGHT.value > 0.0: JF += FORCE_WEIGHT.value * Jforce #\ diff --git a/examples/3_Advanced/QH_reactorscale_nodipoles.py b/examples/3_Advanced/QH_reactorscale_nodipoles.py new file mode 100644 index 000000000..53d98e2b2 --- /dev/null +++ b/examples/3_Advanced/QH_reactorscale_nodipoles.py @@ -0,0 +1,469 @@ +#!/usr/bin/env python +r""" +""" + +import os +import shutil +from pathlib import Path +import time +import numpy as np +from scipy.optimize import minimize +from simsopt.field import BiotSavart, Current, coils_via_symmetries +# from simsopt.field import CoilCoilNetForces, CoilCoilNetTorques, \ +# TotalVacuumEnergy, CoilSelfNetForces, CoilCoilNetForces12, CoilCoilNetTorques12 +from simsopt.field import regularization_rect +from simsopt.field.force import MeanSquaredForce, coil_force, coil_torque, coil_net_torques, coil_net_forces, LpCurveForce, \ + SquaredMeanForce, \ + MeanSquaredTorque, SquaredMeanTorque, LpCurveTorque, MixedSquaredMeanForce, MixedLpCurveForce +from simsopt.util import calculate_on_axis_B +from simsopt.geo import ( + CurveLength, CurveCurveDistance, + MeanSquaredCurvature, LpCurveCurvature, CurveSurfaceDistance, LinkingNumber, + SurfaceRZFourier, curves_to_vtk, create_equally_spaced_planar_curves, + create_planar_curves_between_two_toroidal_surfaces +) +from simsopt.objectives import Weight, SquaredFlux, QuadraticPenalty +from simsopt.util import in_github_actions +import cProfile +import re + +t1 = time.time() + +# Number of Fourier modes describing each Cartesian component of each coil: +order = 0 + +# File for the desired boundary magnetic surface: +TEST_DIR = (Path(__file__).parent / ".." / ".." / "tests" / "test_files").resolve() +input_name = 'input.LandremanPaul2021_QH_reactorScale_lowres' +filename = TEST_DIR / input_name + +# Initialize the boundary magnetic surface: +range_param = "half period" +nphi = 32 +ntheta = 32 +poff = 2.0 +coff = 3.0 +s = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi, ntheta=ntheta) +s_inner = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4) +s_outer = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4) + +# Make the inner and outer surfaces by extending the plasma surface +s_inner.extend_via_normal(poff) +s_outer.extend_via_normal(poff + coff) + +qphi = nphi * 2 +qtheta = ntheta * 2 +quadpoints_phi = np.linspace(0, 1, qphi, endpoint=True) +quadpoints_theta = np.linspace(0, 1, qtheta, endpoint=True) + +# Make high resolution, full torus version of the plasma boundary for plotting +s_plot = SurfaceRZFourier.from_vmec_input( + filename, + quadpoints_phi=quadpoints_phi, + quadpoints_theta=quadpoints_theta +) + +### Initialize some TF coils +def initialize_coils_QH(TEST_DIR, s): + """ + Initializes coils for each of the target configurations that are + used for permanent magnet optimization. + + Args: + config_flag: String denoting the stellarator configuration + being initialized. + TEST_DIR: String denoting where to find the input files. + out_dir: Path or string for the output directory for saved files. + s: plasma boundary surface. + Returns: + base_curves: List of CurveXYZ class objects. + curves: List of Curve class objects. + coils: List of Coil class objects. + """ + from simsopt.geo import create_equally_spaced_curves + from simsopt.field import Current, Coil, coils_via_symmetries + from simsopt.geo import curves_to_vtk + + # generate planar TF coils + ncoils = 2 + R0 = s.get_rc(0, 0) * 1 + R1 = s.get_rc(1, 0) * 4 + order = 8 + + from simsopt.mhd.vmec import Vmec + vmec_file = 'wout_LandremanPaul2021_QH_reactorScale_lowres_reference.nc' + total_current = Vmec(TEST_DIR / vmec_file).external_current() / (2 * s.nfp) / 1.4 + print('Total current = ', total_current) + + # Only need Jax flag for CurvePlanarFourier class + base_curves = create_equally_spaced_curves( + ncoils, s.nfp, stellsym=True, + R0=R0, R1=R1, order=order, numquadpoints=256, + jax_flag=True, + ) + + base_currents = [(Current(total_current / ncoils * 1e-7) * 1e7) for _ in range(ncoils - 1)] + # base_currents = [(Current(total_current / ncoils * 1e-7) * 1e7) for _ in range(ncoils)] + # base_currents[0].fix_all() + + total_current = Current(total_current) + total_current.fix_all() + base_currents += [total_current - sum(base_currents)] + coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) + # for c in coils: + # c.current.fix_all() + # c.curve.fix_all() + + # Initialize the coil curves and save the data to vtk + curves = [c.curve for c in coils] + currents = [c.current.get_value() for c in coils] + return base_curves, curves, coils, base_currents + +# initialize the coils +base_curves_TF, curves_TF, coils_TF, currents_TF = initialize_coils_QH(TEST_DIR, s) +num_TF_unique_coils = len(coils_TF) // 4 +base_coils_TF = coils_TF[:num_TF_unique_coils] +currents_TF = np.array([coil.current.get_value() for coil in coils_TF]) + +# # Set up BiotSavart fields +bs_TF = BiotSavart(coils_TF) + +# # Calculate average, approximate on-axis B field strength +calculate_on_axis_B(bs_TF, s) + +# wire cross section for the TF coils is a square 20 cm x 20 cm +# Only need this if make self forces and TVE nonzero in the objective! +a = 0.2 +b = 0.2 +nturns = 100 +nturns_TF = 200 + +def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): + contig = np.ascontiguousarray + forces = np.zeros((len(coils), len(coils[0].curve.gamma()) + 1, 3)) + torques = np.zeros((len(coils), len(coils[0].curve.gamma()) + 1, 3)) + for i, c in enumerate(coils): + aprime = aprimes[i] + bprime = bprimes[i] + # print(np.shape(bs._coils), np.shape(coils)) + # B_other = BiotSavart([cc for j, cc in enumerate(bs._coils) if i != j]).set_points(c.curve.gamma()).B() + # print(np.shape(bs._coils)) + # B_other = bs.set_points(c.curve.gamma()).B() + # print(B_other) + # exit() + forces[i, :-1, :] = coil_force(c, allcoils, regularization_rect(aprime, bprime), nturns_list[i]) + # print(i, forces[i, :-1, :]) + # bs._coils = coils + torques[i, :-1, :] = coil_torque(c, allcoils, regularization_rect(aprime, bprime), nturns_list[i]) + + forces[:, -1, :] = forces[:, 0, :] + torques[:, -1, :] = torques[:, 0, :] + forces = forces.reshape(-1, 3) + torques = torques.reshape(-1, 3) + point_data = {"Pointwise_Forces": (contig(forces[:, 0]), contig(forces[:, 1]), contig(forces[:, 2])), + "Pointwise_Torques": (contig(torques[:, 0]), contig(torques[:, 1]), contig(torques[:, 2]))} + return point_data + +btot = bs_TF +calculate_on_axis_B(btot, s) +btot.set_points(s.gamma().reshape((-1, 3))) +# a_list = np.hstack((np.ones(len(coils)) * aa, np.ones(len(coils_TF)) * a)) +# b_list = np.hstack((np.ones(len(coils)) * bb, np.ones(len(coils_TF)) * b)) +# base_a_list = np.hstack((np.ones(len(base_coils)) * aa, np.ones(len(base_coils_TF)) * a)) +# base_b_list = np.hstack((np.ones(len(base_coils)) * bb, np.ones(len(base_coils_TF)) * b)) + +LENGTH_WEIGHT = Weight(0.001) +LENGTH_TARGET = 80 +LINK_WEIGHT = 1e3 +CC_THRESHOLD = 0.8 +CC_WEIGHT = 1e1 +CS_THRESHOLD = 1.5 +CS_WEIGHT = 1e2 +# Weight for the Coil Coil forces term +FORCE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT2 = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT = Weight(1e-22) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT2 = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# Directory for output +OUT_DIR = ("./QH_nodipoles/") +if os.path.exists(OUT_DIR): + shutil.rmtree(OUT_DIR) +os.makedirs(OUT_DIR, exist_ok=True) + +curves_to_vtk( + curves_TF, + OUT_DIR + "curves_TF_0", + close=True, + extra_point_data=pointData_forces_torques(coils_TF, coils_TF, np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b, np.ones(len(coils_TF)) * nturns_TF), + I=currents_TF, + NetForces=coil_net_forces(coils_TF, coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF), + NetTorques=coil_net_torques(coils_TF, coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF) +) +# Force and Torque calculations spawn a bunch of spurious BiotSavart child objects -- erase them! +for c in (coils_TF): + c._children = set() + +pointData = {"B_N": np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} +s.to_vtk(OUT_DIR + "surf_init_DA", extra_data=pointData) + +btot.set_points(s_plot.gamma().reshape((-1, 3))) +pointData = {"B_N": np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} +s_plot.to_vtk(OUT_DIR + "surf_full_init_DA", extra_data=pointData) +btot.set_points(s.gamma().reshape((-1, 3))) + +# Repeat for whole B field +pointData = {"B_N": np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} +s.to_vtk(OUT_DIR + "surf_init", extra_data=pointData) + +btot.set_points(s_plot.gamma().reshape((-1, 3))) +pointData = {"B_N": np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} +s_plot.to_vtk(OUT_DIR + "surf_full_init", extra_data=pointData) +btot.set_points(s.gamma().reshape((-1, 3))) + +# Define the individual terms objective function: +Jf = SquaredFlux(s, btot) +# Separate length penalties on the dipole coils and the TF coils +# since they have very different sizes +# Jls = [CurveLength(c) for c in base_curves] +Jls_TF = [CurveLength(c) for c in base_curves_TF] +Jlength = QuadraticPenalty(sum(Jls_TF), LENGTH_TARGET, "max") + +# coil-coil and coil-plasma distances should be between all coils + +### Jcc below removed the dipoles! +Jccdist = CurveCurveDistance(curves_TF, CC_THRESHOLD, num_basecurves=len(coils_TF)) +Jcsdist = CurveSurfaceDistance(curves_TF, s, CS_THRESHOLD) +Jcsdist2 = CurveSurfaceDistance(curves_TF, s, CS_THRESHOLD) + +# While the coil array is not moving around, they cannot +# interlink. +linkNum = LinkingNumber(curves_TF) + +##### Note need coils_TF + coils below!!!!!!! +# Jforce2 = sum([LpCurveForce(c, coils_TF, +# regularization=regularization_rect(base_a_list[i], base_b_list[i]), +# p=2, threshold=1e8) for i, c in enumerate(base_coils + base_coils_TF)]) +# Jforce2 = LpCurveForce2(coils, coils_TF, p=2, threshold=1e8) +# Jforce = sum([SquaredMeanForce(c, coils_TF) for c in (base_coils + base_coils_TF)]) + +# regularization_list = np.zeros(len(coils)) * regularization_rect(aa, bb) +# regularization_list2 = np.zeros(len(coils_TF)) * regularization_rect(a, b) +# Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2) # [SquaredMeanForce2(c, coils) for c in (base_coils)] +# Jforce = MixedSquaredMeanForce(coils, coils_TF) +Jforce = sum([LpCurveForce(c, coils_TF, regularization_rect(a, b), p=2, threshold=1e5 * 100) for i, c in enumerate(base_coils_TF)]) +# Jforce2 = sum([SquaredMeanForce(c, coils_TF) for c in (base_coils + base_coils_TF)]) +# Jtorque = sum([LpCurveTorque(c, coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=1e5 * 40) for i, c in enumerate(base_coils + base_coils_TF)]) +# Jtorque2 = sum([SquaredMeanTorque(c, coils_TF) for c in (base_coils + base_coils_TF)]) + +# Jtorque = SquaredMeanTorque2(coils, coils_TF) # [SquaredMeanForce2(c, coils) for c in (base_coils)] +# Jtorque = [SquaredMeanTorque(c, coils_TF) for c in (base_coils + base_coils_TF)] + +JF = Jf \ + + CC_WEIGHT * Jccdist \ + + CS_WEIGHT * Jcsdist \ + + LINK_WEIGHT * linkNum \ + + LENGTH_WEIGHT * Jlength + +if FORCE_WEIGHT.value > 0.0: + JF += FORCE_WEIGHT.value * Jforce #\ + +# if FORCE_WEIGHT2.value > 0.0: +# JF += FORCE_WEIGHT2.value * Jforce2 #\ + +# if TORQUE_WEIGHT.value > 0.0: +# JF += TORQUE_WEIGHT * Jtorque + +# if TORQUE_WEIGHT2.value > 0.0: +# JF += TORQUE_WEIGHT2 * Jtorque2 + # + TVE_WEIGHT * Jtve + # + SF_WEIGHT * Jsf + # + CURRENTS_WEIGHT * DipoleCurrentsObj + # + CURVATURE_WEIGHT * sum(Jcs_TF) \ + # + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD) for J in Jmscs_TF) \ +# + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD) for J in Jmscs) \ + # + CURVATURE_WEIGHT * sum(Jcs) \ + +# We don't have a general interface in SIMSOPT for optimisation problems that +# are not in least-squares form, so we write a little wrapper function that we +# pass directly to scipy.optimize.minimize + +import pstats, io +from pstats import SortKey +# print(btot.ancestors,len(btot.ancestors)) +# print(JF.ancestors,len(JF.ancestors)) + + +def fun(dofs): + JF.x = dofs + # pr = cProfile.Profile() + # pr.enable() + J = JF.J() + grad = JF.dJ() + # pr.disable() + # sio = io.StringIO() + # sortby = SortKey.CUMULATIVE + # ps = pstats.Stats(pr, stream=sio).sort_stats(sortby) + # ps.print_stats(20) + # print(sio.getvalue()) + # exit() + jf = Jf.J() + length_val = LENGTH_WEIGHT.value * Jlength.J() + cc_val = CC_WEIGHT * Jccdist.J() + cs_val = CS_WEIGHT * Jcsdist.J() + link_val = LINK_WEIGHT * linkNum.J() + forces_val = FORCE_WEIGHT.value * Jforce.J() + # forces_val2 = FORCE_WEIGHT2.value * Jforce2.J() + # torques_val = TORQUE_WEIGHT.value * Jtorque.J() + # torques_val2 = TORQUE_WEIGHT2.value * Jtorque2.J() + BdotN = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))) + BdotN_over_B = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)) + ) / np.mean(btot.AbsB()) + outstr = f"J={J:.1e}, Jf={jf:.1e}, ⟨B·n⟩={BdotN:.1e}, ⟨B·n⟩/⟨B⟩={BdotN_over_B:.1e}" + valuestr = f"J={J:.2e}, Jf={jf:.2e}" + cl_string = ", ".join([f"{J.J():.1f}" for J in Jls_TF]) + outstr += f", Len=sum([{cl_string}])={sum(J.J() for J in Jls_TF):.2f}" + valuestr += f", LenObj={length_val:.2e}" + valuestr += f", ccObj={cc_val:.2e}" + valuestr += f", csObj={cs_val:.2e}" + valuestr += f", Lk1Obj={link_val:.2e}" + valuestr += f", forceObj={forces_val:.2e}" + # valuestr += f", forceObj2={forces_val2:.2e}" + # valuestr += f", torqueObj={torques_val:.2e}" + # valuestr += f", torqueObj2={torques_val2:.2e}" + outstr += f", F={Jforce.J():.2e}" + # outstr += f", Fnet={Jforce2.J():.2e}" + # outstr += f", T={Jtorque.J():.2e}" + # outstr += f", Tnet={Jtorque2.J():.2e}" + outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist2.shortest_distance():.2f}" + outstr += f", Link Number = {linkNum.J()}" + outstr += f", ║∇J║={np.linalg.norm(grad):.1e}" + print(outstr) + print(valuestr) + return J, grad + + +print(""" +################################################################################ +### Perform a Taylor test ###################################################### +################################################################################ +""") +f = fun +dofs = JF.x +np.random.seed(1) +h = np.random.uniform(size=dofs.shape) +J0, dJ0 = f(dofs) +dJh = sum(dJ0 * h) +for eps in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]: + t1 = time.time() + J1, _ = f(dofs + eps*h) + J2, _ = f(dofs - eps*h) + t2 = time.time() + print("err", (J1-J2)/(2*eps) - dJh) + +print(""" +################################################################################ +### Run the optimisation ####################################################### +################################################################################ +""") + + +print('Timing calls: ') +t1 = time.time() +Jf.J() +t2 = time.time() +print('Jf time = ', t2 - t1, ' s') +t1 = time.time() +Jf.dJ() +t2 = time.time() +print('dJf time = ', t2 - t1, ' s') +t1 = time.time() +Jccdist.J() +Jccdist.dJ() +t2 = time.time() +print('Jcc time = ', t2 - t1, ' s') +t1 = time.time() +Jcsdist.J() +Jcsdist.dJ() +t2 = time.time() +print('Jcs time = ', t2 - t1, ' s') +t1 = time.time() +linkNum.J() +linkNum.dJ() +t2 = time.time() +print('linkNum time = ', t2 - t1, ' s') +t1 = time.time() +Jlength.J() +Jlength.dJ() +t2 = time.time() +print('sum(Jls_TF) time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce.J() +# t2 = time.time() +# print('Jforces time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce.dJ() +# t2 = time.time() +# print('dJforces time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce2.J() +# t2 = time.time() +# print('Jforces2 time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce2.dJ() +# t2 = time.time() +# print('dJforces2 time = ', t2 - t1, ' s') +# t1 = time.time() +# Jtorque.J() +# t2 = time.time() +# print('Jtorques time = ', t2 - t1, ' s') +# t1 = time.time() +# Jtorque.dJ() +# t2 = time.time() +# print('dJtorques time = ', t2 - t1, ' s') + +n_saves = 1 +MAXITER = 400 +for i in range(1, n_saves + 1): + print('Iteration ' + str(i) + ' / ' + str(n_saves)) + res = minimize(fun, dofs, jac=True, method='L-BFGS-B', + options={'maxiter': MAXITER, 'maxcor': 400}, tol=1e-15) + # dofs = res.x + + curves_to_vtk( + [c.curve for c in bs_TF.coils], + OUT_DIR + "curves_TF_{0:d}".format(i), + close=True, + extra_point_data=pointData_forces_torques(coils_TF, coils_TF, np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b, np.ones(len(coils_TF)) * nturns_TF), + I=[c.current.get_value() for c in bs_TF.coils], + NetForces=coil_net_forces(coils_TF, coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF), + NetTorques=coil_net_torques(coils_TF, coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF), + ) + + btot.set_points(s_plot.gamma().reshape((-1, 3))) + pointData = {"B_N": np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} + s_plot.to_vtk(OUT_DIR + "surf_full_{0:d}".format(i), extra_data=pointData) + + pointData = {"B_N / B": (np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2 + ) / np.linalg.norm(btot.B().reshape(qphi, qtheta, 3), axis=-1))[:, :, None]} + s_plot.to_vtk(OUT_DIR + "surf_full_normalizedBn_{0:d}".format(i), extra_data=pointData) + + btot.set_points(s.gamma().reshape((-1, 3))) + calculate_on_axis_B(btot, s) + # LENGTH_WEIGHT *= 0.01 + # JF = Jf \ + # + CC_WEIGHT * Jccdist \ + # + CS_WEIGHT * Jcsdist \ + # + LINK_WEIGHT * linkNum \ + # + LINK_WEIGHT2 * linkNum2 \ + # + LENGTH_WEIGHT * sum(Jls_TF) + + +t2 = time.time() +print('Total time = ', t2 - t1) +btot.save(OUT_DIR + "biot_savart_optimized_QH.json") +print(OUT_DIR) + diff --git a/examples/3_Advanced/QH_reactorscale_notfixed.py b/examples/3_Advanced/QH_reactorscale_notfixed.py new file mode 100644 index 000000000..05429ba7e --- /dev/null +++ b/examples/3_Advanced/QH_reactorscale_notfixed.py @@ -0,0 +1,643 @@ +#!/usr/bin/env python +r""" +""" + +import os +import shutil +from pathlib import Path +import time +import numpy as np +from scipy.optimize import minimize +from simsopt.field import BiotSavart, Current, coils_via_symmetries +# from simsopt.field import CoilCoilNetForces, CoilCoilNetTorques, \ +# TotalVacuumEnergy, CoilSelfNetForces, CoilCoilNetForces12, CoilCoilNetTorques12 +from simsopt.field import regularization_rect +from simsopt.field.force import MeanSquaredForce, coil_force, coil_torque, coil_net_torques, coil_net_forces, LpCurveForce, \ + SquaredMeanForce, \ + MeanSquaredTorque, SquaredMeanTorque, LpCurveTorque, MixedSquaredMeanForce, MixedLpCurveForce +from simsopt.util import calculate_on_axis_B +from simsopt.geo import ( + CurveLength, CurveCurveDistance, + MeanSquaredCurvature, LpCurveCurvature, CurveSurfaceDistance, LinkingNumber, + SurfaceRZFourier, curves_to_vtk, create_equally_spaced_planar_curves, + create_planar_curves_between_two_toroidal_surfaces +) +from simsopt.objectives import Weight, SquaredFlux, QuadraticPenalty +from simsopt.util import in_github_actions +import cProfile +import re + +t1 = time.time() + +# Number of Fourier modes describing each Cartesian component of each coil: +order = 0 + +# File for the desired boundary magnetic surface: +TEST_DIR = (Path(__file__).parent / ".." / ".." / "tests" / "test_files").resolve() +input_name = 'input.LandremanPaul2021_QH_reactorScale_lowres' +filename = TEST_DIR / input_name + +# Initialize the boundary magnetic surface: +range_param = "half period" +nphi = 32 +ntheta = 32 +poff = 2.0 +coff = 3.0 +s = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi, ntheta=ntheta) +s_inner = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4) +s_outer = SurfaceRZFourier.from_vmec_input(filename, range=range_param, nphi=nphi * 4, ntheta=ntheta * 4) + +# Make the inner and outer surfaces by extending the plasma surface +s_inner.extend_via_normal(poff) +s_outer.extend_via_normal(poff + coff) + +qphi = nphi * 2 +qtheta = ntheta * 2 +quadpoints_phi = np.linspace(0, 1, qphi, endpoint=True) +quadpoints_theta = np.linspace(0, 1, qtheta, endpoint=True) + +# Make high resolution, full torus version of the plasma boundary for plotting +s_plot = SurfaceRZFourier.from_vmec_input( + filename, + quadpoints_phi=quadpoints_phi, + quadpoints_theta=quadpoints_theta +) + +### Initialize some TF coils +def initialize_coils_QH(TEST_DIR, s): + """ + Initializes coils for each of the target configurations that are + used for permanent magnet optimization. + + Args: + config_flag: String denoting the stellarator configuration + being initialized. + TEST_DIR: String denoting where to find the input files. + out_dir: Path or string for the output directory for saved files. + s: plasma boundary surface. + Returns: + base_curves: List of CurveXYZ class objects. + curves: List of Curve class objects. + coils: List of Coil class objects. + """ + from simsopt.geo import create_equally_spaced_curves + from simsopt.field import Current, Coil, coils_via_symmetries + from simsopt.geo import curves_to_vtk + + # generate planar TF coils + ncoils = 2 + R0 = s.get_rc(0, 0) * 1 + R1 = s.get_rc(1, 0) * 4 + order = 8 + + from simsopt.mhd.vmec import Vmec + vmec_file = 'wout_LandremanPaul2021_QH_reactorScale_lowres_reference.nc' + total_current = Vmec(TEST_DIR / vmec_file).external_current() / (2 * s.nfp) / 1.4 + print('Total current = ', total_current) + + # Only need Jax flag for CurvePlanarFourier class + base_curves = create_equally_spaced_curves( + ncoils, s.nfp, stellsym=True, + R0=R0, R1=R1, order=order, numquadpoints=256, + jax_flag=True, + ) + + base_currents = [(Current(total_current / ncoils * 1e-7) * 1e7) for _ in range(ncoils - 1)] + # base_currents = [(Current(total_current / ncoils * 1e-7) * 1e7) for _ in range(ncoils)] + # base_currents[0].fix_all() + + total_current = Current(total_current) + total_current.fix_all() + base_currents += [total_current - sum(base_currents)] + coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) + # for c in coils: + # c.current.fix_all() + # c.curve.fix_all() + + # Initialize the coil curves and save the data to vtk + curves = [c.curve for c in coils] + currents = [c.current.get_value() for c in coils] + return base_curves, curves, coils, base_currents + +# initialize the coils +base_curves_TF, curves_TF, coils_TF, currents_TF = initialize_coils_QH(TEST_DIR, s) +num_TF_unique_coils = len(coils_TF) // 4 +base_coils_TF = coils_TF[:num_TF_unique_coils] +currents_TF = np.array([coil.current.get_value() for coil in coils_TF]) + +# # Set up BiotSavart fields +bs_TF = BiotSavart(coils_TF) + +# # Calculate average, approximate on-axis B field strength +calculate_on_axis_B(bs_TF, s) + +# wire cross section for the TF coils is a square 20 cm x 20 cm +# Only need this if make self forces and TVE nonzero in the objective! +a = 0.2 +b = 0.2 +nturns = 100 +nturns_TF = 200 + +# wire cross section for the dipole coils should be more like 5 cm x 5 cm +aa = 0.05 +bb = 0.05 + +Nx = 6 +Ny = Nx +Nz = Nx +# Create the initial coils: +base_curves, all_curves = create_planar_curves_between_two_toroidal_surfaces( + s, s_inner, s_outer, Nx, Ny, Nz, order=order, coil_coil_flag=True, jax_flag=True, + # numquadpoints=10 # Defaults is (order + 1) * 40 so this halves it +) +import warnings + +keep_inds = [] +for ii in range(len(base_curves)): + counter = 0 + for i in range(base_curves[0].gamma().shape[0]): + eps = 0.05 + for j in range(len(base_curves_TF)): + for k in range(base_curves_TF[j].gamma().shape[0]): + dij = np.sqrt(np.sum((base_curves[ii].gamma()[i, :] - base_curves_TF[j].gamma()[k, :]) ** 2)) + conflict_bool = (dij < (1.0 + eps) * base_curves[0].x[0]) + if conflict_bool: + print('bad indices = ', i, j, dij, base_curves[0].x[0]) + warnings.warn( + 'There is a PSC coil initialized such that it is within a radius' + 'of a TF coil. Deleting these PSCs now.') + counter += 1 + break + if counter == 0: + keep_inds.append(ii) + +print(keep_inds) +base_curves = np.array(base_curves)[keep_inds] + +ncoils = len(base_curves) +print('Ncoils = ', ncoils) +coil_normals = np.zeros((ncoils, 3)) +plasma_points = s.gamma().reshape(-1, 3) +plasma_unitnormals = s.unitnormal().reshape(-1, 3) +for i in range(ncoils): + point = (base_curves[i].get_dofs()[-3:]) + dists = np.sum((point - plasma_points) ** 2, axis=-1) + min_ind = np.argmin(dists) + coil_normals[i, :] = plasma_unitnormals[min_ind, :] + # coil_normals[i, :] = (plasma_points[min_ind, :] - point) +coil_normals = coil_normals / np.linalg.norm(coil_normals, axis=-1)[:, None] +# alphas = np.arctan2( +# -coil_normals[:, 1], +# np.sqrt(coil_normals[:, 0] ** 2 + coil_normals[:, 2] ** 2)) +# deltas = np.arcsin(coil_normals[:, 0] / \ +# np.sqrt(coil_normals[:, 0] ** 2 + coil_normals[:, 2] ** 2)) +alphas = np.arcsin( + -coil_normals[:, 1], + ) +deltas = np.arctan2(coil_normals[:, 0], coil_normals[:, 2]) +for i in range(len(base_curves)): + alpha2 = alphas[i] / 2.0 + delta2 = deltas[i] / 2.0 + calpha2 = np.cos(alpha2) + salpha2 = np.sin(alpha2) + cdelta2 = np.cos(delta2) + sdelta2 = np.sin(delta2) + base_curves[i].set('x' + str(2 * order + 1), calpha2 * cdelta2) + base_curves[i].set('x' + str(2 * order + 2), salpha2 * cdelta2) + base_curves[i].set('x' + str(2 * order + 3), calpha2 * sdelta2) + base_curves[i].set('x' + str(2 * order + 4), -salpha2 * sdelta2) + # Fix orientations of each coil + base_curves[i].fix('x' + str(2 * order + 1)) + base_curves[i].fix('x' + str(2 * order + 2)) + base_curves[i].fix('x' + str(2 * order + 3)) + base_curves[i].fix('x' + str(2 * order + 4)) + + # Fix shape of each coil + for j in range(2 * order + 1): + base_curves[i].fix('x' + str(j)) + # Fix center points of each coil + base_curves[i].fix('x' + str(2 * order + 5)) + base_curves[i].fix('x' + str(2 * order + 6)) + base_curves[i].fix('x' + str(2 * order + 7)) +base_currents = [Current(1e-1) * 2e7 for i in range(ncoils)] + +coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) +base_coils = coils[:ncoils] + +def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list): + contig = np.ascontiguousarray + forces = np.zeros((len(coils), len(coils[0].curve.gamma()) + 1, 3)) + torques = np.zeros((len(coils), len(coils[0].curve.gamma()) + 1, 3)) + for i, c in enumerate(coils): + aprime = aprimes[i] + bprime = bprimes[i] + # print(np.shape(bs._coils), np.shape(coils)) + # B_other = BiotSavart([cc for j, cc in enumerate(bs._coils) if i != j]).set_points(c.curve.gamma()).B() + # print(np.shape(bs._coils)) + # B_other = bs.set_points(c.curve.gamma()).B() + # print(B_other) + # exit() + forces[i, :-1, :] = coil_force(c, allcoils, regularization_rect(aprime, bprime), nturns_list[i]) + # print(i, forces[i, :-1, :]) + # bs._coils = coils + torques[i, :-1, :] = coil_torque(c, allcoils, regularization_rect(aprime, bprime), nturns_list[i]) + + forces[:, -1, :] = forces[:, 0, :] + torques[:, -1, :] = torques[:, 0, :] + forces = forces.reshape(-1, 3) + torques = torques.reshape(-1, 3) + point_data = {"Pointwise_Forces": (contig(forces[:, 0]), contig(forces[:, 1]), contig(forces[:, 2])), + "Pointwise_Torques": (contig(torques[:, 0]), contig(torques[:, 1]), contig(torques[:, 2]))} + return point_data + +bs = BiotSavart(coils) # + coils_TF) +btot = bs + bs_TF +calculate_on_axis_B(btot, s) +btot.set_points(s.gamma().reshape((-1, 3))) +bs.set_points(s.gamma().reshape((-1, 3))) +curves = [c.curve for c in coils] +currents = [c.current.get_value() for c in coils] +a_list = np.hstack((np.ones(len(coils)) * aa, np.ones(len(coils_TF)) * a)) +b_list = np.hstack((np.ones(len(coils)) * bb, np.ones(len(coils_TF)) * b)) +base_a_list = np.hstack((np.ones(len(base_coils)) * aa, np.ones(len(base_coils_TF)) * a)) +base_b_list = np.hstack((np.ones(len(base_coils)) * bb, np.ones(len(base_coils_TF)) * b)) + +LENGTH_WEIGHT = Weight(0.001) +LENGTH_TARGET = 80 +LINK_WEIGHT = 1e3 +CC_THRESHOLD = 0.8 +CC_WEIGHT = 1e1 +CS_THRESHOLD = 1.5 +CS_WEIGHT = 1e2 +# Weight for the Coil Coil forces term +FORCE_WEIGHT = Weight(1e-34) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# TORQUE_WEIGHT2 = Weight(4e-27) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# FORCE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +TORQUE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons +# Directory for output +OUT_DIR = ("./QH_debug/") +if os.path.exists(OUT_DIR): + shutil.rmtree(OUT_DIR) +os.makedirs(OUT_DIR, exist_ok=True) + +curves_to_vtk( + curves_TF, + OUT_DIR + "curves_TF_0", + close=True, + extra_point_data=pointData_forces_torques(coils_TF, coils + coils_TF, np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b, np.ones(len(coils_TF)) * nturns_TF), + I=currents_TF, + NetForces=coil_net_forces(coils_TF, coils + coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF), + NetTorques=coil_net_torques(coils_TF, coils + coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF) +) +curves_to_vtk( + curves, + OUT_DIR + "curves_0", + close=True, + extra_point_data=pointData_forces_torques(coils, coils + coils_TF, np.ones(len(coils)) * aa, np.ones(len(coils)) * bb, np.ones(len(coils)) * nturns), + I=currents, + NetForces=coil_net_forces(coils, coils + coils_TF, regularization_rect(np.ones(len(coils)) * aa, np.ones(len(coils)) * bb), np.ones(len(coils)) * nturns), + NetTorques=coil_net_torques(coils, coils + coils_TF, regularization_rect(np.ones(len(coils)) * aa, np.ones(len(coils)) * bb), np.ones(len(coils)) * nturns) +) +# Force and Torque calculations spawn a bunch of spurious BiotSavart child objects -- erase them! +for c in (coils + coils_TF): + c._children = set() + +pointData = {"B_N": np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} +s.to_vtk(OUT_DIR + "surf_init_DA", extra_data=pointData) + +btot.set_points(s_plot.gamma().reshape((-1, 3))) +pointData = {"B_N": np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} +s_plot.to_vtk(OUT_DIR + "surf_full_init_DA", extra_data=pointData) +btot.set_points(s.gamma().reshape((-1, 3))) + +# Repeat for whole B field +pointData = {"B_N": np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} +s.to_vtk(OUT_DIR + "surf_init", extra_data=pointData) + +btot.set_points(s_plot.gamma().reshape((-1, 3))) +pointData = {"B_N": np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} +s_plot.to_vtk(OUT_DIR + "surf_full_init", extra_data=pointData) +btot.set_points(s.gamma().reshape((-1, 3))) + +# Define the individual terms objective function: +Jf = SquaredFlux(s, btot) +# Separate length penalties on the dipole coils and the TF coils +# since they have very different sizes +# Jls = [CurveLength(c) for c in base_curves] +Jls_TF = [CurveLength(c) for c in base_curves_TF] +Jlength = QuadraticPenalty(sum(Jls_TF), LENGTH_TARGET, "max") + +# coil-coil and coil-plasma distances should be between all coils + +### Jcc below removed the dipoles! +Jccdist = CurveCurveDistance(curves_TF, CC_THRESHOLD, num_basecurves=len(coils_TF)) +Jcsdist = CurveSurfaceDistance(curves_TF, s, CS_THRESHOLD) +Jcsdist2 = CurveSurfaceDistance(curves + curves_TF, s, CS_THRESHOLD) + +# While the coil array is not moving around, they cannot +# interlink. +linkNum = LinkingNumber(curves + curves_TF, downsample=2) + +##### Note need coils_TF + coils below!!!!!!! +# Jforce2 = sum([LpCurveForce(c, coils + coils_TF, +# regularization=regularization_rect(base_a_list[i], base_b_list[i]), +# p=2, threshold=1e8) for i, c in enumerate(base_coils + base_coils_TF)]) +# Jforce2 = LpCurveForce2(coils, coils_TF, p=2, threshold=1e8) +# Jforce = sum([SquaredMeanForce(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) + +regularization_list = np.ones(len(coils)) * regularization_rect(aa, bb) +regularization_list2 = np.ones(len(coils_TF)) * regularization_rect(a, b) +# Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2) # [SquaredMeanForce2(c, coils) for c in (base_coils)] +# Jforce = MixedSquaredMeanForce(coils, coils_TF) +# Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2, p=4, threshold=4e5 * 100, downsample=4) +Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=4, threshold=4e5 * 100, downsample=4) for i, c in enumerate(base_coils + base_coils_TF)]) +# Jforce2 = sum([SquaredMeanForce(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) +# Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=4e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) +# Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=1e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)]) + + +# Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils_TF)]) + +# Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]) + + +# Jtorque = SquaredMeanTorque2(coils, coils_TF) # [SquaredMeanForce2(c, coils) for c in (base_coils)] +# Jtorque = [SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)] + +JF = Jf \ + + CC_WEIGHT * Jccdist \ + + CS_WEIGHT * Jcsdist \ + + LENGTH_WEIGHT * Jlength \ + + LINK_WEIGHT * linkNum + +if FORCE_WEIGHT.value > 0.0: + JF += FORCE_WEIGHT.value * Jforce #\ + +# if FORCE_WEIGHT2.value > 0.0: +# JF += FORCE_WEIGHT2.value * Jforce2 #\ + +# if TORQUE_WEIGHT.value > 0.0: +# JF += TORQUE_WEIGHT * Jtorque + +# if TORQUE_WEIGHT2.value > 0.0: +# JF += TORQUE_WEIGHT2 * Jtorque2 + # + TVE_WEIGHT * Jtve + # + SF_WEIGHT * Jsf + # + CURRENTS_WEIGHT * DipoleCurrentsObj + # + CURVATURE_WEIGHT * sum(Jcs_TF) \ + # + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD) for J in Jmscs_TF) \ +# + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD) for J in Jmscs) \ + # + CURVATURE_WEIGHT * sum(Jcs) \ + +# We don't have a general interface in SIMSOPT for optimisation problems that +# are not in least-squares form, so we write a little wrapper function that we +# pass directly to scipy.optimize.minimize + + +print('Timing calls: ') +t1 = time.time() +Jf.J() +t2 = time.time() +print('Jf time = ', t2 - t1, ' s') +t1 = time.time() +Jf.dJ() +t2 = time.time() +print('dJf time = ', t2 - t1, ' s') +t1 = time.time() +Jccdist.J() +Jccdist.dJ() +t2 = time.time() +print('Jcc time = ', t2 - t1, ' s') +t1 = time.time() +Jcsdist.J() +Jcsdist.dJ() +t2 = time.time() +print('Jcs time = ', t2 - t1, ' s') +t1 = time.time() +linkNum.J() +linkNum.dJ() +t2 = time.time() +print('linkNum time = ', t2 - t1, ' s') +t1 = time.time() +Jlength.J() +Jlength.dJ() +t2 = time.time() +print('sum(Jls_TF) time = ', t2 - t1, ' s') +t1 = time.time() +print(Jforce.J()) +print(Jforce.dJ()) +t2 = time.time() +print('Jforce time = ', t2 - t1, ' s') +t1 = time.time() +JF.J() +JF.dJ() +t2 = time.time() +print('JF time = ', t2 - t1, ' s') +import pstats, io +from pstats import SortKey +# print(btot.ancestors,len(btot.ancestors)) +# print(JF.ancestors,len(JF.ancestors)) + + +def fun(dofs): + pr = cProfile.Profile() + pr.enable() + JF.x = dofs + J = JF.J() + grad = JF.dJ() + jf = Jf.J() + length_val = LENGTH_WEIGHT.value * Jlength.J() + cc_val = CC_WEIGHT * Jccdist.J() + cs_val = CS_WEIGHT * Jcsdist.J() + link_val = LINK_WEIGHT * linkNum.J() + forces_val = FORCE_WEIGHT.value * Jforce.J() + # Jforce3.dJ() + # forces_val2 = FORCE_WEIGHT2.value * Jforce2.J() + # torques_val = TORQUE_WEIGHT.value * Jtorque.J() + # torques_val2 = TORQUE_WEIGHT2.value * Jtorque2.J() + BdotN = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))) + BdotN_over_B = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)) + ) / np.mean(btot.AbsB()) + outstr = f"J={J:.1e}, Jf={jf:.1e}, ⟨B·n⟩={BdotN:.1e}, ⟨B·n⟩/⟨B⟩={BdotN_over_B:.1e}" + valuestr = f"J={J:.2e}, Jf={jf:.2e}" + cl_string = ", ".join([f"{J.J():.1f}" for J in Jls_TF]) + outstr += f", Len=sum([{cl_string}])={sum(J.J() for J in Jls_TF):.2f}" + valuestr += f", LenObj={length_val:.2e}" + valuestr += f", ccObj={cc_val:.2e}" + valuestr += f", csObj={cs_val:.2e}" + valuestr += f", Lk1Obj={link_val:.2e}" + valuestr += f", forceObj={forces_val:.2e}" + # valuestr += f", forceObj2={forces_val2:.2e}" + # valuestr += f", torqueObj={torques_val:.2e}" + # valuestr += f", torqueObj2={torques_val2:.2e}" + outstr += f", F={Jforce.J():.2e}" + # outstr += f", Fnet={Jforce2.J():.2e}" + # outstr += f", T={Jtorque.J():.2e}" + # outstr += f", Tnet={Jtorque2.J():.2e}" + outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist2.shortest_distance():.2f}" + outstr += f", Link Number = {linkNum.J()}" + outstr += f", ║∇J║={np.linalg.norm(grad):.1e}" + print(coils[0]._children, coils_TF[0]._children, JF._children, Jforce._children) + print(outstr) + print(valuestr) + pr.disable() + sio = io.StringIO() + sortby = SortKey.CUMULATIVE + ps = pstats.Stats(pr, stream=sio).sort_stats(sortby) + ps.print_stats(20) + print(sio.getvalue()) + # exit() + return J, grad + + +print(""" +################################################################################ +### Perform a Taylor test ###################################################### +################################################################################ +""") +f = fun +dofs = JF.x +np.random.seed(1) +h = np.random.uniform(size=dofs.shape) + +print(""" +Calling f now +""") +J0, dJ0 = f(dofs) +dJh = sum(dJ0 * h) +for eps in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]: + t1 = time.time() + J1, _ = f(dofs + eps*h) + J2, _ = f(dofs - eps*h) + t2 = time.time() + print("err", (J1-J2)/(2*eps) - dJh) + +print(""" +################################################################################ +### Run the optimisation ####################################################### +################################################################################ +""") + + +print('Timing calls: ') +t1 = time.time() +Jf.J() +t2 = time.time() +print('Jf time = ', t2 - t1, ' s') +t1 = time.time() +Jf.dJ() +t2 = time.time() +print('dJf time = ', t2 - t1, ' s') +t1 = time.time() +Jccdist.J() +Jccdist.dJ() +t2 = time.time() +print('Jcc time = ', t2 - t1, ' s') +t1 = time.time() +Jcsdist.J() +Jcsdist.dJ() +t2 = time.time() +print('Jcs time = ', t2 - t1, ' s') +t1 = time.time() +linkNum.J() +linkNum.dJ() +t2 = time.time() +print('linkNum time = ', t2 - t1, ' s') +t1 = time.time() +Jlength.J() +Jlength.dJ() +t2 = time.time() +print('sum(Jls_TF) time = ', t2 - t1, ' s') +t1 = time.time() +print(Jforce.J()) +print(Jforce.dJ()) +t2 = time.time() +print('Jforce time = ', t2 - t1, ' s') +t1 = time.time() +JF.J() +JF.dJ() +t2 = time.time() +print('JF time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce.J() +# t2 = time.time() +# print('Jforces time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce.dJ() +# t2 = time.time() +# print('dJforces time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce2.J() +# t2 = time.time() +# print('Jforces2 time = ', t2 - t1, ' s') +# t1 = time.time() +# Jforce2.dJ() +# t2 = time.time() +# print('dJforces2 time = ', t2 - t1, ' s') +# t1 = time.time() +# Jtorque.J() +# t2 = time.time() +# print('Jtorques time = ', t2 - t1, ' s') +# t1 = time.time() +# Jtorque.dJ() +# t2 = time.time() +# print('dJtorques time = ', t2 - t1, ' s') + +n_saves = 1 +MAXITER = 200 +for i in range(1, n_saves + 1): + print('Iteration ' + str(i) + ' / ' + str(n_saves)) + res = minimize(fun, dofs, jac=True, method='L-BFGS-B', + options={'maxiter': MAXITER, 'maxcor': 200}, tol=1e-15) + # dofs = res.x + + dipole_currents = [c.current.get_value() for c in bs.coils] + curves_to_vtk( + [c.curve for c in bs.coils], + OUT_DIR + "curves_{0:d}".format(i), + close=True, + extra_point_data=pointData_forces_torques(coils, coils + coils_TF, np.ones(len(coils)) * aa, np.ones(len(coils)) * bb, np.ones(len(coils)) * nturns), + I=dipole_currents, + NetForces=coil_net_forces(coils, coils + coils_TF, regularization_rect(np.ones(len(coils)) * aa, np.ones(len(coils)) * bb), np.ones(len(coils)) * nturns), + NetTorques=coil_net_torques(coils, coils + coils_TF, regularization_rect(np.ones(len(coils)) * aa, np.ones(len(coils)) * bb), np.ones(len(coils)) * nturns), + ) + curves_to_vtk( + [c.curve for c in bs_TF.coils], + OUT_DIR + "curves_TF_{0:d}".format(i), + close=True, + extra_point_data=pointData_forces_torques(coils_TF, coils + coils_TF, np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b, np.ones(len(coils_TF)) * nturns_TF), + I=[c.current.get_value() for c in bs_TF.coils], + NetForces=coil_net_forces(coils_TF, coils + coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF), + NetTorques=coil_net_torques(coils_TF, coils + coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF), + ) + + btot.set_points(s_plot.gamma().reshape((-1, 3))) + pointData = {"B_N": np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} + s_plot.to_vtk(OUT_DIR + "surf_full_{0:d}".format(i), extra_data=pointData) + + pointData = {"B_N / B": (np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2 + ) / np.linalg.norm(btot.B().reshape(qphi, qtheta, 3), axis=-1))[:, :, None]} + s_plot.to_vtk(OUT_DIR + "surf_full_normalizedBn_{0:d}".format(i), extra_data=pointData) + + btot.set_points(s.gamma().reshape((-1, 3))) + print('Max I = ', np.max(np.abs(dipole_currents))) + print('Min I = ', np.min(np.abs(dipole_currents))) + calculate_on_axis_B(btot, s) + # LENGTH_WEIGHT *= 0.01 + # JF = Jf \ + # + CC_WEIGHT * Jccdist \ + # + CS_WEIGHT * Jcsdist \ + # + LINK_WEIGHT * linkNum \ + # + LINK_WEIGHT2 * linkNum2 \ + # + LENGTH_WEIGHT * sum(Jls_TF) + + +t2 = time.time() +print('Total time = ', t2 - t1) +btot.save(OUT_DIR + "biot_savart_optimized_QH.json") +print(OUT_DIR) + diff --git a/src/simsopt/field/biotsavart.py b/src/simsopt/field/biotsavart.py index 8cfe5571d..0cf2001ec 100644 --- a/src/simsopt/field/biotsavart.py +++ b/src/simsopt/field/biotsavart.py @@ -34,7 +34,7 @@ def __init__(self, coils): self._coils = coils sopp.BiotSavart.__init__(self, coils) MagneticField.__init__(self, depends_on=coils) - self.B_vjp_jax = jit(lambda v: self.B_vjp_pure(v)) + # self.B_vjp_jax = jit(lambda v: self.B_vjp_pure(v)) def dB_by_dcoilcurrents(self, compute_derivatives=0): points = self.get_points_cart_ref() diff --git a/src/simsopt/field/force.py b/src/simsopt/field/force.py index f544484f9..b8bc6c286 100644 --- a/src/simsopt/field/force.py +++ b/src/simsopt/field/force.py @@ -39,9 +39,13 @@ def coil_force(coil, allcoils, regularization, nturns=1): ### Line below seems to be the issue -- all these BiotSavart objects seem to stick ### around and not to go out of scope after these calls! - mutual_field = BiotSavart(mutual_coils).set_points(coil.curve.gamma()).B() - mutualforce = np.cross(coil.current.get_value() * tangent, mutual_field) + mutual_field = BiotSavart(mutual_coils).set_points(coil.curve.gamma()) + B_mutual = mutual_field.B() + mutualforce = np.cross(coil.current.get_value() * tangent, B_mutual) selfforce = self_force(coil, regularization) + mutual_field._children = set() + for c in mutual_coils: + c._children = set() return (selfforce + mutualforce) / nturns def coil_net_forces(coils, allcoils, regularization, nturns=None): @@ -96,8 +100,8 @@ def self_force_rect(coil, a, b): return self_force(coil, regularization_rect(a, b)) -@jit -def lp_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold): +# @jit +def lp_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold, downsample): r"""Pure function for minimizing the Lorentz force on a coil. The function is @@ -108,13 +112,17 @@ def lp_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regulari where :math:`\vec{F}` is the Lorentz force, :math:`F_0` is a threshold force, and :math:`\ell` is arclength along the coil. """ - + B_mutual = B_mutual[::downsample, :] + gamma = gamma[::downsample, :] + gammadash = gammadash[::downsample, :] + gammadashdash = gammadashdash[::downsample, :] + quadpoints = quadpoints[::downsample] B_self = B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization) gammadash_norm = jnp.linalg.norm(gammadash, axis=1)[:, None] tangent = gammadash / gammadash_norm force = jnp.cross(current * tangent, B_self + B_mutual) force_norm = jnp.linalg.norm(force, axis=1)[:, None] - return (jnp.sum(jnp.maximum(force_norm - threshold, 0)**p * gammadash_norm)) * (1. / p) + return (jnp.sum(jnp.maximum(force_norm - threshold, 0)**p * gammadash_norm)) * (1. / p) / jnp.shape(gamma)[0] class LpCurveForce(Optimizable): @@ -129,81 +137,107 @@ class LpCurveForce(Optimizable): and :math:`\ell` is arclength along the coil. """ - def __init__(self, coil, allcoils, regularization, p=2.0, threshold=0.0): + def __init__(self, coil, allcoils, regularization, p=2.0, threshold=0.0, downsample=1): self.coil = coil self.allcoils = allcoils self.othercoils = [c for c in allcoils if c is not coil] - self.biotsavart = BiotSavart(self.othercoils) + # self.biotsavart = BiotSavart(self.othercoils) quadpoints = self.coil.curve.quadpoints + self.downsample = downsample + args = {"static_argnums": (5,)} self.J_jax = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - lp_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + lp_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold, downsample), + **args ) self.dJ_dgamma = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) self.dJ_dgammadash = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) self.dJ_dgammadashdash = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) self.dJ_dcurrent = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=3)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=3)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) self.dJ_dB_mutual = jit( - lambda gamma, gammadash, gammadashdash, current, B_mutual: - grad(self.J_jax, argnums=4)(gamma, gammadash, gammadashdash, current, B_mutual) + lambda gamma, gammadash, gammadashdash, current, B_mutual, downsample: + grad(self.J_jax, argnums=4)(gamma, gammadash, gammadashdash, current, B_mutual, downsample), + **args ) super().__init__(depends_on=allcoils) def J(self): - self.biotsavart.set_points(self.coil.curve.gamma()) + biotsavart = BiotSavart(self.othercoils) + biotsavart.set_points(self.coil.curve.gamma()) + # biotsavart._children = set() args = [ self.coil.curve.gamma(), self.coil.curve.gammadash(), self.coil.curve.gammadashdash(), self.coil.current.get_value(), - self.biotsavart.B() + biotsavart.B(), + self.downsample ] + # biotsavart._children = set() + for c in self.othercoils: + c._children = set() + c.curve._children = set() + c.current._children = set() return self.J_jax(*args) @derivative_dec def dJ(self): - self.biotsavart.set_points(self.coil.curve.gamma()) + biotsavart = BiotSavart(self.othercoils) + biotsavart.set_points(self.coil.curve.gamma()) + # biotsavart._children = set() args = [ self.coil.curve.gamma(), self.coil.curve.gammadash(), self.coil.curve.gammadashdash(), self.coil.current.get_value(), - self.biotsavart.B() + biotsavart.B(), + self.downsample ] dJ_dB = self.dJ_dB_mutual(*args) - dB_dX = self.biotsavart.dB_by_dX() + dB_dX = biotsavart.dB_by_dX() dJ_dX = np.einsum('ij,ikj->ik', dJ_dB, dB_dX) - return ( + dJ = ( self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args) + dJ_dX) + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args)) + self.coil.curve.dgammadashdash_by_dcoeff_vjp(self.dJ_dgammadashdash(*args)) + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args)])) - + self.biotsavart.B_vjp(dJ_dB) + + biotsavart.B_vjp(dJ_dB) ) + # biotsavart._children = set() + for c in self.othercoils: + c._children = set() + c.curve._children = set() + c.current._children = set() + + return dJ return_fn_map = {'J': J, 'dJ': dJ} @@ -543,21 +577,38 @@ def dJ(self): return_fn_map = {'J': J, 'dJ': dJ} - -@jit +# @jit def mixed_lp_force_pure(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, quadpoints, quadpoints2, - currents, currents2, regularizations, regularizations2, p, threshold): + currents, currents2, regularizations, regularizations2, p, threshold, + downsample=1, + ): r""" """ - B_self = [B_regularized_pure(gammas[i], gammadashs[i], gammadashdashs[i], quadpoints, currents[i], regularizations[i]) for i in range(jnp.shape(gammas)[0])] - B_self2 = [B_regularized_pure(gammas2[i], gammadashs2[i], gammadashdashs2[i], quadpoints2, currents2[i], regularizations2[i]) for i in range(jnp.shape(gammas2)[0])] + # print(jnp.shape(quadpoints), jnp.shape(gammas), jnp.shape(gammadashs), jnp.shape(gammadashdashs)) + # print(jnp.shape(quadpoints2), jnp.shape(gammas2), jnp.shape(gammadashs2), jnp.shape(gammadashdashs2)) + quadpoints = quadpoints[::downsample] + gammas = gammas[:, ::downsample, :] + gammadashs = gammadashs[:, ::downsample, :] + gammadashdashs = gammadashdashs[:, ::downsample, :] + + quadpoints2 = quadpoints2[::downsample] + gammas2 = gammas2[:, ::downsample, :] + gammadashs2 = gammadashs2[:, ::downsample, :] + gammadashdashs2 = gammadashdashs2[:, ::downsample, :] + # print(jnp.shape(quadpoints), jnp.shape(gammas), jnp.shape(gammadashs), jnp.shape(gammadashdashs)) + # print(jnp.shape(quadpoints2), jnp.shape(gammas2), jnp.shape(gammadashs2), jnp.shape(gammadashdashs2)) + + B_self = jnp.array([B_regularized_pure(gammas[i], gammadashs[i], gammadashdashs[i], quadpoints, + currents[i], regularizations[i]) for i in range(jnp.shape(gammas)[0])]) + B_self2 = jnp.array([B_regularized_pure(gammas2[i], gammadashs2[i], gammadashdashs2[i], quadpoints2, + currents2[i], regularizations2[i]) for i in range(jnp.shape(gammas2)[0])]) gammadash_norms = jnp.linalg.norm(gammadashs, axis=-1)[:, :, None] tangents = gammadashs / gammadash_norms gammadash_norms2 = jnp.linalg.norm(gammadashs2, axis=-1)[:, :, None] tangents2 = gammadashs2 / gammadash_norms2 - selfforce = jnp.array([jnp.cross(currents[i] * tangents[i], B_self[i]) for i in range(jnp.shape(gammas)[0])]) - selfforce2 = jnp.array([jnp.cross(currents2[i] * tangents2[i], B_self2[i]) for i in range(jnp.shape(gammas2)[0])]) + # selfforce = jnp.array([jnp.cross(currents[i] * tangents[i], B_self[i]) for i in range(jnp.shape(gammas)[0])]) + # selfforce2 = jnp.array([jnp.cross(currents2[i] * tangents2[i], B_self2[i]) for i in range(jnp.shape(gammas2)[0])]) eps = 1e-10 # small number to avoid blow up in the denominator when i = j r_ij = gammas[:, None, :, None, :] - gammas[None, :, None, :, :] # Note, do not use the i = j indices @@ -565,7 +616,8 @@ def mixed_lp_force_pure(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs ### Note that need to do dl1 x dl2 x r12 here instead of just (dl1 * dl2)r12 # because these are not equivalent expressions if we are squaring the pointwise forces # before integration over coil i! - cross_prod = jnp.cross(tangents[:, None, :, None, :], jnp.cross(gammadashs[None, :, None, :, :], r_ij)) + # cross_prod = jnp.cross(tangents[:, None, :, None, :], jnp.cross(gammadashs[None, :, None, :, :], r_ij)) + cross_prod = jnp.cross(gammadashs[None, :, None, :, :], r_ij) rij_norm3 = jnp.linalg.norm(r_ij + eps, axis=-1) ** 3 Ii_Ij = currents[:, None] * currents[None, :] Ii_Ij = Ii_Ij.at[:, :].add(-jnp.diag(jnp.diag(Ii_Ij))) @@ -573,29 +625,32 @@ def mixed_lp_force_pure(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs # repeat with gamma, gamma2 r_ij = gammas[:, None, :, None, :] - gammas2[None, :, None, :, :] # Note, do not use the i = j indices - cross_prod = jnp.cross(tangents[:, None, :, None, :], jnp.cross(gammadashs2[None, :, None, :, :], r_ij)) + # cross_prod = jnp.cross(tangents[:, None, :, None, :], jnp.cross(gammadashs2[None, :, None, :, :], r_ij)) + cross_prod = jnp.cross(gammadashs2[None, :, None, :, :], r_ij) rij_norm3 = jnp.linalg.norm(r_ij + eps, axis=-1) ** 3 Ii_Ij = currents[:, None] * currents2[None, :] F += jnp.sum(Ii_Ij[:, :, None, None] * jnp.sum(cross_prod / rij_norm3[:, :, :, :, None], axis=3), axis=1) / jnp.shape(gammas2)[1] - force_norm = jnp.linalg.norm(F * 1e-7 + selfforce, axis=-1) - summ = jnp.sum(jnp.maximum(force_norm[:, :, None] - threshold, 0) ** p * gammadash_norms) + force_norm = jnp.linalg.norm(jnp.cross(tangents, F * 1e-7 + currents[:, None, None] * B_self), axis=-1) + summ = jnp.sum(jnp.maximum(force_norm[:, :, None] - threshold, 0) ** p * gammadash_norms) / jnp.shape(gammas)[1] # repeat with gamma2, gamma r_ij = gammas2[:, None, :, None, :] - gammas[None, :, None, :, :] # Note, do not use the i = j indices - cross_prod = jnp.cross(tangents2[:, None, :, None, :], jnp.cross(gammadashs[None, :, None, :, :], r_ij)) + cross_prod = jnp.cross(gammadashs[None, :, None, :, :], r_ij) + # cross_prod = jnp.cross(tangents2[:, None, :, None, :], jnp.cross(gammadashs[None, :, None, :, :], r_ij)) rij_norm3 = jnp.linalg.norm(r_ij + eps, axis=-1) ** 3 Ii_Ij = currents2[:, None] * currents[None, :] F = jnp.sum(Ii_Ij[:, :, None, None] * jnp.sum(cross_prod / rij_norm3[:, :, :, :, None], axis=3), axis=1) / jnp.shape(gammas)[1] # repeat with gamma2, gamma2 r_ij = gammas2[:, None, :, None, :] - gammas2[None, :, None, :, :] # Note, do not use the i = j indices - cross_prod = jnp.cross(tangents2[:, None, :, None, :], jnp.cross(gammadashs2[None, :, None, :, :], r_ij)) + cross_prod = jnp.cross(gammadashs2[None, :, None, :, :], r_ij) + # cross_prod = jnp.cross(tangents2[:, None, :, None, :], jnp.cross(gammadashs2[None, :, None, :, :], r_ij)) rij_norm3 = jnp.linalg.norm(r_ij + eps, axis=-1) ** 3 Ii_Ij = currents2[:, None] * currents2[None, :] Ii_Ij = Ii_Ij.at[:, :].add(-jnp.diag(jnp.diag(Ii_Ij))) F += jnp.sum(Ii_Ij[:, :, None, None] * jnp.sum(cross_prod / rij_norm3[:, :, :, :, None], axis=3), axis=1) / jnp.shape(gammas2)[1] - force_norm2 = jnp.linalg.norm(F * 1e-7 + selfforce2, axis=-1) - summ += jnp.sum(jnp.maximum(force_norm2[:, :, None] - threshold, 0) ** p * gammadash_norms2) + force_norm2 = jnp.linalg.norm(jnp.cross(tangents2, F * 1e-7 + currents2[:, None, None] * B_self2), axis=-1) + summ += jnp.sum(jnp.maximum(force_norm2[:, :, None] - threshold, 0) ** p * gammadash_norms2) / jnp.shape(gammas2)[1] return summ * (1 / p) class MixedLpCurveForce(Optimizable): @@ -610,56 +665,66 @@ class MixedLpCurveForce(Optimizable): along the coil. """ - def __init__(self, allcoils, allcoils2, regularizations, regularizations2, p=2.0, threshold=0.0): + def __init__(self, allcoils, allcoils2, regularizations, regularizations2, p=2.0, threshold=0.0, downsample=1): self.allcoils = allcoils self.allcoils2 = allcoils2 quadpoints = self.allcoils[0].curve.quadpoints quadpoints2 = self.allcoils2[0].curve.quadpoints + self.downsample = downsample + args = {"static_argnums": (8,)} self.J_jax = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: mixed_lp_force_pure(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, quadpoints, quadpoints2, - currents, currents2, regularizations, regularizations2, p, threshold) + currents, currents2, regularizations, regularizations2, p, threshold, downsample), + **args ) self.dJ_dgamma = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=0)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=0)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) self.dJ_dgamma2 = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=1)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=1)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) self.dJ_dgammadash = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=2)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=2)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) self.dJ_dgammadash2 = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=3)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=3)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) self.dJ_dgammadashdash = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=4)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=4)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) self.dJ_dgammadashdash2 = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=5)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=5)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) self.dJ_dcurrent = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=6)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=6)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) - self.dJ_dcurrent2 = jit( - lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2: - grad(self.J_jax, argnums=7)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2) + lambda gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample: + grad(self.J_jax, argnums=7)(gammas, gammas2, gammadashs, gammadashs2, gammadashdashs, gammadashdashs2, currents, currents2, downsample), + **args ) super().__init__(depends_on=(allcoils + allcoils2)) @@ -676,6 +741,7 @@ def J(self): jnp.array([c.curve.gammadashdash() for c in self.allcoils2]), jnp.array([c.current.get_value() for c in self.allcoils]), jnp.array([c.current.get_value() for c in self.allcoils2]), + self.downsample ] return self.J_jax(*args) @@ -692,6 +758,7 @@ def dJ(self): jnp.array([c.curve.gammadashdash() for c in self.allcoils2]), jnp.array([c.current.get_value() for c in self.allcoils]), jnp.array([c.current.get_value() for c in self.allcoils2]), + self.downsample ] dJ_dgamma = self.dJ_dgamma(*args) dJ_dgammadash = self.dJ_dgammadash(*args) @@ -709,7 +776,7 @@ def dJ(self): + sum([c.current.vjp(jnp.asarray([dJ_dcurrent[i]])) for i, c in enumerate(self.allcoils)]) + sum([c.curve.dgamma_by_dcoeff_vjp(dJ_dgamma2[i]) for i, c in enumerate(self.allcoils2)]) + sum([c.curve.dgammadash_by_dcoeff_vjp(dJ_dgammadash2[i]) for i, c in enumerate(self.allcoils2)]) - + sum([c.curve.dgammadashdash_by_dcoeff_vjp(dJ_dgammadashdash[i]) for i, c in enumerate(self.allcoils2)]) + + sum([c.curve.dgammadashdash_by_dcoeff_vjp(dJ_dgammadashdash2[i]) for i, c in enumerate(self.allcoils2)]) + sum([c.current.vjp(jnp.asarray([dJ_dcurrent2[i]])) for i, c in enumerate(self.allcoils2)]) ) diff --git a/src/simsopt/field/selffield.py b/src/simsopt/field/selffield.py index 46b9a003c..ab6fff0c0 100755 --- a/src/simsopt/field/selffield.py +++ b/src/simsopt/field/selffield.py @@ -55,8 +55,8 @@ def B_regularized_singularity_term(rc_prime, rc_prime_prime, regularization): def B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization): # The factors of 2π in the next few lines come from the fact that simsopt # uses a curve parameter that goes up to 1 rather than 2π. - phi = quadpoints * 2 * jnp.pi - rc = gamma + phi = quadpoints * 2 * jnp.pi + rc = gamma rc_prime = gammadash / 2 / jnp.pi rc_prime_prime = gammadashdash / 4 / jnp.pi**2 n_quad = phi.shape[0] @@ -69,6 +69,11 @@ def B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, reg second_term = jnp.cross(rc_prime_prime, rc_prime)[:, None, :] * ( 0.5 * cos_fac / (cos_fac * jnp.sum(rc_prime * rc_prime, axis=1)[:, None] + regularization)**1.5)[:, :, None] integral_term = dphi * jnp.sum(first_term + second_term, 1) + # print(jnp.any(jnp.isnan(first_term))) + # print(jnp.any(jnp.isnan(second_term))) + # print(jnp.any(jnp.isnan(integral_term))) + # print(jnp.any(jnp.isnan(analytic_term))) + # print(jnp.max(jnp.abs(first_term))) # print(jnp.max(jnp.abs(second_term))) # print(jnp.max(jnp.abs(integral_term))) diff --git a/src/simsopt/geo/curve.py b/src/simsopt/geo/curve.py index 6fa97b6c4..9bc79341e 100644 --- a/src/simsopt/geo/curve.py +++ b/src/simsopt/geo/curve.py @@ -470,6 +470,15 @@ def gamma_impl(self, gamma, quadpoints): def incremental_arclength_pure(self, dofs): gammadash = self.gammadash_jax(dofs) return jnp.linalg.norm(gammadash, axis=1) + + @property + def qps(self): + return self.quadpoints + + @qps.setter + def qps(self, new_quadpoints): + self.quadpoints = new_quadpoints + self.numquadpoints = len(new_quadpoints) def incremental_arclength(self): return self.incremental_arclength_jax(self.get_dofs()) diff --git a/src/simsopt/geo/curveplanarfourier.py b/src/simsopt/geo/curveplanarfourier.py index 3c71c28f2..7c5d16143 100644 --- a/src/simsopt/geo/curveplanarfourier.py +++ b/src/simsopt/geo/curveplanarfourier.py @@ -159,6 +159,8 @@ def center(self, gamma, gammadash): barycenter = jnp.sum(gamma * arclength[:, None], axis=0) / N / np.pi return barycenter + def set_quadpoints(self, quadpoints): + self.quadpoints = quadpoints # class JaxCurvePlanarFourier(JaxCurve): diff --git a/src/simsopt/geo/jit.py b/src/simsopt/geo/jit.py index 3f888f2f9..796fc150e 100644 --- a/src/simsopt/geo/jit.py +++ b/src/simsopt/geo/jit.py @@ -3,8 +3,8 @@ from .config import parameters -def jit(fun): +def jit(fun, **args): if parameters['jit']: - return jaxjit(fun) + return jaxjit(fun, **args) else: return fun diff --git a/tests/field/test_selffieldforces.py b/tests/field/test_selffieldforces.py index 342f18c7f..49265558e 100644 --- a/tests/field/test_selffieldforces.py +++ b/tests/field/test_selffieldforces.py @@ -211,7 +211,7 @@ def test_force_objectives(self): gammadash_norm = np.linalg.norm(coils[0].curve.gammadash(), axis=1) force_norm = np.linalg.norm(coil_force(coils[0], coils, regularization), axis=1) print("force_norm mean:", np.mean(force_norm), "max:", np.max(force_norm)) - objective_alt = (1 / p) * np.sum(np.maximum(force_norm - threshold, 0)**p * gammadash_norm) + objective_alt = (1 / p) * np.sum(np.maximum(force_norm - threshold, 0)**p * gammadash_norm) / np.shape(gammadash_norm)[0] print("objective:", objective, "objective_alt:", objective_alt, "diff:", objective - objective_alt) np.testing.assert_allclose(objective, objective_alt) @@ -260,22 +260,36 @@ def test_force_objectives(self): # # Test MixedLpCurveForce objective = 0.0 + objective2 = 0.0 objective_alt = 0.0 for i in range(len(coils)): objective += float(LpCurveForce(coils[i], coils, regularization, p=p, threshold=threshold).J()) + objective2 += float(LpCurveForce(coils[i], coils, regularization, p=p, threshold=threshold, downsample=2).J()) force_norm = np.linalg.norm(coil_force(coils[i], coils, regularization), axis=1) gammadash_norm = np.linalg.norm(coils[i].curve.gammadash(), axis=1) - objective_alt += (1 / p) * np.sum(np.maximum(force_norm - threshold, 0)**p * gammadash_norm) + objective_alt += (1 / p) * np.sum(np.maximum(force_norm - threshold, 0)**p * gammadash_norm) / np.shape(gammadash_norm)[0] regularization_list = np.ones(len(coils)) * regularization - objective_mixed = float(MixedLpCurveForce(coils[0:1], coils[1:], regularization_list[0:1], regularization_list[1:], p=p, threshold=threshold).J()) + objective_mixed = float(MixedLpCurveForce(coils[0:1], coils[1:], regularization_list[0:1], + regularization_list[1:], p=p, threshold=threshold).J()) print("objective:", objective, "objective_alt:", objective_alt, "diff:", objective - objective_alt) np.testing.assert_allclose(objective, objective_alt) + + print("objective:", objective, "objective2:", objective2, "diff:", objective - objective2) + np.testing.assert_allclose(objective, objective2, rtol=1e-2) + print("objective:", objective, "objective_mixed:", objective_mixed, "diff:", objective - objective_mixed) np.testing.assert_allclose(objective, objective_mixed) + objective_mixed = float(MixedLpCurveForce(coils[0:1], coils[1:], + regularization_list[0:1], regularization_list[1:], p=p, + threshold=threshold, downsample=2).J()) + + print("objective:", objective, "objective_mixed:", objective_mixed, "diff:", objective - objective_mixed) + np.testing.assert_allclose(objective, objective_mixed, rtol=1e-2) + # Test SquaredMeanTorque # Scramble the orientations so the torques are nonzero @@ -459,6 +473,7 @@ def objectives_time_test(self): def test_update_points(self): """Confirm that Biot-Savart evaluation points are updated when the curve shapes change.""" + from simsopt.field import BiotSavart nfp = 4 ncoils = 3 I = 1.7e4 @@ -472,7 +487,8 @@ def test_update_points(self): objective = objective_class(coils[0], coils, regularization) old_objective_value = objective.J() - old_biot_savart_points = objective.biotsavart.get_points_cart() + biotsavart = BiotSavart(objective.othercoils) + old_biot_savart_points = biotsavart.get_points_cart() # A deterministic random shift to the coil dofs: shift = np.array([-0.06797948, -0.0808704 , -0.02680599, -0.02775893, -0.0325402 , @@ -488,7 +504,8 @@ def test_update_points(self): objective.x = objective.x + shift assert abs(objective.J() - old_objective_value) > 1e-6 - new_biot_savart_points = objective.biotsavart.get_points_cart() + biotsavart = BiotSavart(objective.othercoils) + new_biot_savart_points = biotsavart.get_points_cart() assert not np.allclose(old_biot_savart_points, new_biot_savart_points) # Objective2 is created directly at the new points after they are moved: objective2 = objective_class(coils[0], coils, regularization)