From d09626fc2d8a4000a011e62a7d5fba5891b7ec6d Mon Sep 17 00:00:00 2001
From: Vincent Moens <vmoens@meta.com>
Date: Sat, 20 Apr 2024 12:04:44 +0100
Subject: [PATCH] [BugFix] Make functorch.dim optional (#737)

---
 tensordict/_lazy.py     | 10 +++++++++-
 tensordict/_td.py       | 16 ++++++++++++++--
 tensordict/nn/params.py | 10 +++++++++-
 tensordict/utils.py     | 21 ++++++++++++++++++++-
 test/test_tensordict.py | 11 ++++++++++-
 5 files changed, 62 insertions(+), 6 deletions(-)

diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py
index f6c917130..577ac200a 100644
--- a/tensordict/_lazy.py
+++ b/tensordict/_lazy.py
@@ -21,7 +21,15 @@
 import numpy as np
 import torch
 import torch.distributed as dist
-from functorch import dim as ftdim
+
+try:
+    from functorch import dim as ftdim
+
+    _has_funcdim = True
+except ImportError:
+    from tensordict.utils import _ftdim_mock as ftdim
+
+    _has_funcdim = False
 from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict
 from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list
 from tensordict.base import (
diff --git a/tensordict/_td.py b/tensordict/_td.py
index b1fbbe966..5d742f728 100644
--- a/tensordict/_td.py
+++ b/tensordict/_td.py
@@ -20,7 +20,15 @@
 
 import numpy as np
 import torch
-from functorch import dim as ftdim
+
+try:
+    from functorch import dim as ftdim
+
+    _has_funcdim = True
+except ImportError:
+    from tensordict.utils import _ftdim_mock as ftdim
+
+    _has_funcdim = False
 
 from tensordict.base import (
     _ACCEPTED_CLASSES,
@@ -1561,7 +1569,11 @@ def _get_names_idx(self, idx):
         else:
 
             def is_boolean(idx):
-                from functorch import dim as ftdim
+                try:
+                    from functorch import dim as ftdim
+
+                except ImportError:
+                    from tensordict.utils import _ftdim_mock as ftdim
 
                 if isinstance(idx, ftdim.Dim):
                     return None
diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py
index 359f94bb0..7aa1f6594 100644
--- a/tensordict/nn/params.py
+++ b/tensordict/nn/params.py
@@ -14,7 +14,15 @@
 from typing import Any, Callable, Iterator, OrderedDict, Sequence, Type
 
 import torch
-from functorch import dim as ftdim
+
+try:
+    from functorch import dim as ftdim
+
+    _has_funcdim = True
+except ImportError:
+    from tensordict.utils import _ftdim_mock as ftdim
+
+    _has_funcdim = False
 
 from tensordict._lazy import _CustomOpTensorDict, LazyStackedTensorDict
 from tensordict._td import _SubTensorDict, TensorDict
diff --git a/tensordict/utils.py b/tensordict/utils.py
index 8f59654d6..bac125377 100644
--- a/tensordict/utils.py
+++ b/tensordict/utils.py
@@ -39,7 +39,13 @@
 
 import numpy as np
 import torch
-from functorch import dim as ftdim
+
+try:
+    from functorch import dim as ftdim
+
+    _has_funcdim = True
+except ImportError:
+    _has_funcdim = False
 from packaging.version import parse
 from tensordict._contextlib import _DecoratorContextManager
 from tensordict._tensordict import (  # noqa: F401
@@ -2207,3 +2213,16 @@ def is_non_tensor(data):
 
 def _is_non_tensor(cls: type):
     return getattr(cls, "_is_non_tensor", False)
+
+
+if not _has_funcdim:
+
+    class _ftdim_mock:
+        class Dim:
+            pass
+
+        class Tensor:
+            pass
+
+        def dims(self, *args, **kwargs):
+            raise ImportError("functorch.dim not found")
diff --git a/test/test_tensordict.py b/test/test_tensordict.py
index e0d453350..f906a17eb 100644
--- a/test/test_tensordict.py
+++ b/test/test_tensordict.py
@@ -29,7 +29,15 @@
     prod,
     TestTensorDictsBase,
 )
-from functorch import dim as ftdim
+
+try:
+    from functorch import dim as ftdim
+
+    _has_funcdim = True
+except ImportError:
+    from tensordict.utils import _ftdim_mock as ftdim
+
+    _has_funcdim = False
 
 from tensordict import LazyStackedTensorDict, make_tensordict, TensorDict
 from tensordict._lazy import _CustomOpTensorDict
@@ -7869,6 +7877,7 @@ def _pool_fixt():
         yield pool
 
 
+@pytest.mark.skipif(not _has_funcdim, reason="functorch.dim could not be found")
 class TestFCD(TestTensorDictsBase):
     """Test stack for first-class dimension."""