Skip to content

Commit

Permalink
Set default backend dynamically based on availability
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-graham committed Sep 26, 2024
1 parent 3cee530 commit 9a77111
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/mici/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,28 @@ class AutodiffBackend(NamedTuple):

"""Available autodifferentiation framework backends."""
_REGISTERED_BACKENDS = {
"autograd": AutodiffBackend(autograd_wrapper, autograd_wrapper.AUTOGRAD_AVAILABLE),
"jax": AutodiffBackend(
jax_wrapper, jax_wrapper.JAX_AVAILABLE, jax_wrapper.jit_and_return_numpy_arrays,
jax_wrapper,
jax_wrapper.JAX_AVAILABLE,
jax_wrapper.jit_and_return_numpy_arrays,
),
"jax_nojit": AutodiffBackend(
jax_wrapper, jax_wrapper.JAX_AVAILABLE, jax_wrapper.return_numpy_arrays,
jax_wrapper,
jax_wrapper.JAX_AVAILABLE,
jax_wrapper.return_numpy_arrays,
),
"autograd": AutodiffBackend(autograd_wrapper, autograd_wrapper.AUTOGRAD_AVAILABLE),
"symnum": AutodiffBackend(symnum_wrapper, symnum_wrapper.SYMNUM_AVAILABLE),
}

"""Name of default automatic differentiation backend to use."""
DEFAULT_BACKEND = "jax"
"""Name of default automatic differentiation backend to use.
Defaults to first available backend from `jax`, `autograd` and `symnum` (in that order)
or to `None` if none are available.
"""
DEFAULT_BACKEND = next(
(name for name, backend in _REGISTERED_BACKENDS.items() if backend.available), None,
)


def _get_backend(name: str):
Expand Down

0 comments on commit 9a77111

Please sign in to comment.