From 9a771117376c0007182b2a76abfe3ee8a80c1664 Mon Sep 17 00:00:00 2001 From: Matt Graham Date: Thu, 26 Sep 2024 22:41:24 +0100 Subject: [PATCH] Set default backend dynamically based on availability --- src/mici/autodiff.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/mici/autodiff.py b/src/mici/autodiff.py index b3d53fb..4161ee3 100644 --- a/src/mici/autodiff.py +++ b/src/mici/autodiff.py @@ -48,18 +48,28 @@ class AutodiffBackend(NamedTuple): """Available autodifferentiation framework backends.""" _REGISTERED_BACKENDS = { - "autograd": AutodiffBackend(autograd_wrapper, autograd_wrapper.AUTOGRAD_AVAILABLE), "jax": AutodiffBackend( - jax_wrapper, jax_wrapper.JAX_AVAILABLE, jax_wrapper.jit_and_return_numpy_arrays, + jax_wrapper, + jax_wrapper.JAX_AVAILABLE, + jax_wrapper.jit_and_return_numpy_arrays, ), "jax_nojit": AutodiffBackend( - jax_wrapper, jax_wrapper.JAX_AVAILABLE, jax_wrapper.return_numpy_arrays, + jax_wrapper, + jax_wrapper.JAX_AVAILABLE, + jax_wrapper.return_numpy_arrays, ), + "autograd": AutodiffBackend(autograd_wrapper, autograd_wrapper.AUTOGRAD_AVAILABLE), "symnum": AutodiffBackend(symnum_wrapper, symnum_wrapper.SYMNUM_AVAILABLE), } -"""Name of default automatic differentiation backend to use.""" -DEFAULT_BACKEND = "jax" +"""Name of default automatic differentiation backend to use. + +Defaults to first available backend from `jax`, `autograd` and `symnum` (in that order) +or to `None` if none are available. +""" +DEFAULT_BACKEND = next( + (name for name, backend in _REGISTERED_BACKENDS.items() if backend.available), None, +) def _get_backend(name: str):