diff --git a/src/simsopt/_core/derivative.py b/src/simsopt/_core/derivative.py index c3acfce6e..394677f92 100644 --- a/src/simsopt/_core/derivative.py +++ b/src/simsopt/_core/derivative.py @@ -176,20 +176,28 @@ def __call__(self, optim, as_derivative=False): """ from .optimizable import Optimizable # Import here to avoid circular import assert isinstance(optim, Optimizable) - derivs = [] - keys = [] - for k in optim.unique_dof_lineage: - if np.any(k.dofs_free_status): - local_derivs = np.zeros(k.local_dof_size) + + if not as_derivative: + derivs = [] + keys = [] + for k in optim.unique_dof_lineage: + if np.any(k.dofs_free_status): + local_derivs = np.zeros(k.local_dof_size) + for opt in k.dofs.dep_opts(): + local_derivs += self.data[opt][opt.local_dofs_free_status] + keys.append(opt) + derivs.append(local_derivs) + return np.concatenate(derivs) + else: + derivs = [] + keys = [] + for k in optim.unique_dof_lineage: + local_derivs = np.zeros(k.local_full_dof_size) for opt in k.dofs.dep_opts(): - local_derivs += self.data[opt][opt.local_dofs_free_status] + local_derivs += self.data[opt] keys.append(opt) derivs.append(local_derivs) - - if as_derivative: return Derivative({k: d for k, d in zip(keys, derivs)}) - else: - return np.concatenate(derivs) # https://stackoverflow.com/questions/11624955/avoiding-python-sum-default-start-arg-behavior def __radd__(self, other): diff --git a/tests/geo/test_surface_objectives.py b/tests/geo/test_surface_objectives.py index 0f7d82db6..990952638 100644 --- a/tests/geo/test_surface_objectives.py +++ b/tests/geo/test_surface_objectives.py @@ -318,14 +318,18 @@ def test_nonQSratio_derivative(self): """ for label in ["Volume", "ToroidalFlux"]: for axis in [False, True]: - with self.subTest(label=label, axis=axis): - self.subtest_nonQSratio_derivative(label, axis) + for fix_coil_dof in [True, False]: + with self.subTest(label=label, axis=axis, fix_coil_dof=fix_coil_dof): + self.subtest_nonQSratio_derivative(label, axis, fix_coil_dof) - def subtest_nonQSratio_derivative(self, label, axis): + def subtest_nonQSratio_derivative(self, label, axis, fix_coil_dof): bs, boozer_surface = get_boozer_surface(label=label) - coeffs = bs.x io = NonQuasiSymmetricRatio(boozer_surface, bs, quasi_poloidal=axis) + + if fix_coil_dof: + bs.coils[0].curve.fix('xc(0)') + coeffs = bs.x def f(dofs): bs.x = dofs return io.J()