Skip to content

Commit

Permalink
Tried speeding up some of the biotsavart stuff but no luck with jax s…
Browse files Browse the repository at this point in the history
…o far.
  • Loading branch information
akaptano committed Oct 3, 2024
1 parent c18c648 commit 2de04a3
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# filename = TEST_DIR / input_name

# Directory for output
OUT_DIR = "./dipole_array_optimization_QA_reactorScale_noForceTorqueOptimization/"
OUT_DIR = "./dipole_array_optimization_QA_reactorScale_debug/"
if os.path.exists(OUT_DIR):
shutil.rmtree(OUT_DIR)
os.makedirs(OUT_DIR, exist_ok=True)
Expand Down Expand Up @@ -189,9 +189,9 @@ def initialize_coils_QA(TEST_DIR, s):
close=True,
NetForces=bs.coil_coil_forces(),
NetTorques=bs.coil_coil_torques(),
MixedCoilForces=CoilCoilNetForces12(bs, bs_TF).coil_coil_forces12()[:len(curves), :],
MixedCoilTorques=CoilCoilNetTorques12(bs, bs_TF).coil_coil_torques12()[:len(curves), :],
NetSelfForces=bs.coil_self_forces(a, b)
# MixedCoilForces=CoilCoilNetForces12(bs, bs_TF).coil_coil_forces12()[:len(curves), :],
# MixedCoilTorques=CoilCoilNetTorques12(bs, bs_TF).coil_coil_torques12()[:len(curves), :],
# NetSelfForces=bs.coil_self_forces(a, b)
)
pointData = {"B_N": np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)[:, :, None]}
s.to_vtk(OUT_DIR + "surf_init_DA", extra_data=pointData)
Expand Down Expand Up @@ -253,8 +253,8 @@ def initialize_coils_QA(TEST_DIR, s):
# interlink.
linkNum = LinkingNumber(curves_TF)
linkNum2 = LinkingNumber(curves)
# Jforces = CoilCoilNetForces(bs) + CoilCoilNetForces12(bs, bs_TF) + CoilCoilNetForces(bs_TF)
# Jtorques = CoilCoilNetTorques(bs) + CoilCoilNetTorques12(bs, bs_TF) + CoilCoilNetTorques(bs_TF)
Jforces = CoilCoilNetForces(bs) + CoilCoilNetForces12(bs, bs_TF) + CoilCoilNetForces(bs_TF)
Jtorques = CoilCoilNetTorques(bs) + CoilCoilNetTorques12(bs, bs_TF) + CoilCoilNetTorques(bs_TF)
# Jtve = TotalVacuumEnergy(bs, a=a, b=b)
# Jsf = CoilSelfNetForces(bs, a=a, b=b)

Expand Down Expand Up @@ -288,9 +288,9 @@ def initialize_coils_QA(TEST_DIR, s):
+ CS_WEIGHT * Jcsdist \
+ LINK_WEIGHT * linkNum \
+ LINK_WEIGHT2 * linkNum2 \
+ LENGTH_WEIGHT * sum(Jls_TF) # \
# + FORCES_WEIGHT * Jforces \
# + TORQUES_WEIGHT * Jtorques
+ LENGTH_WEIGHT * sum(Jls_TF) \
+ FORCES_WEIGHT * Jforces \
+ TORQUES_WEIGHT * Jtorques
# + TVE_WEIGHT * Jtve
# + SF_WEIGHT * Jsf
# + CURRENTS_WEIGHT * DipoleJaxCurrentsObj
Expand All @@ -299,6 +299,51 @@ def initialize_coils_QA(TEST_DIR, s):
# + MSC_WEIGHT * sum(QuadraticPenalty(J, MSC_THRESHOLD) for J in Jmscs) \
# + CURVATURE_WEIGHT * sum(Jcs) \

print('Timing calls: ')
t1 = time.time()
Jf.J()
t2 = time.time()
print('Jf time = ', t2 - t1, ' s')
t1 = time.time()
Jf.dJ()
t2 = time.time()
print('dJf time = ', t2 - t1, ' s')
t1 = time.time()
Jccdist.J()
Jccdist.dJ()
t2 = time.time()
print('Jcc time = ', t2 - t1, ' s')
t1 = time.time()
Jcsdist.J()
Jcsdist.dJ()
t2 = time.time()
print('Jcs time = ', t2 - t1, ' s')
t1 = time.time()
linkNum.J()
linkNum.dJ()
t2 = time.time()
print('linkNum time = ', t2 - t1, ' s')
t1 = time.time()
linkNum2.J()
linkNum2.dJ()
t2 = time.time()
print('linkNum2 time = ', t2 - t1, ' s')
t1 = time.time()
sum(Jls_TF).J()
sum(Jls_TF).dJ()
t2 = time.time()
print('sum(Jls_TF) time = ', t2 - t1, ' s')
t1 = time.time()
Jforces.J()
Jforces.dJ()
t2 = time.time()
print('Jforces time = ', t2 - t1, ' s')
t1 = time.time()
Jtorques.J()
Jtorques.dJ()
t2 = time.time()
print('Jtorques time = ', t2 - t1, ' s')

# We don't have a general interface in SIMSOPT for optimisation problems that
# are not in least-squares form, so we write a little wrapper function that we
# pass directly to scipy.optimize.minimize
Expand All @@ -314,8 +359,8 @@ def fun(dofs):
cs_val = CS_WEIGHT * Jcsdist.J()
link_val1 = LINK_WEIGHT * linkNum.J()
link_val2 = LINK_WEIGHT2 * linkNum2.J()
# forces_val = FORCES_WEIGHT * Jforces.J()
# torques_val = TORQUES_WEIGHT * Jtorques.J()
forces_val = FORCES_WEIGHT * Jforces.J()
torques_val = TORQUES_WEIGHT * Jtorques.J()
# tve_val = TVE_WEIGHT * Jtve.J()
# sf_val = SF_WEIGHT * Jsf.J()
BdotN = np.mean(np.abs(np.sum(btot.B().reshape((nphi, ntheta, 3)) * s.unitnormal(), axis=2)))
Expand All @@ -330,8 +375,8 @@ def fun(dofs):
valuestr += f", csObj={cs_val:.2e}"
valuestr += f", Lk1Obj={link_val1:.2e}"
valuestr += f", Lk2Obj={link_val2:.2e}"
# valuestr += f", forceObj={forces_val:.2e}"
# valuestr += f", torqueObj={torques_val:.2e}"
valuestr += f", forceObj={forces_val:.2e}"
valuestr += f", torqueObj={torques_val:.2e}"
# valuestr += f", tveObj={tve_val:.2e}"
# valuestr += f", sfObj={sf_val:.2e}"
# valuestr += f", currObj={curr_val:.2e}"
Expand All @@ -343,8 +388,8 @@ def fun(dofs):
outstr += f", C-C-Sep={Jccdist.shortest_distance():.2f}, C-S-Sep={Jcsdist.shortest_distance():.2f}"
outstr += f", Link Number = {linkNum.J()}"
outstr += f", Link Number 2 = {linkNum2.J()}"
# outstr += f", C-C-Forces={Jforces.J():.1e}"
# outstr += f", C-C-Torques={Jtorques.J():.1e}"
outstr += f", C-C-Forces={Jforces.J():.1e}"
outstr += f", C-C-Torques={Jtorques.J():.1e}"
# outstr += f", TVE={Jtve.J():.1e}"
# outstr += f", TotalSelfForces={Jsf.J():.1e}"
outstr += f", ║∇J║={np.linalg.norm(grad):.1e}"
Expand All @@ -369,6 +414,52 @@ def fun(dofs):
J2, _ = f(dofs - eps*h)
print("err", (J1-J2)/(2*eps) - dJh) #(J1-J2)/(2*eps), dJh, (J1-J2)/(2*eps) - dJh)

print('Timing calls, Round Two: ')
t1 = time.time()
Jf.J()
t2 = time.time()
print('Jf time = ', t2 - t1, ' s')
t1 = time.time()
Jf.dJ()
t2 = time.time()
print('dJf time = ', t2 - t1, ' s')
t1 = time.time()
Jccdist.J()
Jccdist.dJ()
t2 = time.time()
print('Jcc time = ', t2 - t1, ' s')
t1 = time.time()
Jcsdist.J()
Jcsdist.dJ()
t2 = time.time()
print('Jcs time = ', t2 - t1, ' s')
t1 = time.time()
linkNum.J()
linkNum.dJ()
t2 = time.time()
print('linkNum time = ', t2 - t1, ' s')
t1 = time.time()
linkNum2.J()
linkNum2.dJ()
t2 = time.time()
print('linkNum2 time = ', t2 - t1, ' s')
t1 = time.time()
sum(Jls_TF).J()
sum(Jls_TF).dJ()
t2 = time.time()
print('sum(Jls_TF) time = ', t2 - t1, ' s')
t1 = time.time()
Jforces.J()
Jforces.dJ()
t2 = time.time()
print('Jforces time = ', t2 - t1, ' s')
t1 = time.time()
Jtorques.J()
Jtorques.dJ()
t2 = time.time()
print('Jtorques time = ', t2 - t1, ' s')
exit()

print("""
################################################################################
### Run the optimisation #######################################################
Expand All @@ -389,18 +480,18 @@ def fun(dofs):
I=dipole_currents,
NetForces=np.array(bs.coil_coil_forces()),
NetTorques=bs.coil_coil_torques(),
MixedCoilForces=CoilCoilNetForces12(bs, bs_TF).coil_coil_forces12()[:len(curves), :],
MixedCoilTorques=CoilCoilNetTorques12(bs, bs_TF).coil_coil_torques12()[:len(curves), :],
NetSelfForces=bs.coil_self_forces(a, b)
# MixedCoilForces=CoilCoilNetForces12(bs, bs_TF).coil_coil_forces12()[:len(curves), :],
# MixedCoilTorques=CoilCoilNetTorques12(bs, bs_TF).coil_coil_torques12()[:len(curves), :],
# NetSelfForces=bs.coil_self_forces(a, b)
)
curves_to_vtk([c.curve for c in bs_TF.coils], OUT_DIR + "curves_TF_{0:d}".format(i),
close=True,
I=[c.current.get_value() for c in bs_TF.coils],
NetForces=np.array(bs_TF.coil_coil_forces()),
NetTorques=bs_TF.coil_coil_torques(),
MixedCoilForces=CoilCoilNetForces12(bs, bs_TF).coil_coil_forces12()[len(curves):, :],
MixedCoilTorques=CoilCoilNetTorques12(bs, bs_TF).coil_coil_torques12()[len(curves):, :],
NetSelfForces=bs_TF.coil_self_forces(a, b)
# MixedCoilForces=CoilCoilNetForces12(bs, bs_TF).coil_coil_forces12()[len(curves):, :],
# MixedCoilTorques=CoilCoilNetTorques12(bs, bs_TF).coil_coil_torques12()[len(curves):, :],
# NetSelfForces=bs_TF.coil_self_forces(a, b)
)

btot.set_points(s_plot.gamma().reshape((-1, 3)))
Expand Down
121 changes: 105 additions & 16 deletions src/simsopt/field/biotsavart.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,13 @@ def __init__(self, coils):
MagneticField.__init__(self, depends_on=coils)

self.B_jax = jit(lambda curve_dofs, currents, p: self.B_pure(curve_dofs, currents, p))
self.B_jax_reduced = jit(lambda curve_dofs: self.B_pure_reduced(curve_dofs))
# self.get_dofs_jax = jit(self.get_dofs)
# self.B_impl = jit(lambda p: self.B_pure(self.get_gammas(), self.get_currents(), p))
self.dB_by_dcurvedofs_jax = jit(jacfwd(self.B_jax, argnums=0))
self.dB_by_dcoilcurrents_jax = jit(jacfwd(self.B_jax, argnums=1))
self.dB_by_dX_jax = jit(jacfwd(self.B_jax, argnums=2)) #jit(jacfwd(self.B_jax, argnums=2))
self.dB_by_dcurvedofs_vjp_jax = jit(lambda x, v: vjp(self.B_jax_reduced, x)[1](v)[0])
self.d2B_by_dXdX_jax = jit(jacfwd(self.dB_by_dX_jax, argnums=2))
self.d2B_by_dXdcoilcurrents_jax = jit(jacfwd(self.dB_by_dX_jax, argnums=1))
self.d2B_by_dXdcurvedofs_jax = jit(jacfwd(self.dB_by_dX_jax, argnums=0))
Expand Down Expand Up @@ -280,39 +283,118 @@ def B_pure(self, curve_dofs, currents, points):
"""
# First two axes, over the total number of coils and the integral over the coil
# # points, should be summed
if currents is None:
currents = self.get_currents()
if points is None:
points = self.get_points_cart_ref()
gammas = self.get_gammas(curve_dofs)
p_minus_g = points[None, None, :, :] - gammas[:, :, None, :]
denom = 1.0 / jnp.linalg.norm(p_minus_g
, axis=-1)[:, :, :, None] ** 3
B = jnp.sum(currents[:, None, None] * jnp.sum(jnp.cross(
-p_minus_g,
self.get_gammadashs(curve_dofs)[:, :, None, :]
) * denom, axis=1), axis=0)
self.get_gammadashs(curve_dofs)[:, :, None, :],
p_minus_g
) / jnp.linalg.norm(p_minus_g
, axis=-1)[:, :, :, None] ** 3, axis=1), axis=0)
return B * 1e-7 / jnp.shape(gammas)[1] # Divide by number of quadpoints

def B_pure_reduced(self, curve_dofs):
"""
Assumes that the quadpoints are uniformly space on the curve!
"""
# First two axes, over the total number of coils and the integral over the coil
# # points, should be summed
points = self.get_points_cart_ref()
currents = self.get_currents()
gammas = self.get_gammas(curve_dofs)
p_minus_g = points[None, None, :, :] - gammas[:, :, None, :]
B = jnp.sum(currents[:, None, None] * jnp.sum(jnp.cross(
self.get_gammadashs(curve_dofs)[:, :, None, :],
p_minus_g
) / jnp.linalg.norm(p_minus_g
, axis=-1)[:, :, :, None] ** 3, axis=1), axis=0)
return B * 1e-7 / jnp.shape(gammas)[1] # Divide by number of quadpoints

# @jit
def A_pure(self, curve_dofs, currents, points):
"""
"""
gammas = self.get_gammas(curve_dofs)
gammadashs = self.get_gammadashs(curve_dofs)
A = jnp.sum(currents[:, None, None] * jnp.sum(gammadashs[:, :, None, :] / jnp.linalg.norm(
gammas[:, :, None, :] - points[None, None, :, :], axis=-1)[:, :, :, None], axis=1), axis=0)
return A * 1e-7 / jnp.shape(gammas)[1] # Divide by number of quadpoints
return jnp.sum(currents[:, None, None] * jnp.sum(
self.get_gammadashs(curve_dofs)[:, :, None, :] / jnp.linalg.norm(
gammas[:, :, None, :] - points[None, None, :, :], axis=-1
)[:, :, :, None], axis=1), axis=0) * 1e-7 / jnp.shape(gammas)[1] # Divide by number of quadpoints

# @jit
def B(self):
return np.array(self.B_jax(self.get_curve_dofs(), self.get_currents(), self.get_points_cart_ref()))
return self.B_jax(self.get_curve_dofs(), self.get_currents(), self.get_points_cart_ref())
#np.array(self.B_jax(self.get_curve_dofs(), self.get_currents(), self.get_points_cart_ref()))

def B_vjp(self, v):
import time
t1 = time.time()
coils = self._coils
# dB_by_dcoilcurrents = self.dB_by_dcoilcurrents()
# print(jnp.shape(v), jnp.shape(dB_by_dcoilcurrents), jnp.shape(v * dB_by_dcoilcurrents[0]))
res_current = jnp.sum(jnp.sum(v[None, :, :] * self.dB_by_dcoilcurrents(), axis=-1), axis=-1)
# print(jnp.shape(res_current))
t2 = time.time()
print('Current dJ time = ', t2 - t1)
t1 = time.time()
# dB_by_dcurvedofs = self.dB_by_dcurvedofs()
# t2 = time.time()
# print('Curve dJ time = ', t2 - t1)
# print(jnp.shape(dB_by_dcurvedofs))
# t1 = time.time()
# print(jnp.shape(v), jnp.shape(self.get_curve_dofs()))
# print(jnp.shape(vjp(self.B_pure_reduced, self.get_curve_dofs())[1]))
# print(vjp(self.B_pure_reduced, self.get_curve_dofs())[1](v),
# jnp.shape(vjp(self.B_jax_reduced, self.get_curve_dofs())[1](v)))
# res_curvedofs = self.dB_by_dcurvedofs_vjp_impl(v)
# print('res_curvedofs size = ', jnp.shape(res_curvedofs))
print(jnp.shape(v), jnp.shape(self.dB_by_dcurvedofs()))
res_curvedofs = jnp.sum(jnp.sum(jnp.sum(v[None, None, :, :] * self.dB_by_dcurvedofs(), axis=-1), axis=-1), axis=-1)
print(jnp.shape(res_curvedofs), len(res_curvedofs))
curve_derivs = [Derivative({coils[i].curve: res_curvedofs[i]}) for i in range(len(res_curvedofs))]
current_derivs = [coils[i].current.vjp(np.array([res_current[i]])) for i in range(len(coils))]
# t2 = time.time()
# print('dJ construction time = ', t2 - t1)
return sum(curve_derivs + current_derivs)

def B_vjp_jax(self, v):
import time
t1 = time.time()
coils = self._coils
dB_by_dcoilcurrents = self.dB_by_dcoilcurrents()
res_current = [np.sum(v * dB_by_dcoilcurrents[i]) for i in range(len(dB_by_dcoilcurrents))]
dB_by_dcurvedofs = self.dB_by_dcurvedofs()
res_curvedofs = [np.sum(np.sum(v[None, :, :] * dB_by_dcurvedofs[i], axis=-1), axis=-1) for i in range(len(dB_by_dcurvedofs))]
t2 = time.time()
print('Current dJ time = ', t2 - t1)
t1 = time.time()
# dB_by_dcurvedofs = self.dB_by_dcurvedofs()
# t2 = time.time()
# print('Curve dJ time = ', t2 - t1)
# print(jnp.shape(dB_by_dcurvedofs))
# t1 = time.time()
print(jnp.shape(v), jnp.shape(self.get_curve_dofs()))
# print(jnp.shape(vjp(self.B_pure_reduced, self.get_curve_dofs())[1]))
# print(vjp(self.B_pure_reduced, self.get_curve_dofs())[1](v),
# jnp.shape(vjp(self.B_jax_reduced, self.get_curve_dofs())[1](v)))
res_curvedofs = self.dB_by_dcurvedofs_vjp_impl(v)
print('res_curvedofs size = ', jnp.shape(res_curvedofs))
#res_curvedofs = [np.sum(np.sum(v[None, :, :] * dB_by_dcurvedofs[i], axis=-1), axis=-1) for i in range(len(dB_by_dcurvedofs))]
curve_derivs = [Derivative({coils[i].curve: res_curvedofs[i]}) for i in range(len(res_curvedofs))]
current_derivs = [coils[i].current.vjp(np.array([res_current[i]])) for i in range(len(coils))]
t2 = time.time()
print('dJ construction time = ', t2 - t1)
return sum(curve_derivs + current_derivs)


def dB_by_dcurvedofs_vjp_impl(self, v):
r"""
"""
return self.dB_by_dcurvedofs_vjp_jax(
self.get_curve_dofs(), v
)

# @jit
def A(self):
return self.A_jax(self.get_curve_dofs(), self.get_currents(), self.get_points_cart_ref())

Expand All @@ -326,10 +408,10 @@ def A_vjp(self, v):
current_derivs = [coils[i].current.vjp(np.array([res_current[i]])) for i in range(len(coils))]
return sum(curve_derivs + current_derivs)

def get_dofs(self):
ll = [self._coils[i].current.get_value() for i in range(len(self._coils))]
lc = [self._coils[i].curve.get_dofs() for i in range(len(self._coils))]
return (ll + lc)
# def get_dofs(self):
# ll = [self._coils[i].current.get_value() for i in range(len(self._coils))]
# lc = [self._coils[i].curve.get_dofs() for i in range(len(self._coils))]
# return (ll + lc)

def get_curve_dofs(self):
# get the dofs of the UNIQUE coils (the rest of the coils don't have dofs because
Expand Down Expand Up @@ -394,7 +476,14 @@ def dB_by_dcoilcurrents(self):

def dB_by_dcurvedofs(self):
r"""
Returns matrix of shape (ncoils, 3, 3, npoints)
"""
# import time
# t1 = time.time()
# self.dB_by_dcurvedofs_jax(
# self.get_curve_dofs(), self.get_currents(), self.get_points_cart_ref())
# t2 = time.time()
# print('Time just to call the function = ', t2 - t1)
return jnp.transpose(self.dB_by_dcurvedofs_jax(
self.get_curve_dofs(), self.get_currents(), self.get_points_cart_ref()),
axes=[2, 3, 0, 1])
Expand Down
2 changes: 1 addition & 1 deletion src/simsopt/field/coil.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def current(self):
return self.get_value()

self.current_pure = current_pure
self.current_jax = lambda dofs: self.current_pure(dofs)
self.current_jax = jit(lambda dofs: self.current_pure(dofs))
self.dcurrent_by_dcurrent_jax = jit(jacfwd(self.current_jax))
self.dcurrent_by_dcurrent_vjp_jax = jit(lambda x, v: vjp(self.current_jax, x)[1](v)[0])

Expand Down
Loading

0 comments on commit 2de04a3

Please sign in to comment.