diff --git a/skfem/experimental/autodiff/__init__.py b/skfem/experimental/autodiff/__init__.py index 4664f6b8..c1d525f0 100644 --- a/skfem/experimental/autodiff/__init__.py +++ b/skfem/experimental/autodiff/__init__.py @@ -4,6 +4,7 @@ import numpy as np from jax import jvp, linearize, config from jax.tree_util import register_pytree_node +import jax.numpy as jnp config.update("jax_enable_x64", True) @@ -51,6 +52,9 @@ def __rmul__(self, other): return self.value * other.value return self.value * other + def __pow__(self, ix): + return self.value ** ix + def __array__(self): return self.value @@ -108,8 +112,13 @@ def assemble(self, basis, x=None, **kwargs): nt = basis.nelems dx = basis.dx + + defaults = basis.default_parameters() w = FormExtraParams({ - **basis.default_parameters(), + **{ + k: JaxDiscreteField(*tuple(jnp.asarray(x) for x in defaults[k].astuple)) + for k in defaults + }, **self._normalize_asm_kwargs(kwargs, basis), })