Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Giuliani committed Oct 6, 2023
1 parent 085177e commit e237c4d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
28 changes: 18 additions & 10 deletions src/simsopt/_core/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,20 +176,28 @@ def __call__(self, optim, as_derivative=False):
"""
from .optimizable import Optimizable # Import here to avoid circular import
assert isinstance(optim, Optimizable)
derivs = []
keys = []
for k in optim.unique_dof_lineage:
if np.any(k.dofs_free_status):
local_derivs = np.zeros(k.local_dof_size)

if not as_derivative:
derivs = []
keys = []
for k in optim.unique_dof_lineage:
if np.any(k.dofs_free_status):
local_derivs = np.zeros(k.local_dof_size)
for opt in k.dofs.dep_opts():
local_derivs += self.data[opt][opt.local_dofs_free_status]
keys.append(opt)
derivs.append(local_derivs)
return np.concatenate(derivs)
else:
derivs = []
keys = []
for k in optim.unique_dof_lineage:
local_derivs = np.zeros(k.local_full_dof_size)
for opt in k.dofs.dep_opts():
local_derivs += self.data[opt][opt.local_dofs_free_status]
local_derivs += self.data[opt]
keys.append(opt)
derivs.append(local_derivs)

if as_derivative:
return Derivative({k: d for k, d in zip(keys, derivs)})
else:
return np.concatenate(derivs)

# https://stackoverflow.com/questions/11624955/avoiding-python-sum-default-start-arg-behavior
def __radd__(self, other):
Expand Down
12 changes: 8 additions & 4 deletions tests/geo/test_surface_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,14 +318,18 @@ def test_nonQSratio_derivative(self):
"""
for label in ["Volume", "ToroidalFlux"]:
for axis in [False, True]:
with self.subTest(label=label, axis=axis):
self.subtest_nonQSratio_derivative(label, axis)
for fix_coil_dof in [True, False]:
with self.subTest(label=label, axis=axis, fix_coil_dof=fix_coil_dof):
self.subtest_nonQSratio_derivative(label, axis, fix_coil_dof)

def subtest_nonQSratio_derivative(self, label, axis):
def subtest_nonQSratio_derivative(self, label, axis, fix_coil_dof):
bs, boozer_surface = get_boozer_surface(label=label)
coeffs = bs.x
io = NonQuasiSymmetricRatio(boozer_surface, bs, quasi_poloidal=axis)

if fix_coil_dof:
bs.coils[0].curve.fix('xc(0)')

coeffs = bs.x
def f(dofs):
bs.x = dofs
return io.J()
Expand Down

0 comments on commit e237c4d

Please sign in to comment.