Skip to content

Commit

Permalink
[BugFix] Make functorch.dim optional (#737)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 20, 2024
1 parent 3a44928 commit d09626f
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 6 deletions.
10 changes: 9 additions & 1 deletion tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,15 @@
import numpy as np
import torch
import torch.distributed as dist
from functorch import dim as ftdim

try:
from functorch import dim as ftdim

_has_funcdim = True
except ImportError:
from tensordict.utils import _ftdim_mock as ftdim

_has_funcdim = False
from tensordict._td import _SubTensorDict, _TensorDictKeysView, TensorDict
from tensordict._tensordict import _unravel_key_to_tuple, unravel_key_list
from tensordict.base import (
Expand Down
16 changes: 14 additions & 2 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@

import numpy as np
import torch
from functorch import dim as ftdim

try:
from functorch import dim as ftdim

_has_funcdim = True
except ImportError:
from tensordict.utils import _ftdim_mock as ftdim

_has_funcdim = False

from tensordict.base import (
_ACCEPTED_CLASSES,
Expand Down Expand Up @@ -1561,7 +1569,11 @@ def _get_names_idx(self, idx):
else:

def is_boolean(idx):
from functorch import dim as ftdim
try:
from functorch import dim as ftdim

except ImportError:
from tensordict.utils import _ftdim_mock as ftdim

if isinstance(idx, ftdim.Dim):
return None
Expand Down
10 changes: 9 additions & 1 deletion tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
from typing import Any, Callable, Iterator, OrderedDict, Sequence, Type

import torch
from functorch import dim as ftdim

try:
from functorch import dim as ftdim

_has_funcdim = True
except ImportError:
from tensordict.utils import _ftdim_mock as ftdim

_has_funcdim = False

from tensordict._lazy import _CustomOpTensorDict, LazyStackedTensorDict
from tensordict._td import _SubTensorDict, TensorDict
Expand Down
21 changes: 20 additions & 1 deletion tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@

import numpy as np
import torch
from functorch import dim as ftdim

try:
from functorch import dim as ftdim

_has_funcdim = True
except ImportError:
_has_funcdim = False
from packaging.version import parse
from tensordict._contextlib import _DecoratorContextManager
from tensordict._tensordict import ( # noqa: F401
Expand Down Expand Up @@ -2207,3 +2213,16 @@ def is_non_tensor(data):

def _is_non_tensor(cls: type):
return getattr(cls, "_is_non_tensor", False)


if not _has_funcdim:

class _ftdim_mock:
class Dim:
pass

class Tensor:
pass

def dims(self, *args, **kwargs):
raise ImportError("functorch.dim not found")
11 changes: 10 additions & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@
prod,
TestTensorDictsBase,
)
from functorch import dim as ftdim

try:
from functorch import dim as ftdim

_has_funcdim = True
except ImportError:
from tensordict.utils import _ftdim_mock as ftdim

_has_funcdim = False

from tensordict import LazyStackedTensorDict, make_tensordict, TensorDict
from tensordict._lazy import _CustomOpTensorDict
Expand Down Expand Up @@ -7869,6 +7877,7 @@ def _pool_fixt():
yield pool


@pytest.mark.skipif(not _has_funcdim, reason="functorch.dim could not be found")
class TestFCD(TestTensorDictsBase):
"""Test stack for first-class dimension."""

Expand Down

0 comments on commit d09626f

Please sign in to comment.