diff --git a/skfem/experimental/autodiff/__init__.py b/skfem/experimental/autodiff/__init__.py index c1d525f0..224449b1 100644 --- a/skfem/experimental/autodiff/__init__.py +++ b/skfem/experimental/autodiff/__init__.py @@ -114,9 +114,12 @@ def assemble(self, basis, x=None, **kwargs): dx = basis.dx defaults = basis.default_parameters() + # turn defaults into JaxDiscreteField to avoid np.ndarray + # to jnp.ndarray promotion issues w = FormExtraParams({ **{ - k: JaxDiscreteField(*tuple(jnp.asarray(x) for x in defaults[k].astuple)) + k: JaxDiscreteField(*tuple(jnp.asarray(x) + for x in defaults[k].astuple)) for k in defaults }, **self._normalize_asm_kwargs(kwargs, basis),