Skip to content

Commit

Permalink
Still struggling to find out why JaxCurves seem to spawn so many opti…
Browse files Browse the repository at this point in the history
…mizable weak references, especially when B_vjp is used in the dJ calculation in the various Force objectives. For now, seem to have got around it. Finally implemented the corrected jacobian terms for the CurvePlanarFourier objects from Alex, and these are running much faster, including with forces. No issue with generating huge numbers of child processes IF one cleans up the children spawning after every call to the force J or dJ calls. Code ready for a dramatic clean up and finalization.
  • Loading branch information
akaptano committed Nov 4, 2024
1 parent cc527f7 commit 3f19636
Show file tree
Hide file tree
Showing 10 changed files with 552 additions and 1,379 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
# Create the initial coils:
base_curves = create_equally_spaced_planar_curves(
ncoils, s.nfp, stellsym=True, R0=R0, R1=R1, order=order,
jax_flag=True
jax_flag=False
)
# for i in range(len(base_curves)):
# base_curves[i].set('x' + str(2 * order + 1), np.random.rand(1) - 0.5)
Expand Down
75 changes: 42 additions & 33 deletions examples/3_Advanced/QH_reactorscale_DA.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def initialize_coils_QH(TEST_DIR, s):
base_curves = create_equally_spaced_curves(
ncoils, s.nfp, stellsym=True,
R0=R0, R1=R1, order=order, numquadpoints=256,
jax_flag=True,
jax_flag=False,
)

base_currents = [(Current(total_current / ncoils * 1e-7) * 1e7) for _ in range(ncoils - 1)]
Expand Down Expand Up @@ -142,12 +142,12 @@ def initialize_coils_QH(TEST_DIR, s):
aa = 0.05
bb = 0.05

Nx = 6
Nx = 7
Ny = Nx
Nz = Nx
# Create the initial coils:
base_curves, all_curves = create_planar_curves_between_two_toroidal_surfaces(
s, s_inner, s_outer, Nx, Ny, Nz, order=order, coil_coil_flag=True, jax_flag=True,
s, s_inner, s_outer, Nx, Ny, Nz, order=order, coil_coil_flag=True, jax_flag=False,
# numquadpoints=10 # Defaults is (order + 1) * 40 so this halves it
)
import warnings
Expand Down Expand Up @@ -309,8 +309,8 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list):
NetTorques=coil_net_torques(coils, coils + coils_TF, regularization_rect(np.ones(len(coils)) * aa, np.ones(len(coils)) * bb), np.ones(len(coils)) * nturns)
)
# Force and Torque calculations spawn a bunch of spurious BiotSavart child objects -- erase them!
for c in (coils + coils_TF):
c._children = set()
# for c in (coils + coils_TF):
# c._children = set()

pointData = {"B_N": np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]}
s.to_vtk(OUT_DIR + "surf_init_DA", extra_data=pointData)
Expand Down Expand Up @@ -346,7 +346,7 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list):

# While the coil array is not moving around, they cannot
# interlink.
linkNum = LinkingNumber(curves + curves_TF)
linkNum = LinkingNumber(curves + curves_TF, downsample=4)

##### Note need coils_TF + coils below!!!!!!!
# Jforce2 = sum([LpCurveForce(c, coils + coils_TF,
Expand All @@ -359,7 +359,7 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list):
regularization_list2 = np.ones(len(coils_TF)) * regularization_rect(a, b)
# Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2) # [SquaredMeanForce2(c, coils) for c in (base_coils)]
# Jforce = MixedSquaredMeanForce(coils, coils_TF)
Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=4, threshold=4e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)])
Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=4, threshold=4e5 * 100, downsample=1) for i, c in enumerate(base_coils + base_coils_TF)])
Jforce2 = sum([SquaredMeanForce(c, coils + coils_TF) for c in (base_coils + base_coils_TF)])
Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=4e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)])
# Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=1e5 * 100) for i, c in enumerate(base_coils + base_coils_TF)])
Expand Down Expand Up @@ -407,29 +407,27 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list):
# print(btot.ancestors,len(btot.ancestors))
# print(JF.ancestors,len(JF.ancestors))

# Force and Torque calculations using JaxCurves spawn a bunch of spurious BiotSavart child objects
# erase them!
# for c in (coils + coils_TF):
# c._children = set()

def fun(dofs):
pr = cProfile.Profile()
pr.enable()
JF.x = dofs
# pr = cProfile.Profile()
# pr.enable()
J = JF.J()
grad = JF.dJ()
# pr.disable()
# sio = io.StringIO()
# sortby = SortKey.CUMULATIVE
# ps = pstats.Stats(pr, stream=sio).sort_stats(sortby)
# ps.print_stats(20)
# print(sio.getvalue())
# exit()
jf = Jf.J()
# print(JF._children, btot._children, btot.coils[0]._children)
length_val = LENGTH_WEIGHT.value * Jlength.J()
cc_val = CC_WEIGHT * Jccdist.J()
cs_val = CS_WEIGHT * Jcsdist.J()
link_val = LINK_WEIGHT * linkNum.J()
forces_val = FORCE_WEIGHT.value * Jforce.J()
forces_val2 = FORCE_WEIGHT2.value * Jforce2.J()
torques_val = TORQUE_WEIGHT.value * Jtorque.J()
torques_val2 = TORQUE_WEIGHT2.value * Jtorque2.J()
# forces_val2 = FORCE_WEIGHT2.value * Jforce2.J()
# torques_val = TORQUE_WEIGHT.value * Jtorque.J()
# torques_val2 = TORQUE_WEIGHT2.value * Jtorque2.J()
BdotN = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)))
BdotN_over_B = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))
) / np.mean(btot.AbsB())
Expand All @@ -442,18 +440,27 @@ def fun(dofs):
valuestr += f", csObj={cs_val:.2e}"
valuestr += f", Lk1Obj={link_val:.2e}"
valuestr += f", forceObj={forces_val:.2e}"
valuestr += f", forceObj2={forces_val2:.2e}"
valuestr += f", torqueObj={torques_val:.2e}"
valuestr += f", torqueObj2={torques_val2:.2e}"
# valuestr += f", forceObj2={forces_val2:.2e}"
# valuestr += f", torqueObj={torques_val:.2e}"
# valuestr += f", torqueObj2={torques_val2:.2e}"
outstr += f", F={Jforce.J():.2e}"
outstr += f", Fnet={Jforce2.J():.2e}"
outstr += f", T={Jtorque.J():.2e}"
outstr += f", Tnet={Jtorque2.J():.2e}"
# outstr += f", Fnet={Jforce2.J():.2e}"
# outstr += f", T={Jtorque.J():.2e}"
# outstr += f", Tnet={Jtorque2.J():.2e}"
outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist2.shortest_distance():.2f}"
outstr += f", Link Number = {linkNum.J()}"
outstr += f", ║∇J║={np.linalg.norm(grad):.1e}"
print(outstr)
print(valuestr)
pr.disable()
sio = io.StringIO()
sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=sio).sort_stats(sortby)
ps.print_stats(20)
print(sio.getvalue())
# for c in (coils + coils_TF):
# c._children = set()
# exit()
return J, grad


Expand Down Expand Up @@ -511,14 +518,14 @@ def fun(dofs):
Jlength.dJ()
t2 = time.time()
print('sum(Jls_TF) time = ', t2 - t1, ' s')
# t1 = time.time()
# Jforce.J()
# t2 = time.time()
# print('Jforces time = ', t2 - t1, ' s')
# t1 = time.time()
# Jforce.dJ()
# t2 = time.time()
# print('dJforces time = ', t2 - t1, ' s')
t1 = time.time()
Jforce.J()
t2 = time.time()
print('Jforces time = ', t2 - t1, ' s')
t1 = time.time()
Jforce.dJ()
t2 = time.time()
print('dJforces time = ', t2 - t1, ' s')
# t1 = time.time()
# Jforce2.J()
# t2 = time.time()
Expand Down Expand Up @@ -563,6 +570,8 @@ def fun(dofs):
NetForces=coil_net_forces(coils_TF, coils + coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF),
NetTorques=coil_net_torques(coils_TF, coils + coils_TF, regularization_rect(np.ones(len(coils_TF)) * a, np.ones(len(coils_TF)) * b), np.ones(len(coils_TF)) * nturns_TF),
)
# for c in (coils + coils_TF):
# c._children = set()

btot.set_points(s_plot.gamma().reshape((-1, 3)))
pointData = {"B_N": np.sum(btot.B().reshape((qphi, qtheta, 3)) * s_plot.unitnormal(), axis=2)[:, :, None]}
Expand Down
Loading

0 comments on commit 3f19636

Please sign in to comment.