diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index f6c917130..577ac200a 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -21,7 +21,15 @@ import numpy as np import torch import torch.distributed as dist -from functorch import dim as ftdim + +try: + from functorch import dim as ftdim + + _has_funcdim = True +except ImportError: + from tensordict.utils import _ftdim_mock as ftdim + + _has_funcdim = False from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list from tensordict.base import ( diff --git a/tensordict/_td.py b/tensordict/_td.py index b1fbbe966..5d742f728 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -20,7 +20,15 @@ import numpy as np import torch -from functorch import dim as ftdim + +try: + from functorch import dim as ftdim + + _has_funcdim = True +except ImportError: + from tensordict.utils import _ftdim_mock as ftdim + + _has_funcdim = False from tensordict.base import ( _ACCEPTED_CLASSES, @@ -1561,7 +1569,11 @@ def _get_names_idx(self, idx): else: def is_boolean(idx): - from functorch import dim as ftdim + try: + from functorch import dim as ftdim + + except ImportError: + from tensordict.utils import _ftdim_mock as ftdim if isinstance(idx, ftdim.Dim): return None diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 359f94bb0..7aa1f6594 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -14,7 +14,15 @@ from typing import Any, Callable, Iterator, OrderedDict, Sequence, Type import torch -from functorch import dim as ftdim + +try: + from functorch import dim as ftdim + + _has_funcdim = True +except ImportError: + from tensordict.utils import _ftdim_mock as ftdim + + _has_funcdim = False from tensordict._lazy import _CustomOpTensorDict, LazyStackedTensorDict from tensordict._td import _SubTensorDict, TensorDict diff --git a/tensordict/utils.py b/tensordict/utils.py index 8f59654d6..bac125377 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -39,7 +39,13 @@ import numpy as np import torch -from functorch import dim as ftdim + +try: + from functorch import dim as ftdim + + _has_funcdim = True +except ImportError: + _has_funcdim = False from packaging.version import parse from tensordict._contextlib import _DecoratorContextManager from tensordict._tensordict import ( # noqa: F401 @@ -2207,3 +2213,16 @@ def is_non_tensor(data): def _is_non_tensor(cls: type): return getattr(cls, "_is_non_tensor", False) + + +if not _has_funcdim: + + class _ftdim_mock: + class Dim: + pass + + class Tensor: + pass + + def dims(self, *args, **kwargs): + raise ImportError("functorch.dim not found") diff --git a/test/test_tensordict.py b/test/test_tensordict.py index e0d453350..f906a17eb 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -29,7 +29,15 @@ prod, TestTensorDictsBase, ) -from functorch import dim as ftdim + +try: + from functorch import dim as ftdim + + _has_funcdim = True +except ImportError: + from tensordict.utils import _ftdim_mock as ftdim + + _has_funcdim = False from tensordict import LazyStackedTensorDict, make_tensordict, TensorDict from tensordict._lazy import _CustomOpTensorDict @@ -7869,6 +7877,7 @@ def _pool_fixt(): yield pool +@pytest.mark.skipif(not _has_funcdim, reason="functorch.dim could not be found") class TestFCD(TestTensorDictsBase): """Test stack for first-class dimension."""