Skip to content

Commit

Permalink
Made some fixes and finally have test_jaxbiotsavrt.py looking all goo…
Browse files Browse the repository at this point in the history
…d. One fix was calling coil.current.vjp directly since otherwise this is an issue with the Scaled Current class.
  • Loading branch information
akaptano committed Sep 25, 2024
1 parent eb86d16 commit 6c0d0ee
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 75 deletions.
112 changes: 45 additions & 67 deletions src/simsopt/field/biotsavart.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,25 +242,22 @@ def __init__(self, coils):
self.B_jax = lambda curve_dofs, currents, p: self.B_pure(curve_dofs, currents, p)
# self.B_impl = jit(lambda p: self.B_pure(self.get_gammas(), self.get_currents(), p))
self.dB_by_dcurvedofs_jax = jit(jacfwd(self.B_jax, argnums=0))

# B_vjp returns v * dB/dcurvedofs
# self.dB_vjp_jax = jit(lambda x, y, z, v: vjp(self.B_jax, x, y, z)[1](v)[0])
self.dB_by_dcoilcurrents_jax = jit(jacfwd(self.B_jax, argnums=1))
self.dB_by_dX_jax = jacfwd(self.B_jax, argnums=2) #jit(jacfwd(self.B_jax, argnums=2))
self.d2B_by_dXdX_jax = jit(jacfwd(self.dB_by_dX_jax, argnums=2))
self.d2B_by_dXdcoilcurrents_jax = jit(jacfwd(self.dB_by_dX_jax, argnums=1))
self.d2B_by_dXdcurvedofs_jax = jit(jacfwd(self.dB_by_dX_jax, argnums=0))
self.d3B_by_dXdXdcoilcurrents_jax = jit(jacfwd(self.d2B_by_dXdX_jax, argnums=1))

# Seems like B_jax and A_jax should not use jit, since then
# it does not update correctly when the dofs change.
self.A_jax = lambda curve_dofs, currents, p: self.A_pure(curve_dofs, currents, p)
# self.A_impl = jit(lambda p: self.A_pure(self.get_curve_dofs(), self.get_currents(), p))
# self.A_vjp_jax = jit(lambda x, v: vjp(self.A_pure, x)[1](v)[0])
self.dA_by_dcurvedofs_jax = jit(jacfwd(self.A_jax, argnums=0))
self.dA_by_dcoilcurrents_jax = jit(jacfwd(self.A_jax, argnums=1))
self.dA_by_dX_jax = jacfwd(self.A_jax, argnums=2) #jit(jacfwd(self.A_jax, argnums=2))
self.d2A_by_dXdX_jax = jit(hessian(self.A_jax, argnums=2))
self.d2A_by_dXdX_jax = jit(jacfwd(self.dA_by_dX_jax, argnums=2))
self.d2A_by_dXdcoilcurrents_jax = jit(jacfwd(self.dA_by_dX_jax, argnums=1))
self.d2A_by_dXdcurvedofs_jax = jit(jacfwd(self.dA_by_dX_jax, argnums=0))
self.d3A_by_dXdXdcoilcurrents_jax = jit(jacfwd(self.d2A_by_dXdX_jax, argnums=1))

# @jit
Expand Down Expand Up @@ -299,8 +296,8 @@ def B_vjp(self, v):
res_current = [np.sum(v * dB_by_dcoilcurrents[i]) for i in range(len(dB_by_dcoilcurrents))]
dB_by_dcurvedofs = self.dB_by_dcurvedofs()
res_curvedofs = [np.sum(np.sum(v[None, :, :] * dB_by_dcurvedofs[i], axis=-1), axis=-1) for i in range(len(dB_by_dcurvedofs))]
return sum([Derivative({coils[i].curve: res_curvedofs[i]}) + Derivative({coils[i].current: res_current[i]}) for i in range(len(coils))])

return sum([Derivative({coils[i].curve: res_curvedofs[i]}) + coils[i].current.vjp(np.array([res_current[i]])) for i in range(len(coils))])
def A(self):
return self.A_jax(self.get_curve_dofs(), self.get_currents(), self.get_points_cart_ref())

Expand All @@ -310,9 +307,7 @@ def A_vjp(self, v):
res_current = [np.sum(v * dA_by_dcoilcurrents[i]) for i in range(len(dA_by_dcoilcurrents))]
dA_by_dcurvedofs = self.dA_by_dcurvedofs()
res_curvedofs = [np.sum(np.sum(v[None, :, :] * dA_by_dcurvedofs[i], axis=-1), axis=-1) for i in range(len(dA_by_dcurvedofs))]
print(jnp.shape(v), jnp.shape(dA_by_dcoilcurrents), jnp.shape(res_current), jnp.shape(dA_by_dcurvedofs), jnp.shape(res_curvedofs))
print(Derivative({coils[0].current: res_current[0]}), coils[0].current, res_current[0])
return sum([Derivative({coils[i].curve: res_curvedofs[i]}) + Derivative({coils[i].current: res_current[i]}) for i in range(len(coils))])
return sum([Derivative({coils[i].curve: res_curvedofs[i]}) + coils[i].current.vjp(np.array([res_current[i]])) for i in range(len(coils))])

def get_dofs(self):
ll = [self._coils[i].current.get_value() for i in range(len(self._coils))]
Expand All @@ -329,7 +324,6 @@ def get_currents(self):
return jnp.array([c.current.get_value() for c in self._coils])

def get_gammadashs(self, dofs):
# curve_dofs = self.get_curve_dofs()
return jnp.array([coil.curve.gammadash_impl_jax(
dofs[i], coil.curve.quadpoints) for i, coil in enumerate(self._coils)]
)
Expand Down Expand Up @@ -362,6 +356,13 @@ def d2B_by_dXdcoilcurrents(self):
return jnp.transpose(jnp.diagonal(self.d2B_by_dXdcoilcurrents_jax(
self.get_curve_dofs(), self.get_currents(), self.get_points_cart_ref()
), axis1=0, axis2=2), axes=[2, 3, 1, 0])

def d2B_by_dXdcurvedofs(self):
r"""
"""
return jnp.transpose(jnp.diagonal(self.d2B_by_dXdcurvedofs_jax(
self.get_curve_dofs(), self.get_currents(), self.get_points_cart_ref()
), axis1=0, axis2=2), axes=[2, 3, 4, 1, 0])

def d2B_by_dXdX(self):
return jnp.transpose(jnp.diagonal(
Expand All @@ -387,28 +388,18 @@ def d3B_by_dXdXdcoilcurrents(self):
def B_and_dB_vjp(self, v, vgrad):
r"""
"""

coils = self._coils
gammas = [coil.curve.gamma() for coil in coils]
gammadashs = [coil.curve.gammadash() for coil in coils]
currents = [coil.current.get_value() for coil in coils]
res_gamma = [np.zeros_like(gamma) for gamma in gammas]
res_gammadash = [np.zeros_like(gammadash) for gammadash in gammadashs]
res_grad_gamma = [np.zeros_like(gamma) for gamma in gammas]
res_grad_gammadash = [np.zeros_like(gammadash) for gammadash in gammadashs]

points = self.get_points_cart_ref()
sopp.biot_savart_vjp_graph(points, gammas, gammadashs, currents, v,
res_gamma, res_gammadash, vgrad, res_grad_gamma, res_grad_gammadash)

dB_by_dcoilcurrents = self.dB_by_dcoilcurrents()
res_current = [np.sum(v * dB_by_dcoilcurrents[i]) for i in range(len(dB_by_dcoilcurrents))]
d2B_by_dXdcoilcurrents = self.d2B_by_dXdcoilcurrents()
res_grad_current = [np.sum(vgrad * d2B_by_dXdcoilcurrents[i]) for i in range(len(d2B_by_dXdcoilcurrents))]

dB_by_dcurvedofs = self.dB_by_dcurvedofs()
res_curvedofs = [np.sum(np.sum(v[None, :, :] * dB_by_dcurvedofs[i], axis=-1), axis=-1) for i in range(len(dB_by_dcurvedofs))]
d2B_by_dXdcurvedofs = self.d2B_by_dXdcurvedofs()
res_grad_curvedofs = [np.sum(np.sum(np.sum(vgrad[None, :, :, :] * d2B_by_dXdcurvedofs[i], axis=-1), axis=-1), axis=-1) for i in range(len(d2B_by_dXdcurvedofs))]
res = (
sum([coils[i].vjp(res_gamma[i], res_gammadash[i], np.asarray([res_current[i]])) for i in range(len(coils))]),
sum([coils[i].vjp(res_grad_gamma[i], res_grad_gammadash[i], np.asarray([res_grad_current[i]])) for i in range(len(coils))])
sum([Derivative({coils[i].curve: res_curvedofs[i]}) + coils[i].current.vjp(np.array([res_current[i]])) for i in range(len(coils))]),
sum([Derivative({coils[i].curve: res_grad_curvedofs[i]}) + coils[i].current.vjp(np.array([res_grad_current[i]])) for i in range(len(coils))]),
)
return res

Expand Down Expand Up @@ -441,6 +432,13 @@ def d2A_by_dXdcoilcurrents(self):
self.get_curve_dofs(), self.get_currents(), self.get_points_cart_ref()
), axis1=0, axis2=2), axes=[2, 3, 1, 0])

def d2A_by_dXdcurvedofs(self):
r"""
"""
return jnp.transpose(jnp.diagonal(self.d2A_by_dXdcurvedofs_jax(
self.get_curve_dofs(), self.get_currents(), self.get_points_cart_ref()
), axis1=0, axis2=2), axes=[2, 3, 4, 1, 0])

def d2A_by_dXdX(self):
return jnp.transpose(jnp.diagonal(
jnp.diagonal(self.d2A_by_dXdX_jax(
Expand All @@ -464,36 +462,20 @@ def d3A_by_dXdXdcoilcurrents(self):

def A_and_dA_vjp(self, v, vgrad):
r"""
Same as :obj:`simsopt.geo.biotsavart.BiotSavart.A_vjp` but returns the vector Jacobian product for :math:`A` and :math:`\nabla A`, i.e. it returns
.. math::
\{ \sum_{i=1}^{n} \mathbf{v}_i \cdot \partial_{\mathbf{c}_k} \mathbf{A}_i \}_k, \{ \sum_{i=1}^{n} {\mathbf{v}_\mathrm{grad}}_i \cdot \partial_{\mathbf{c}_k} \nabla \mathbf{A}_i \}_k.
"""

coils = self._coils
gammas = [coil.curve.gamma() for coil in coils]
gammadashs = [coil.curve.gammadash() for coil in coils]
currents = [coil.current.get_value() for coil in coils]
res_gamma = [np.zeros_like(gamma) for gamma in gammas]
res_gammadash = [np.zeros_like(gammadash) for gammadash in gammadashs]
res_grad_gamma = [np.zeros_like(gamma) for gamma in gammas]
res_grad_gammadash = [np.zeros_like(gammadash) for gammadash in gammadashs]

points = self.get_points_cart_ref()
sopp.biot_savart_vector_potential_vjp_graph(points, gammas, gammadashs, currents, v,
res_gamma, res_gammadash, vgrad, res_grad_gamma, res_grad_gammadash)

dA_by_dcoilcurrents = self.dA_by_dcoilcurrents()
res_current = [np.sum(v * dA_by_dcoilcurrents[i]) for i in range(len(dA_by_dcoilcurrents))]
d2A_by_dXdcoilcurrents = self.d2A_by_dXdcoilcurrents()
res_grad_current = [np.sum(vgrad * d2A_by_dXdcoilcurrents[i]) for i in range(len(d2A_by_dXdcoilcurrents))]

dA_by_dcurvedofs = self.dA_by_dcurvedofs()
res_curvedofs = [np.sum(np.sum(v[None, :, :] * dA_by_dcurvedofs[i], axis=-1), axis=-1) for i in range(len(dA_by_dcurvedofs))]
d2A_by_dXdcurvedofs = self.d2A_by_dXdcurvedofs()
res_grad_curvedofs = [np.sum(np.sum(np.sum(vgrad[None, :, :, :] * d2A_by_dXdcurvedofs[i], axis=-1), axis=-1), axis=-1) for i in range(len(d2A_by_dXdcurvedofs))]
res = (
sum([coils[i].vjp(res_gamma[i], res_gammadash[i], np.asarray([res_current[i]])) for i in range(len(coils))]),
sum([coils[i].vjp(res_grad_gamma[i], res_grad_gammadash[i], np.asarray([res_grad_current[i]])) for i in range(len(coils))])
sum([Derivative({coils[i].curve: res_curvedofs[i]}) + coils[i].current.vjp(np.array([res_current[i]])) for i in range(len(coils))]),
sum([Derivative({coils[i].curve: res_grad_curvedofs[i]}) + coils[i].current.vjp(np.array([res_grad_current[i]])) for i in range(len(coils))]),
)

return res

def as_dict(self, serial_objs_dict) -> dict:
Expand All @@ -515,23 +497,19 @@ def from_dict(cls, d, serial_objs_dict, recon_objs):
return bs

# @jit
# def coil_coil_forces_pure(self, gammas, currents):
# """
# """
# currents_i = currents
# currents_j = currents
# gammas_i = gammas
# gammas_j = gammas
# gammadashs_i = self.get_gammadashs()
# gammadashs_j = gammadashs_i
# r_ij = gammas_i[None, :, :] - gammas_j[:, None, :] # Note, do not use the i = j indices
# jnp.diag(r_ij[:, :, 0]) = 1e100
# jnp.diag(r_ij[:, :, 1]) = 1e100
# jnp.diag(r_ij[:, :, 2]) = 1e100
# F = jnp.cross(currents_i * gammadashs_i,
# jnp.cross(currents_j * gammadashs_j, r_ij)
# ) / jnp.linalg.norm(r_ij, axis=-1)[:, :, None] ** 3
# return F * 1e-7 / jnp.shape(gammas_i)[1] ** 2
def coil_coil_forces_pure(self, curve_dofs, currents):
"""
"""
gammas = self.get_gammadashs(curve_dofs)
gammadashs = self.get_gammadashs(curve_dofs)
r_ij = gammas[None, :, :] - gammas[:, None, :] # Note, do not use the i = j indices
# jnp.diag(r_ij[:, :, 0]) = 1e100
# jnp.diag(r_ij[:, :, 1]) = 1e100
# jnp.diag(r_ij[:, :, 2]) = 1e100
F = jnp.cross(currents * gammadashs,
jnp.cross(currents * gammadashs, r_ij)
) / jnp.linalg.norm(r_ij, axis=-1)[:, :, None] ** 3
return F * 1e-7 / jnp.shape(gammas)[1] ** 2

def coil_coil_inductances_pure(self, curve_dofs):
"""
Expand Down
6 changes: 3 additions & 3 deletions src/simsopt/field/coil.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,21 @@ def current(self):
return self.get_value()

self.current_pure = current_pure
self.current_jax = jit(lambda dofs: self.current_pure(dofs))
self.current_jax = lambda dofs: self.current_pure(dofs)
self.dcurrent_by_dcurrent_jax = jit(jacfwd(self.current_jax))
self.dcurrent_by_dcurrent_vjp_jax = jit(lambda x, v: vjp(self.current_jax, x)[1](v)[0])

def current_impl(self, dofs):
return self.current_jax(dofs)

def vjp(self, v):
r"""
"""
return Derivative({self: self.dcurrent_by_dcurrent_vjp_jax(self.get_dofs(), v)})

def set_dofs(self, dofs):
self.local_x = dofs
sopp.Current.set_dofs(self, dofs)
sopp.Current.set_dofs(self, dofs)

class CurrentSum(sopp.CurrentBase, CurrentBase):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/field/test_biotsavart.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def test_dA_by_dcoilcoeff_reverse_taylortest(self):
Jh = np.sum(Ah**2)
deriv_est = (Jh-J0)/eps
err_new = np.linalg.norm(deriv_est-dJ_dh)
print(err_new, err)
assert err_new < 0.55 * err
err = err_new

Expand Down
11 changes: 6 additions & 5 deletions tests/field/test_jaxbiotsavart.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_dB_by_dcoilcoeff_reverse_taylortest(self):
B2 = bs2.B()
J0 = np.sum(B**2)
dJ = bs.B_vjp(B)(curve)
dJ2 = bs.B_vjp(B2)(curve)
dJ2 = bs2.B_vjp(B2)(curve)
assert np.allclose(B, B2)
assert np.allclose(dJ, dJ2)

Expand Down Expand Up @@ -366,7 +366,7 @@ def test_biotsavart_coil_current_taylortest(self):
def test_dA_by_dcoilcoeff_reverse_taylortest(self):
np.random.seed(1)
curve = get_curve()
coil = Coil(curve, Current(1e4))
coil = Coil(curve, ScaledCurrent(Current(1), 1e4))
bs = JaxBiotSavart([coil])
bs2 = BiotSavart([coil])
points = np.asarray(17 * [[-1.41513202e-03, 8.99999382e-01, -3.14473221e-04]])
Expand All @@ -379,17 +379,18 @@ def test_dA_by_dcoilcoeff_reverse_taylortest(self):
A2 = bs2.A()
J0 = np.sum(A**2)
dJ = bs.A_vjp(A)(coil)
dJ2 = bs.A_vjp(A2)(coil)
dJ2 = bs2.A_vjp(A2)(coil)
assert np.allclose(A, A2)
assert np.allclose(dJ, dJ2)

h = 1e-2 * np.random.rand(len(coil_dofs)).reshape(coil_dofs.shape)
dJ_dh = 2*np.sum(dJ * h)
dJ_dh = 2*np.sum(dJ2 * h)
err = 1e6
for i in range(2, 10):
eps = 0.5**i
coil.x = coil_dofs + eps * h
Ah = bs.A()
Ah = bs2.A()
Ah2 = bs2.A()
Jh = np.sum(Ah**2)
deriv_est = (Jh-J0)/eps
err_new = np.linalg.norm(deriv_est-dJ_dh)
Expand Down

0 comments on commit 6c0d0ee

Please sign in to comment.