Skip to content

Commit

Permalink
Added updated examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
akaptano committed Oct 23, 2024
1 parent 97d9acb commit 9801fe5
Show file tree
Hide file tree
Showing 5 changed files with 1,127 additions and 62 deletions.
71 changes: 43 additions & 28 deletions examples/3_Advanced/QA_reactorScale_dipoleArrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ def initialize_coils_QA(TEST_DIR, s):
jax_flag=True,
)

# base_currents = [(Current(total_current / ncoils * 1e-7) * 1e7) for _ in range(ncoils - 1)]
base_currents = [(Current(total_current / ncoils * 1e-7) * 1e7) for _ in range(ncoils)]
base_currents[0].fix_all()
base_currents = [(Current(total_current / ncoils * 1e-7) * 1e7) for _ in range(ncoils - 1)]
# base_currents = [(Current(total_current / ncoils * 1e-7) * 1e7) for _ in range(ncoils)]
# base_currents[0].fix_all()

# total_current = Current(total_current)
# total_current.fix_all()
# base_currents += [total_current - sum(base_currents)]
total_current = Current(total_current)
total_current.fix_all()
base_currents += [total_current - sum(base_currents)]
coils = coils_via_symmetries(base_curves, base_currents, s.nfp, True)
# for c in coils:
# c.current.fix_all()
Expand All @@ -121,7 +121,8 @@ def initialize_coils_QA(TEST_DIR, s):

# initialize the coils
base_curves_TF, curves_TF, coils_TF, currents_TF = initialize_coils_QA(TEST_DIR, s)
base_coils_TF = [coils_TF[0]]
num_TF_unique_coils = len(coils_TF) // 4
base_coils_TF = coils_TF[:num_TF_unique_coils]
currents_TF = np.array([coil.current.get_value() for coil in coils_TF])

# # Set up BiotSavart fields
Expand All @@ -134,12 +135,12 @@ def initialize_coils_QA(TEST_DIR, s):
# Only need this if make self forces and TVE nonzero in the objective!
a = 0.2
b = 0.2
nturns = 40
nturns = 100
nturns_TF = 200

# wire cross section for the dipole coils should be more like 5 cm x 5 cm
aa = 0.04
bb = 0.04
aa = 0.05
bb = 0.05

Nx = 6
Ny = Nx
Expand Down Expand Up @@ -242,16 +243,18 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list):
CS_THRESHOLD = 1.5
CS_WEIGHT = 1e2
# Weight for the Coil Coil forces term
FORCE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
FORCE_WEIGHT2 = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
TORQUE_WEIGHT = Weight(0.0) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
FORCE_WEIGHT = Weight(1e-22) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
FORCE_WEIGHT2 = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
TORQUE_WEIGHT = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
TORQUE_WEIGHT2 = Weight(1e-24) # Forces are in Newtons, and typical values are ~10^5, 10^6 Newtons
# Directory for output
OUT_DIR = ("./QA_n{:d}_p{:.2e}_c{:.2e}_lw{:.2e}_lt{:.2e}_lkw{:.2e}" + \
"_cct{:.2e}_ccw{:.2e}_cst{:.2e}_csw{:.2e}_fw{:.2e}_fww{:2e}_tw{:.2e}/").format(
"_cct{:.2e}_ccw{:.2e}_cst{:.2e}_csw{:.2e}_fw{:.2e}_fww{:2e}_tw{:.2e}_tww{:2e}/").format(
ncoils, poff, coff, LENGTH_WEIGHT.value, LENGTH_TARGET, LINK_WEIGHT,
CC_THRESHOLD, CC_WEIGHT, CS_THRESHOLD, CS_WEIGHT, FORCE_WEIGHT.value,
FORCE_WEIGHT2.value,
TORQUE_WEIGHT.value)
TORQUE_WEIGHT.value,
TORQUE_WEIGHT2.value)
if os.path.exists(OUT_DIR):
shutil.rmtree(OUT_DIR)
os.makedirs(OUT_DIR, exist_ok=True)
Expand Down Expand Up @@ -323,7 +326,10 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list):
# Jforce = MixedLpCurveForce(coils, coils_TF, regularization_list, regularization_list2) # [SquaredMeanForce2(c, coils) for c in (base_coils)]
# Jforce = MixedSquaredMeanForce(coils, coils_TF)
Jforce = sum([LpCurveForce(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=1e5 * 40) for i, c in enumerate(base_coils + base_coils_TF)])
Jforce2 = sum([SquaredMeanForce(c, coils + coils_TF) for c in (base_coils + base_coils_TF)])
Jtorque = sum([LpCurveTorque(c, coils + coils_TF, regularization_rect(a_list[i], b_list[i]), p=2, threshold=1e5 * 40) for i, c in enumerate(base_coils + base_coils_TF)])
Jtorque2 = sum([SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)])

# Jtorque = SquaredMeanTorque2(coils, coils_TF) # [SquaredMeanForce2(c, coils) for c in (base_coils)]
# Jtorque = [SquaredMeanTorque(c, coils + coils_TF) for c in (base_coils + base_coils_TF)]

Expand All @@ -335,10 +341,15 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list):

if FORCE_WEIGHT.value > 0.0:
JF += FORCE_WEIGHT.value * Jforce #\
# + FORCE_WEIGHT2.value * Jforce2

if FORCE_WEIGHT2.value > 0.0:
JF += FORCE_WEIGHT2.value * Jforce2 #\

if TORQUE_WEIGHT.value > 0.0:
JF += TORQUE_WEIGHT * Jtorque

if TORQUE_WEIGHT2.value > 0.0:
JF += TORQUE_WEIGHT2 * Jtorque2
# + TVE_WEIGHT * Jtve
# + SF_WEIGHT * Jsf
# + CURRENTS_WEIGHT * DipoleCurrentsObj
Expand Down Expand Up @@ -376,8 +387,9 @@ def fun(dofs):
cs_val = CS_WEIGHT * Jcsdist.J()
link_val = LINK_WEIGHT * linkNum.J()
forces_val = FORCE_WEIGHT.value * Jforce.J()
# forces_val2 = FORCE_WEIGHT2.value * Jforce2.J()
forces_val2 = FORCE_WEIGHT2.value * Jforce2.J()
torques_val = TORQUE_WEIGHT.value * Jtorque.J()
torques_val2 = TORQUE_WEIGHT2.value * Jtorque2.J()
BdotN = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)))
BdotN_over_B = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2))
) / np.mean(btot.AbsB())
Expand All @@ -390,11 +402,13 @@ def fun(dofs):
valuestr += f", csObj={cs_val:.2e}"
valuestr += f", Lk1Obj={link_val:.2e}"
valuestr += f", forceObj={forces_val:.2e}"
# valuestr += f", forceObj2={forces_val2:.2e}"
valuestr += f", forceObj2={forces_val2:.2e}"
valuestr += f", torqueObj={torques_val:.2e}"
valuestr += f", torqueObj2={torques_val2:.2e}"
outstr += f", F={Jforce.J():.2e}"
# outstr += f", Fpointwise={Jforce2.J():.2e}"
outstr += f", Fnet={Jforce2.J():.2e}"
outstr += f", T={Jtorque.J():.2e}"
outstr += f", Tnet={Jtorque2.J():.2e}"
outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist.shortest_distance():.2f}"
outstr += f", Link Number = {linkNum.J()}"
outstr += f", ║∇J║={np.linalg.norm(grad):.1e}"
Expand Down Expand Up @@ -465,14 +479,14 @@ def fun(dofs):
Jforce.dJ()
t2 = time.time()
print('dJforces time = ', t2 - t1, ' s')
# t1 = time.time()
# Jforce2.J()
# t2 = time.time()
# print('Jforces2 time = ', t2 - t1, ' s')
# t1 = time.time()
# Jforce2.dJ()
# t2 = time.time()
# print('dJforces2 time = ', t2 - t1, ' s')
t1 = time.time()
Jforce2.J()
t2 = time.time()
print('Jforces2 time = ', t2 - t1, ' s')
t1 = time.time()
Jforce2.dJ()
t2 = time.time()
print('dJforces2 time = ', t2 - t1, ' s')
t1 = time.time()
Jtorque.J()
t2 = time.time()
Expand All @@ -487,7 +501,7 @@ def fun(dofs):
for i in range(1, n_saves + 1):
print('Iteration ' + str(i) + ' / ' + str(n_saves))
res = minimize(fun, dofs, jac=True, method='L-BFGS-B',
options={'maxiter': MAXITER, 'maxcor': 1000}, tol=1e-15)
options={'maxiter': MAXITER, 'maxcor': 300}, tol=1e-15)
# dofs = res.x

dipole_currents = [c.current.get_value() for c in bs.coils]
Expand Down Expand Up @@ -534,4 +548,5 @@ def fun(dofs):
t2 = time.time()
print('Total time = ', t2 - t1)
btot.save("biot_savart_optimized_QA.json")
print(OUT_DIR)

Loading

0 comments on commit 9801fe5

Please sign in to comment.