diff --git a/examples/3_Advanced/coil_force_objectives_scan.py b/examples/3_Advanced/coil_force_objectives_scan.py new file mode 100644 index 000000000..ea2148709 --- /dev/null +++ b/examples/3_Advanced/coil_force_objectives_scan.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python + +""" +Example script for the force metric in a stage-two coil optimization +""" +import os +from pathlib import Path +import shutil +from scipy.optimize import minimize +import numpy as np +from simsopt.geo import curves_to_vtk, create_equally_spaced_curves +from simsopt.geo import SurfaceRZFourier +from simsopt.field import Current, coils_via_symmetries +from simsopt.objectives import SquaredFlux, Weight, QuadraticPenalty +from simsopt.geo import (CurveLength, CurveCurveDistance, CurveSurfaceDistance, + MeanSquaredCurvature, LpCurveCurvature) +from simsopt.field import BiotSavart +from simsopt.field.force import MeanSquaredForce, coil_force, coil_torque, coil_net_torques, coil_net_forces, LpCurveForce, \ + SquaredMeanForce, MeanSquaredTorque, SquaredMeanTorque, LpCurveTorque # , TVE +from simsopt.field.selffield import regularization_circ +from simsopt.util import in_github_actions + + +############################################################################### +# INPUT PARAMETERS +############################################################################### + +# Number of unique coil shapes, i.e. the number of coils per half field period: +# (Since the configuration has nfp = 2, multiply by 4 to get the total number of coils.) +ncoils = 4 + +# Major radius for the initial circular coils: +R0 = 1.0 + +# Minor radius for the initial circular coils: +R1 = 0.5 + +# Number of Fourier modes describing each Cartesian component of each coil: +order = 5 + +# Weight on the curve lengths in the objective function. We use the `Weight` +# class here to later easily adjust the scalar value and rerun the optimization +# without having to rebuild the objective. +LENGTH_WEIGHT = Weight(1e-03) +LENGTH_TARGET = 17.4 + +# Threshold and weight for the coil-to-coil distance penalty in the objective function: +CC_THRESHOLD = 0.1 +CC_WEIGHT = 1000 + +# Threshold and weight for the coil-to-surface distance penalty in the objective function: +CS_THRESHOLD = 0.3 +CS_WEIGHT = 10 + +# Threshold and weight for the curvature penalty in the objective function: +CURVATURE_THRESHOLD = 5. +CURVATURE_WEIGHT = 1e-6 + +# Threshold and weight for the mean squared curvature penalty in the objective function: +MSC_THRESHOLD = 5 +MSC_WEIGHT = 1e-6 + +# Weight on the mean squared force penalty in the objective function +FORCE_WEIGHT = Weight(1e-14) + +# Number of iterations to perform: +MAXITER = 500 + +# File for the desired boundary magnetic surface: +TEST_DIR = (Path(__file__).parent / ".." / ".." / "tests" / "test_files").resolve() +filename = TEST_DIR / 'input.LandremanPaul2021_QA' + +# Directory for output +OUT_DIR = "./coil_forces/" +if os.path.exists(OUT_DIR): + shutil.rmtree(OUT_DIR) +os.makedirs(OUT_DIR, exist_ok=True) + +############################################################################### +# SET UP OBJECTIVE FUNCTION +############################################################################### + +# Initialize the boundary magnetic surface: +nphi = 32 +ntheta = 32 +s = SurfaceRZFourier.from_vmec_input(filename, range="half period", nphi=nphi, ntheta=ntheta) + +qphi = nphi * 2 +qtheta = ntheta * 2 +quadpoints_phi = np.linspace(0, 1, qphi, endpoint=True) +quadpoints_theta = np.linspace(0, 1, qtheta, endpoint=True) +# Make high resolution, full torus version of the plasma boundary for plotting +s_plot = SurfaceRZFourier.from_vmec_input( + filename, + quadpoints_phi=quadpoints_phi, + quadpoints_theta=quadpoints_theta +) + +# Create the initial coils: +base_curves = create_equally_spaced_curves(ncoils, s.nfp, stellsym=True, R0=R0, R1=R1, order=order) # , jax_flag=True) +base_currents = [Current(1e5) for i in range(ncoils)] +# Since the target field is zero, one possible solution is just to set all +# currents to 0. To avoid the minimizer finding that solution, we fix one +# of the currents: +base_currents[0].fix_all() + +coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) +base_coils = coils[:ncoils] +bs = BiotSavart(coils) +bs.set_points(s.gamma().reshape((-1, 3))) + +a = 0.05 + +def pointData_forces_torques(coils): + contig = np.ascontiguousarray + forces = np.zeros((len(coils), len(coils[0].curve.gamma()) + 1, 3)) + torques = np.zeros((len(coils), len(coils[0].curve.gamma()) + 1, 3)) + for i, c in enumerate(coils): + forces[i, :-1, :] = coil_force(c, coils, regularization_circ(a)) + torques[i, :-1, :] = coil_torque(c, coils, regularization_circ(a)) + + forces[:, -1, :] = forces[:, 0, :] + torques[:, -1, :] = torques[:, 0, :] + forces = forces.reshape(-1, 3) + torques = torques.reshape(-1, 3) + point_data = {"Pointwise_Forces": (contig(forces[:, 0]), contig(forces[:, 1]), contig(forces[:, 2])), + "Pointwise_Torques": (contig(torques[:, 0]), contig(torques[:, 1]), contig(torques[:, 2]))} + return point_data + +curves = [c.curve for c in coils] +curves_to_vtk( + curves, OUT_DIR + "curves_init", close=True, extra_point_data=pointData_forces_torques(coils), + NetForces=coil_net_forces(coils, regularization_circ(a)), + NetTorques=coil_net_torques(coils, regularization_circ(a)) + ) +pointData = {"B_N": np.sum(bs.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} +s.to_vtk(OUT_DIR + "surf_init", extra_data=pointData) +bs.set_points(s_plot.gamma().reshape((-1, 3))) +pointData = {"B_N": np.sum(bs.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} +s_plot.to_vtk(OUT_DIR + "surf_full_init", extra_data=pointData) +bs.set_points(s.gamma().reshape((-1, 3))) + +# Jforce = [MeanSquaredForce(c, coils, regularization_circ(a)) for c in base_coils] + +# for ii, JforceObj in enumerate([MeanSquaredForce, SquaredMeanForce]): + # Form the total objective function. To do this, we can exploit the + # fact that Optimizable objects with J() and dJ() functions can be + # multiplied by scalars and added: +ii = 1 +# Define the individual terms objective function: +Jf = SquaredFlux(s, bs) +Jls = [CurveLength(c) for c in base_curves] +Jccdist = CurveCurveDistance(curves, CC_THRESHOLD, num_basecurves=ncoils) +Jcsdist = CurveSurfaceDistance(curves, s, CS_THRESHOLD) +Jcs = [LpCurveCurvature(c, 2, CURVATURE_THRESHOLD) for c in base_curves] +Jmscs = [MeanSquaredCurvature(c) for c in base_curves] +Jforce = [LpCurveForce(c, coils, regularization_circ(a), p=2, threshold=1e5) + SquaredMeanForce(c, coils, regularization_circ(a)) for c in base_coils] +Jlength = QuadraticPenalty(sum(Jls), LENGTH_TARGET, "max") +# Jforce = [LpCurveForce(c, coils, regularization_circ(a), p=2) for c in base_coils] +# Jforce1 = [SquaredMeanForce(c, coils, regularization_circ(a)) for c in base_coils] +# Jforce2 = [MeanSquaredForce(c, coils, regularization_circ(a)) for c in base_coils] +# Jtorque = [LpCurveTorque(c, coils, regularization_circ(a), p=2) for c in base_coils] +# Jtorque1 = [SquaredMeanTorque(c, coils, regularization_circ(a)) for c in base_coils] +# Jtorque2 = [MeanSquaredTorque(c, coils, regularization_circ(a)) for c in base_coils] +JF = Jf \ + + LENGTH_WEIGHT * Jlength \ + + CC_WEIGHT * Jccdist \ + + CS_WEIGHT * Jcsdist \ + + CURVATURE_WEIGHT * sum(Jcs) \ + + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD, "max") for J in Jmscs) \ + + FORCE_WEIGHT * sum(Jforce) +#### Add Torques in here + +# We don't have a general interface in SIMSOPT for optimisation problems that +# are not in least-squares form, so we write a little wrapper function that we +# pass directly to scipy.optimize.minimize + + +def fun(dofs): + JF.x = dofs + J = JF.J() + grad = JF.dJ() + BdotN = np.mean(np.abs(np.sum(bs.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))) + BdotN_over_B = np.mean(np.abs(np.sum(bs.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)) + ) / np.mean(bs.AbsB()) + outstr = f"J={J:.1e}, Jf={Jf.J():.1e}, ⟨B·n⟩={BdotN:.1e}, ⟨B·n⟩/⟨B⟩={BdotN_over_B:.1e}" + cl_string = ", ".join([f"{J.J():.1f}" for J in Jls]) + outstr += f", Len=sum([{cl_string}])={sum(J.J() for J in Jls):.2f}" + outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist.shortest_distance():.2f}" + length_val = LENGTH_WEIGHT.value * Jlength.J() + cc_val = CC_WEIGHT * Jccdist.J() + cs_val = CS_WEIGHT * Jcsdist.J() + forces_val = FORCE_WEIGHT.value * sum(J.J() for J in Jforce) + valuestr = f"J={J:.2e}, Jf={Jf.J():.2e}" + valuestr += f", LenObj={length_val:.2e}" + valuestr += f", ccObj={cc_val:.2e}" + valuestr += f", csObj={cs_val:.2e}" + valuestr += f", forceObj={forces_val:.2e}" + # outstr += f", Link Number = {linkNum.J()}" + # outstr += f", Link Number 2 = {linkNum2.J()}" + outstr += f", F={sum(J.J() for J in Jforce):.2e}" + # outstr += f", T={sum(J.J() for J in Jtorque):.2e}" + # outstr += f", TVE={Jtve.J():.1e}" + outstr += f", ║∇J║={np.linalg.norm(grad):.1e}" + print(outstr) + print(valuestr) + return J, grad + + +print(""" +############################################################################### +# Perform a Taylor test +############################################################################### +""") +print("(It make take jax several minutes to compile the objective for the first evaluation.)") +f = fun +dofs = JF.x +np.random.seed(1) +h = np.random.uniform(size=dofs.shape) +J0, dJ0 = f(dofs) +dJh = sum(dJ0 * h) +for eps in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]: + J1, _ = f(dofs + eps*h) + J2, _ = f(dofs - eps*h) + print("err", (J1-J2)/(2*eps) - dJh) + +############################################################################### +# RUN THE OPTIMIZATION +############################################################################### + + +dofs = JF.x +print(f"Optimization with FORCE_WEIGHT={FORCE_WEIGHT.value} and LENGTH_WEIGHT={LENGTH_WEIGHT.value}") +# print("INITIAL OPTIMIZATION") +res = minimize(fun, dofs, jac=True, method='L-BFGS-B', options={'maxiter': MAXITER, 'maxcor': 300}, tol=1e-15) +curves_to_vtk(curves, OUT_DIR + "curves_opt"+str(ii), close=True, extra_point_data=pointData_forces_torques(coils), + NetForces=coil_net_forces(coils, regularization_circ(a)), + NetTorques=coil_net_torques(coils, regularization_circ(a)) + ) + +pointData_surf = {"B_N": np.sum(bs.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} +s.to_vtk(OUT_DIR + "surf_opt"+str(ii), extra_data=pointData_surf) +bs.set_points(s_plot.gamma().reshape((-1, 3))) +pointData = {"B_N": np.sum(bs.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]} +s_plot.to_vtk(OUT_DIR + "surf_full_opt"+str(ii), extra_data=pointData) +bs.set_points(s.gamma().reshape((-1, 3))) +base_curves = create_equally_spaced_curves(ncoils, s.nfp, stellsym=True, R0=R0, R1=R1, order=order) #, jax_flag=True) +base_currents = [Current(1e5) for i in range(ncoils)] +base_currents[0].fix_all() +coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True) +base_coils = coils[:ncoils] +bs = BiotSavart(coils) +bs.set_points(s.gamma().reshape((-1, 3))) + +# Save the optimized coil shapes and currents so they can be loaded into other scripts for analysis: +# bs.save(OUT_DIR + "biot_savart_opt.json") + +#Print out final important info: +# JF.x = dofs +# J = JF.J() +# grad = JF.dJ() +# jf = Jf.J() +# BdotN = np.mean(np.abs(np.sum(bs.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))) +# force = [np.max(np.linalg.norm(coil_force(c, coils, regularization_circ(a)), axis=1)) for c in base_coils] +# outstr = f"J={J:.1e}, Jf={jf:.1e}, ⟨B·n⟩={BdotN:.1e}" +# cl_string = ", ".join([f"{J.J():.1f}" for J in Jls]) +# kap_string = ", ".join(f"{np.max(c.kappa()):.1f}" for c in base_curves) +# msc_string = ", ".join(f"{J.J():.1f}" for J in Jmscs) +# jforce_string = ", ".join(f"{J.J():.2e}" for J in Jforce) +# force_string = ", ".join(f"{f:.2e}" for f in force) +# outstr += f", Len=sum([{cl_string}])={sum(J.J() for J in Jls):.1f}, ϰ=[{kap_string}], ∫ϰ²/L=[{msc_string}], Jforce=[{jforce_string}], force=[{force_string}]" +# outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist.shortest_distance():.2f}" +# outstr += f", ║∇J║={np.linalg.norm(grad):.1e}" +# print(outstr) diff --git a/examples/3_Advanced/coil_forces.py b/examples/3_Advanced/coil_forces.py index f2b917bfb..393d7028e 100755 --- a/examples/3_Advanced/coil_forces.py +++ b/examples/3_Advanced/coil_forces.py @@ -14,7 +14,8 @@ from simsopt.geo import (CurveLength, CurveCurveDistance, CurveSurfaceDistance, MeanSquaredCurvature, LpCurveCurvature) from simsopt.field import BiotSavart -from simsopt.field.force import MeanSquaredForce, coil_force, LpCurveForce +from simsopt.field.force import MeanSquaredForce, coil_force, coil_torque, coil_net_torques, coil_net_forces, LpCurveForce, \ + SquaredMeanForce, MeanSquaredTorque, SquaredMeanTorque, LpCurveTorque # , TVE from simsopt.field.selffield import regularization_circ from simsopt.util import in_github_actions @@ -69,7 +70,7 @@ filename = TEST_DIR / 'input.LandremanPaul2021_QA' # Directory for output -OUT_DIR = "./output/" +OUT_DIR = "./coil_forces/" os.makedirs(OUT_DIR, exist_ok=True) @@ -83,7 +84,7 @@ s = SurfaceRZFourier.from_vmec_input(filename, range="half period", nphi=nphi, ntheta=ntheta) # Create the initial coils: -base_curves = create_equally_spaced_curves(ncoils, s.nfp, stellsym=True, R0=R0, R1=R1, order=order) +base_curves = create_equally_spaced_curves(ncoils, s.nfp, stellsym=True, R0=R0, R1=R1, order=order, jax_flag=True) base_currents = [Current(1e5) for i in range(ncoils)] # Since the target field is zero, one possible solution is just to set all # currents to 0. To avoid the minimizer finding that solution, we fix one @@ -95,8 +96,27 @@ bs = BiotSavart(coils) bs.set_points(s.gamma().reshape((-1, 3))) +a = 0.05 + +def pointData_forces_torques(coils): + forces = [] + torques = [] + for c in coils: + force = np.linalg.norm(coil_force(c, coils, regularization_circ(a)), axis=1) + torque = np.linalg.norm(coil_torque(c, coils, regularization_circ(a)), axis=1) + force = np.append(force, force[0]) + torque = np.append(torque, torque[0]) + torques = np.concatenate([torques, torque]) + forces = np.concatenate([forces, force]) + point_data = {"Pointwise_Forces": forces, "Pointwise_Torques": torques} + return point_data + curves = [c.curve for c in coils] -curves_to_vtk(curves, OUT_DIR + "curves_init", close=True) +curves_to_vtk( + curves, OUT_DIR + "curves_init", close=True, extra_point_data=pointData_forces_torques(coils), + NetForces=coil_net_forces(coils, regularization_circ(a)), + NetTorques=coil_net_torques(coils, regularization_circ(a)) + ) pointData = {"B_N": np.sum(bs.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} s.to_vtk(OUT_DIR + "surf_init", extra_data=pointData) @@ -107,8 +127,14 @@ Jcsdist = CurveSurfaceDistance(curves, s, CS_THRESHOLD) Jcs = [LpCurveCurvature(c, 2, CURVATURE_THRESHOLD) for c in base_curves] Jmscs = [MeanSquaredCurvature(c) for c in base_curves] -Jforce = [LpCurveForce(c, coils, regularization_circ(0.05), p=4) for c in base_coils] -# Jforce = [MeanSquaredForce(c, coils, regularization_circ(0.05)) for c in base_coils] +Jforce = [LpCurveForce(c, coils, regularization_circ(a), p=4) for c in base_coils] +Jforce1 = [SquaredMeanForce(c, coils, regularization_circ(a)) for c in base_coils] +Jforce2 = [MeanSquaredForce(c, coils, regularization_circ(a)) for c in base_coils] +Jtorque = [MeanSquaredTorque(c, coils, regularization_circ(a)) for c in base_coils] +Jtorque1 = [SquaredMeanTorque(c, coils, regularization_circ(a)) for c in base_coils] +Jtorque2 = [LpCurveTorque(c, coils, regularization_circ(a), p=4) for c in base_coils] + +# Jforce = [MeanSquaredForce(c, coils, regularization_circ(a)) for c in base_coils] # Form the total objective function. To do this, we can exploit the @@ -121,6 +147,7 @@ + CURVATURE_WEIGHT * sum(Jcs) \ + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD, "max") for J in Jmscs) \ + FORCE_WEIGHT * sum(Jforce) +#### Add Torques in here # We don't have a general interface in SIMSOPT for optimisation problems that # are not in least-squares form, so we write a little wrapper function that we @@ -131,46 +158,54 @@ def fun(dofs): JF.x = dofs J = JF.J() grad = JF.dJ() + BdotN = np.mean(np.abs(np.sum(bs.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))) + BdotN_over_B = np.mean(np.abs(np.sum(bs.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)) + ) / np.mean(bs.AbsB()) + outstr = f"J={J:.1e}, Jf={Jf.J():.1e}, ⟨B·n⟩={BdotN:.1e}, ⟨B·n⟩/⟨B⟩={BdotN_over_B:.1e}" + cl_string = ", ".join([f"{J.J():.1f}" for J in Jls]) + outstr += f", Len=sum([{cl_string}])={sum(J.J() for J in Jls):.2f}" + outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist.shortest_distance():.2f}" + # outstr += f", Link Number = {linkNum.J()}" + # outstr += f", Link Number 2 = {linkNum2.J()}" + outstr += f", F={sum(J.J() for J in Jforce):.2e}" + outstr += f", T={sum(J.J() for J in Jtorque):.2e}" + # outstr += f", TVE={Jtve.J():.1e}" + outstr += f", ║∇J║={np.linalg.norm(grad):.1e}" + print(outstr) return J, grad -# print(""" -# ############################################################################### -# # Perform a Taylor test -# ############################################################################### -# """) -# print("(It make take jax several minutes to compile the objective for the first evaluation.)") -# f = fun -# dofs = JF.x -# np.random.seed(1) -# h = np.random.uniform(size=dofs.shape) -# J0, dJ0 = f(dofs) -# dJh = sum(dJ0 * h) -# for eps in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]: -# J1, _ = f(dofs + eps*h) -# J2, _ = f(dofs - eps*h) -# print("err", (J1-J2)/(2*eps) - dJh) +print(""" +############################################################################### +# Perform a Taylor test +############################################################################### +""") +print("(It make take jax several minutes to compile the objective for the first evaluation.)") +f = fun +dofs = JF.x +np.random.seed(1) +h = np.random.uniform(size=dofs.shape) +J0, dJ0 = f(dofs) +dJh = sum(dJ0 * h) +for eps in [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]: + J1, _ = f(dofs + eps*h) + J2, _ = f(dofs - eps*h) + print("err", (J1-J2)/(2*eps) - dJh) ############################################################################### # RUN THE OPTIMIZATION ############################################################################### -def pointData_forces(coils): - forces = [] - for c in coils: - force = np.linalg.norm(coil_force(c, coils, regularization_circ(0.05)), axis=1) - force = np.append(force, force[0]) - forces = np.concatenate([forces, force]) - point_data = {"F": forces} - return point_data - - dofs = JF.x print(f"Optimization with FORCE_WEIGHT={FORCE_WEIGHT.value} and LENGTH_WEIGHT={LENGTH_WEIGHT.value}") # print("INITIAL OPTIMIZATION") res = minimize(fun, dofs, jac=True, method='L-BFGS-B', options={'maxiter': MAXITER, 'maxcor': 300}, tol=1e-15) -curves_to_vtk(curves, OUT_DIR + "curves_opt_short", close=True, extra_data=pointData_forces(coils)) +curves_to_vtk(curves, OUT_DIR + "curves_opt_short", close=True, extra_point_data=pointData_forces_torques(coils), + NetForces=coil_net_forces(coils, regularization_circ(a)), + NetTorques=coil_net_torques(coils, regularization_circ(a)) + ) + pointData_surf = {"B_N": np.sum(bs.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} s.to_vtk(OUT_DIR + "surf_opt_short", extra_data=pointData_surf) @@ -181,7 +216,11 @@ def pointData_forces(coils): LENGTH_WEIGHT *= 0.1 # print("OPTIMIZATION WITH REDUCED LENGTH PENALTY\n") res = minimize(fun, dofs, jac=True, method='L-BFGS-B', options={'maxiter': MAXITER, 'maxcor': 300}, tol=1e-15) -curves_to_vtk(curves, OUT_DIR + f"curves_opt_force_FWEIGHT={FORCE_WEIGHT.value:e}_LWEIGHT={LENGTH_WEIGHT.value*10:e}", close=True, extra_data=pointData_forces(coils)) +curves_to_vtk(curves, OUT_DIR + f"curves_opt_force_FWEIGHT={FORCE_WEIGHT.value:e}_LWEIGHT={LENGTH_WEIGHT.value*10:e}", close=True, + extra_point_data=pointData_forces_torques(coils), + NetForces=coil_net_forces(coils, regularization_circ(a)), + NetTorques=coil_net_torques(coils, regularization_circ(a)) + ) pointData_surf = {"B_N": np.sum(bs.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]} s.to_vtk(OUT_DIR + f"surf_opt_force_WEIGHT={FORCE_WEIGHT.value:e}_LWEIGHT={LENGTH_WEIGHT.value*10:e}", extra_data=pointData_surf) @@ -194,7 +233,7 @@ def pointData_forces(coils): grad = JF.dJ() jf = Jf.J() BdotN = np.mean(np.abs(np.sum(bs.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))) -force = [np.max(np.linalg.norm(coil_force(c, coils, regularization_circ(0.05)), axis=1)) for c in base_coils] +force = [np.max(np.linalg.norm(coil_force(c, coils, regularization_circ(a)), axis=1)) for c in base_coils] outstr = f"J={J:.1e}, Jf={jf:.1e}, ⟨B·n⟩={BdotN:.1e}" cl_string = ", ".join([f"{J.J():.1f}" for J in Jls]) kap_string = ", ".join(f"{np.max(c.kappa()):.1f}" for c in base_curves) diff --git a/src/simsopt/field/biotsavart.py b/src/simsopt/field/biotsavart.py index 46de8dcae..b6c9cb081 100644 --- a/src/simsopt/field/biotsavart.py +++ b/src/simsopt/field/biotsavart.py @@ -339,7 +339,7 @@ def B_vjp(self, v): t2 = time.time() print('Current dJ time = ', t2 - t1) t1 = time.time() - # dB_by_dcurvedofs = self.dB_by_dcurvedofs() + dB_by_dcurvedofs = self.dB_by_dcurvedofs() # t2 = time.time() # print('Curve dJ time = ', t2 - t1) # print(jnp.shape(dB_by_dcurvedofs)) @@ -350,9 +350,9 @@ def B_vjp(self, v): # jnp.shape(vjp(self.B_jax_reduced, self.get_curve_dofs())[1](v))) # res_curvedofs = self.dB_by_dcurvedofs_vjp_impl(v) # print('res_curvedofs size = ', jnp.shape(res_curvedofs)) - print(jnp.shape(v), jnp.shape(self.dB_by_dcurvedofs())) - res_curvedofs = jnp.sum(jnp.sum(jnp.sum(v[None, None, :, :] * self.dB_by_dcurvedofs(), axis=-1), axis=-1), axis=-1) - print(jnp.shape(res_curvedofs), len(res_curvedofs)) + # print(jnp.shape(v), jnp.shape(self.dB_by_dcurvedofs())) + res_curvedofs = [jnp.sum(jnp.sum(v[None, :, :] * dB_by_dcurvedofs[i], axis=-1), axis=-1) for i in range(len(dB_by_dcurvedofs))] + # print(jnp.shape(res_curvedofs), len(res_curvedofs)) curve_derivs = [Derivative({coils[i].curve: res_curvedofs[i]}) for i in range(len(res_curvedofs))] current_derivs = [coils[i].current.vjp(np.array([res_current[i]])) for i in range(len(coils))] # t2 = time.time() @@ -368,18 +368,18 @@ def B_vjp_jax(self, v): t2 = time.time() print('Current dJ time = ', t2 - t1) t1 = time.time() - # dB_by_dcurvedofs = self.dB_by_dcurvedofs() + dB_by_dcurvedofs = self.dB_by_dcurvedofs() # t2 = time.time() # print('Curve dJ time = ', t2 - t1) # print(jnp.shape(dB_by_dcurvedofs)) # t1 = time.time() - print(jnp.shape(v), jnp.shape(self.get_curve_dofs())) + # print(jnp.shape(v), jnp.shape(self.get_curve_dofs())) # print(jnp.shape(vjp(self.B_pure_reduced, self.get_curve_dofs())[1])) # print(vjp(self.B_pure_reduced, self.get_curve_dofs())[1](v), # jnp.shape(vjp(self.B_jax_reduced, self.get_curve_dofs())[1](v))) - res_curvedofs = self.dB_by_dcurvedofs_vjp_impl(v) - print('res_curvedofs size = ', jnp.shape(res_curvedofs)) - #res_curvedofs = [np.sum(np.sum(v[None, :, :] * dB_by_dcurvedofs[i], axis=-1), axis=-1) for i in range(len(dB_by_dcurvedofs))] + # res_curvedofs = self.dB_by_dcurvedofs_vjp_impl(v) + # print('res_curvedofs size = ', jnp.shape(res_curvedofs)) + res_curvedofs = [np.sum(np.sum(v[None, :, :] * dB_by_dcurvedofs[i], axis=-1), axis=-1) for i in range(len(dB_by_dcurvedofs))] curve_derivs = [Derivative({coils[i].curve: res_curvedofs[i]}) for i in range(len(res_curvedofs))] current_derivs = [coils[i].current.vjp(np.array([res_current[i]])) for i in range(len(coils))] t2 = time.time() diff --git a/src/simsopt/field/force.py b/src/simsopt/field/force.py index caed1d643..b892d4d36 100644 --- a/src/simsopt/field/force.py +++ b/src/simsopt/field/force.py @@ -22,12 +22,32 @@ def coil_force(coil, allcoils, regularization): selfforce = self_force(coil, regularization) return selfforce + mutualforce +def coil_net_forces(allcoils, regularization): + net_forces = np.zeros((len(allcoils), 3)) + for i, coil in enumerate(allcoils): + Fi = coil_force(coil, allcoils, regularization) + gammadash = coil.curve.gammadash() + gammadash_norm = np.linalg.norm(gammadash, axis=1)[:, None] + net_forces[i, :] += np.sum(gammadash_norm * Fi, axis=0) / gammadash.shape[0] + return net_forces + +def coil_torque(coil, allcoils, regularization): + gamma = coil.curve.gamma() + return np.cross(gamma, coil_force(coil, allcoils, regularization)) + +def coil_net_torques(allcoils, regularization): + net_torques = np.zeros((len(allcoils), 3)) + for i, coil in enumerate(allcoils): + Ti = coil_torque(coil, allcoils, regularization) + gammadash = coil.curve.gammadash() + gammadash_norm = np.linalg.norm(gammadash, axis=1)[:, None] + net_torques[i, :] += np.sum(gammadash_norm * Ti, axis=0) / gammadash.shape[0] + return net_torques def coil_force_pure(B, I, t): """force on coil for optimization""" return jnp.cross(I * t, B) - def self_force(coil, regularization): """ Compute the self-force of a coil. @@ -67,7 +87,7 @@ def lp_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regulari tangent = gammadash / gammadash_norm force = jnp.cross(current * tangent, B_self + B_mutual) force_norm = jnp.linalg.norm(force, axis=1)[:, None] - return (jnp.sum(jnp.maximum(force_norm - threshold, 0)**p * gammadash_norm))*(1./p) + return (jnp.sum(jnp.maximum(force_norm - threshold, 0)**p * gammadash_norm)) / jnp.sum(gammadash_norm) class LpCurveForce(Optimizable): @@ -82,7 +102,7 @@ class LpCurveForce(Optimizable): and :math:`\ell` is arclength along the coil. """ - def __init__(self, coil, allcoils, regularization, p=1.0, threshold=0.0): + def __init__(self, coil, allcoils, regularization, p=2.0, threshold=0.0): self.coil = coil self.allcoils = allcoils self.othercoils = [c for c in allcoils if c is not coil] @@ -270,4 +290,530 @@ def dJ(self): + self.biotsavart.B_vjp(dJ_dB) ) - return_fn_map = {'J': J, 'dJ': dJ} \ No newline at end of file + return_fn_map = {'J': J, 'dJ': dJ} + +@jit +def squared_mean_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual): + r""" + """ + B_self = B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization) + gammadash_norm = jnp.linalg.norm(gammadash, axis=1)[:, None] + tangent = gammadash / gammadash_norm + force = jnp.cross(current * tangent, B_self + B_mutual) + # force_norm = jnp.linalg.norm(force, axis=1)[:, None] + return jnp.linalg.norm(jnp.sum(force * gammadash_norm, axis=0)) ** 2 / jnp.sum(gammadash_norm) # factor for the integral + +class SquaredMeanForce(Optimizable): + r"""Optimizable class to minimize the net Lorentz force on a coil. + + The objective function is + + .. math: + J = (\frac{\int \vec{F}_i d\ell)^2 + + where :math:`\vec{F}` is the Lorentz force and :math:`\ell` is arclength + along the coil. + """ + + def __init__(self, coil, allcoils, regularization): + self.coil = coil + self.allcoils = allcoils + self.othercoils = [c for c in allcoils if c is not coil] + self.biotsavart = BiotSavart(self.othercoils) + quadpoints = self.coil.curve.quadpoints + + self.J_jax = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + squared_mean_force_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual) + ) + + self.dJ_dgamma = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dgammadash = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dgammadashdash = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dcurrent = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=3)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dB_mutual = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=4)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + super().__init__(depends_on=allcoils) + + def J(self): + self.biotsavart.set_points(self.coil.curve.gamma()) + + args = [ + self.coil.curve.gamma(), + self.coil.curve.gammadash(), + self.coil.curve.gammadashdash(), + self.coil.current.get_value(), + self.biotsavart.B() + ] + + return self.J_jax(*args) + + @derivative_dec + def dJ(self): + self.biotsavart.set_points(self.coil.curve.gamma()) + + args = [ + self.coil.curve.gamma(), + self.coil.curve.gammadash(), + self.coil.curve.gammadashdash(), + self.coil.current.get_value(), + self.biotsavart.B() + ] + + dJ_dB = self.dJ_dB_mutual(*args) + dB_dX = self.biotsavart.dB_by_dX() + dJ_dX = np.einsum('ij,ikj->ik', dJ_dB, dB_dX) + + return ( + self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args) + dJ_dX) + + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args)) + + self.coil.curve.dgammadashdash_by_dcoeff_vjp(self.dJ_dgammadashdash(*args)) + + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args)])) + + self.biotsavart.B_vjp(dJ_dB) + ) + + return_fn_map = {'J': J, 'dJ': dJ} + +@jit +def squared_mean_torque_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual): + r""" + """ + B_self = B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization) + gammadash_norm = jnp.linalg.norm(gammadash, axis=1)[:, None] + tangent = gammadash / gammadash_norm + force = jnp.cross(current * tangent, B_self + B_mutual) + torque = jnp.cross(gamma, force) + return jnp.linalg.norm(jnp.sum(gammadash_norm * torque, axis=0)) ** 2 / jnp.sum(gammadash_norm) # factor for the integral + +class SquaredMeanTorque(Optimizable): + r"""Optimizable class to minimize the net Lorentz force on a coil. + + The objective function is + + .. math: + J = (\frac{\int \vec{F}_i d\ell)^2 + + where :math:`\vec{F}` is the Lorentz force and :math:`\ell` is arclength + along the coil. + """ + + def __init__(self, coil, allcoils, regularization): + self.coil = coil + self.allcoils = allcoils + self.othercoils = [c for c in allcoils if c is not coil] + self.biotsavart = BiotSavart(self.othercoils) + quadpoints = self.coil.curve.quadpoints + + self.J_jax = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + squared_mean_torque_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual) + ) + + self.dJ_dgamma = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dgammadash = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dgammadashdash = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dcurrent = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=3)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dB_mutual = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=4)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + super().__init__(depends_on=allcoils) + + def J(self): + self.biotsavart.set_points(self.coil.curve.gamma()) + + args = [ + self.coil.curve.gamma(), + self.coil.curve.gammadash(), + self.coil.curve.gammadashdash(), + self.coil.current.get_value(), + self.biotsavart.B() + ] + + return self.J_jax(*args) + + @derivative_dec + def dJ(self): + self.biotsavart.set_points(self.coil.curve.gamma()) + + args = [ + self.coil.curve.gamma(), + self.coil.curve.gammadash(), + self.coil.curve.gammadashdash(), + self.coil.current.get_value(), + self.biotsavart.B() + ] + + dJ_dB = self.dJ_dB_mutual(*args) + dB_dX = self.biotsavart.dB_by_dX() + dJ_dX = np.einsum('ij,ikj->ik', dJ_dB, dB_dX) + + return ( + self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args) + dJ_dX) + + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args)) + + self.coil.curve.dgammadashdash_by_dcoeff_vjp(self.dJ_dgammadashdash(*args)) + + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args)])) + + self.biotsavart.B_vjp(dJ_dB) + ) + + return_fn_map = {'J': J, 'dJ': dJ} + +@jit +def mean_squared_torque_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual): + r""" + """ + B_self = B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization) + gammadash_norm = jnp.linalg.norm(gammadash, axis=1)[:, None] + tangent = gammadash / gammadash_norm + force = jnp.cross(current * tangent, B_self + B_mutual) + torque = jnp.cross(gamma, force) + torque_norm = jnp.linalg.norm(torque, axis=1)[:, None] + return jnp.sum(gammadash_norm * torque_norm ** 2) / jnp.sum(gammadash_norm) + +class MeanSquaredTorque(Optimizable): + r"""Optimizable class to minimize the net Lorentz force on a coil. + + The objective function is + + .. math: + J = (\frac{\int \vec{F}_i d\ell)^2 + + where :math:`\vec{F}` is the Lorentz force and :math:`\ell` is arclength + along the coil. + """ + + def __init__(self, coil, allcoils, regularization): + self.coil = coil + self.allcoils = allcoils + self.othercoils = [c for c in allcoils if c is not coil] + self.biotsavart = BiotSavart(self.othercoils) + quadpoints = self.coil.curve.quadpoints + + self.J_jax = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + mean_squared_torque_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual) + ) + + self.dJ_dgamma = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dgammadash = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dgammadashdash = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dcurrent = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=3)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dB_mutual = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=4)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + super().__init__(depends_on=allcoils) + + def J(self): + self.biotsavart.set_points(self.coil.curve.gamma()) + + args = [ + self.coil.curve.gamma(), + self.coil.curve.gammadash(), + self.coil.curve.gammadashdash(), + self.coil.current.get_value(), + self.biotsavart.B() + ] + + return self.J_jax(*args) + + @derivative_dec + def dJ(self): + self.biotsavart.set_points(self.coil.curve.gamma()) + + args = [ + self.coil.curve.gamma(), + self.coil.curve.gammadash(), + self.coil.curve.gammadashdash(), + self.coil.current.get_value(), + self.biotsavart.B() + ] + + dJ_dB = self.dJ_dB_mutual(*args) + dB_dX = self.biotsavart.dB_by_dX() + dJ_dX = np.einsum('ij,ikj->ik', dJ_dB, dB_dX) + + return ( + self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args) + dJ_dX) + + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args)) + + self.coil.curve.dgammadashdash_by_dcoeff_vjp(self.dJ_dgammadashdash(*args)) + + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args)])) + + self.biotsavart.B_vjp(dJ_dB) + ) + + return_fn_map = {'J': J, 'dJ': dJ} + +@jit +def lp_torque_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold): + r"""Pure function for minimizing the Lorentz force on a coil. + + The function is + + .. math:: + J = \frac{1}{p}\left(\int \text{max}(|\vec{T}| - T_0, 0)^p d\ell\right) + + where :math:`\vec{T}` is the Lorentz torque, :math:`T_0` is a threshold torque, + and :math:`\ell` is arclength along the coil. + """ + B_self = B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization) + gammadash_norm = jnp.linalg.norm(gammadash, axis=1)[:, None] + tangent = gammadash / gammadash_norm + force = jnp.cross(current * tangent, B_self + B_mutual) + torque = jnp.cross(gamma, force) + torque_norm = jnp.linalg.norm(torque, axis=1)[:, None] + return (jnp.sum(jnp.maximum(torque_norm - threshold, 0)**p * gammadash_norm)) / jnp.sum(gammadash_norm) + + +class LpCurveTorque(Optimizable): + r""" Optimizable class to minimize the Lorentz force on a coil. + + The objective function is + + .. math:: + J = \frac{1}{p}\left(\int \text{max}(|\vec{F}| - F_0, 0)^p d\ell\right) + + where :math:`\vec{F}` is the Lorentz force, :math:`F_0` is a threshold force, + and :math:`\ell` is arclength along the coil. + """ + + def __init__(self, coil, allcoils, regularization, p=2.0, threshold=0.0): + self.coil = coil + self.allcoils = allcoils + self.othercoils = [c for c in allcoils if c is not coil] + self.biotsavart = BiotSavart(self.othercoils) + quadpoints = self.coil.curve.quadpoints + + self.J_jax = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + lp_torque_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold) + ) + + self.dJ_dgamma = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dgammadash = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dgammadashdash = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dcurrent = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=3)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + self.dJ_dB_mutual = jit( + lambda gamma, gammadash, gammadashdash, current, B_mutual: + grad(self.J_jax, argnums=4)(gamma, gammadash, gammadashdash, current, B_mutual) + ) + + super().__init__(depends_on=allcoils) + + def J(self): + self.biotsavart.set_points(self.coil.curve.gamma()) + + args = [ + self.coil.curve.gamma(), + self.coil.curve.gammadash(), + self.coil.curve.gammadashdash(), + self.coil.current.get_value(), + self.biotsavart.B() + ] + + return self.J_jax(*args) + + @derivative_dec + def dJ(self): + self.biotsavart.set_points(self.coil.curve.gamma()) + + args = [ + self.coil.curve.gamma(), + self.coil.curve.gammadash(), + self.coil.curve.gammadashdash(), + self.coil.current.get_value(), + self.biotsavart.B() + ] + + dJ_dB = self.dJ_dB_mutual(*args) + dB_dX = self.biotsavart.dB_by_dX() + dJ_dX = np.einsum('ij,ikj->ik', dJ_dB, dB_dX) + + return ( + self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args) + dJ_dX) + + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args)) + + self.coil.curve.dgammadashdash_by_dcoeff_vjp(self.dJ_dgammadashdash(*args)) + + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args)])) + + self.biotsavart.B_vjp(dJ_dB) + ) + + return_fn_map = {'J': J, 'dJ': dJ} + + +# @jit +# def tve_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold): +# r"""Pure function for minimizing the Lorentz force on a coil. + +# The function is + +# .. math:: +# J = \frac{1}{p}\left(\int \text{max}(|\vec{T}| - T_0, 0)^p d\ell\right) + +# where :math:`\vec{T}` is the Lorentz torque, :math:`T_0` is a threshold torque, +# and :math:`\ell` is arclength along the coil. +# """ +# B_self = B_regularized_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization) +# gammadash_norm = jnp.linalg.norm(gammadash, axis=1)[:, None] +# tangent = gammadash / gammadash_norm +# force = jnp.cross(current * tangent, B_self + B_mutual) +# torque = jnp.cross(gamma, force) +# torque_norm = jnp.linalg.norm(torque, axis=1)[:, None] +# return (jnp.sum(jnp.maximum(torque_norm - threshold, 0)**p * gammadash_norm))*(1./p) + +# class TVE(Optimizable): +# r""" Optimizable class to minimize the Lorentz force on a coil. + +# The objective function is + +# .. math:: +# J = 0.5 * I_i * I_j * Lij + +# where :math:`\vec{F}` is the Lorentz force, :math:`F_0` is a threshold force, +# and :math:`\ell` is arclength along the coil. +# """ + +# def __init__(self, coil, allcoils, regularization, p=1.0, threshold=0.0): +# self.coil = coil +# self.allcoils = allcoils +# self.othercoils = [c for c in allcoils if c is not coil] +# self.biotsavart = BiotSavart(self.othercoils) +# quadpoints = self.coil.curve.quadpoints + +# self.J_jax = jit( +# lambda gamma, gammadash, gammadashdash, current, B_mutual: +# tve_pure(gamma, gammadash, gammadashdash, quadpoints, current, regularization, B_mutual, p, threshold) +# ) + +# self.dJ_dgamma = jit( +# lambda gamma, gammadash, gammadashdash, current, B_mutual: +# grad(self.J_jax, argnums=0)(gamma, gammadash, gammadashdash, current, B_mutual) +# ) + +# self.dJ_dgammadash = jit( +# lambda gamma, gammadash, gammadashdash, current, B_mutual: +# grad(self.J_jax, argnums=1)(gamma, gammadash, gammadashdash, current, B_mutual) +# ) + +# self.dJ_dgammadashdash = jit( +# lambda gamma, gammadash, gammadashdash, current, B_mutual: +# grad(self.J_jax, argnums=2)(gamma, gammadash, gammadashdash, current, B_mutual) +# ) + +# self.dJ_dcurrent = jit( +# lambda gamma, gammadash, gammadashdash, current, B_mutual: +# grad(self.J_jax, argnums=3)(gamma, gammadash, gammadashdash, current, B_mutual) +# ) + +# self.dJ_dB_mutual = jit( +# lambda gamma, gammadash, gammadashdash, current, B_mutual: +# grad(self.J_jax, argnums=4)(gamma, gammadash, gammadashdash, current, B_mutual) +# ) + +# super().__init__(depends_on=allcoils) + +# def J(self): +# self.biotsavart.set_points(self.coil.curve.gamma()) + +# args = [ +# self.coil.curve.gamma(), +# self.coil.curve.gammadash(), +# self.coil.curve.gammadashdash(), +# self.coil.current.get_value(), +# self.biotsavart.B() +# ] + +# return self.J_jax(*args) + +# @derivative_dec +# def dJ(self): +# self.biotsavart.set_points(self.coil.curve.gamma()) + +# args = [ +# self.coil.curve.gamma(), +# self.coil.curve.gammadash(), +# self.coil.curve.gammadashdash(), +# self.coil.current.get_value(), +# self.biotsavart.B() +# ] + +# dJ_dB = self.dJ_dB_mutual(*args) +# dB_dX = self.biotsavart.dB_by_dX() +# dJ_dX = np.einsum('ij,ikj->ik', dJ_dB, dB_dX) + +# return ( +# self.coil.curve.dgamma_by_dcoeff_vjp(self.dJ_dgamma(*args) + dJ_dX) +# + self.coil.curve.dgammadash_by_dcoeff_vjp(self.dJ_dgammadash(*args)) +# + self.coil.curve.dgammadashdash_by_dcoeff_vjp(self.dJ_dgammadashdash(*args)) +# + self.coil.current.vjp(jnp.asarray([self.dJ_dcurrent(*args)])) +# + self.biotsavart.B_vjp(dJ_dB) +# ) + +# return_fn_map = {'J': J, 'dJ': dJ}