diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index 53df615aa..7d964f3df 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -25,7 +25,11 @@ repopulate_module, ) -from tensordict.nn.utils import set_skip_existing +from tensordict.nn.utils import ( + _auto_make_functional, + _dispatch_td_nn_modules, + set_skip_existing, +) from tensordict.utils import implement_for, NestedKey from torch import nn, Tensor @@ -238,10 +242,16 @@ def __call__(self, func: Callable) -> Callable: "named 'tensordict'." ) break + # if the env variable was used, we can skip the wrapper altogether + if not _dispatch_td_nn_modules(): + return func @functools.wraps(func) def wrapper(_self, *args: Any, **kwargs: Any) -> Any: + if not _dispatch_td_nn_modules(): + return func(_self, *args, **kwargs) + source = self.source if isinstance(source, str): source = getattr(_self, source) @@ -829,29 +839,32 @@ def reset_parameters_recursive( lambda x: x.detach().requires_grad_(), inplace=False ) - if not is_functional(self): + if _auto_make_functional() and not is_functional(self): make_functional(self, keep_params=True) - is_stateless = self._is_stateless - if is_stateless: - repopulate_module(self, sanitized_parameters) - else: - old_params = _swap_state( - self, - sanitized_parameters, - is_stateless=False, - return_old_tensordict=True, - ) + is_stateless = self._is_stateless + if is_stateless: + repopulate_module(self, sanitized_parameters) + else: + old_params = _swap_state( + self, + sanitized_parameters, + is_stateless=False, + return_old_tensordict=True, + ) - self._reset_parameters(self) + self._reset_parameters(self) - if is_stateless: - new_parameters = extract_weights_and_buffers(self) + if is_stateless: + new_parameters = extract_weights_and_buffers(self) + else: + new_parameters = _swap_state( + self, old_params, is_stateless=False, return_old_tensordict=True + ) + return new_parameters else: - new_parameters = _swap_state( - self, old_params, is_stateless=False, return_old_tensordict=True - ) - - return new_parameters + with sanitized_parameters.to_module(self): + self._reset_parameters(self) + return sanitized_parameters def _reset_parameters(self, module: nn.Module) -> None: for child in module.children(): @@ -865,10 +878,6 @@ def _reset_parameters(self, module: nn.Module) -> None: class TensorDictModule(TensorDictModuleBase): """A TensorDictModule, is a python wrapper around a :obj:`nn.Module` that reads and writes to a TensorDict. - By default, :class:`TensorDictModule` subclasses are always functional, - meaning that they support the ``td_module(input, params=params)`` function - call signature. - Args: module (Callable): a callable, typically a :class:`torch.nn.Module`, used to map the input to the output parameter space. Its forward method @@ -966,14 +975,15 @@ class TensorDictModule(TensorDictModuleBase): >>> import torch >>> from tensordict import TensorDict >>> from tensordict.nn import TensorDictModule - >>> from tensordict.nn.functional_modules import make_functional >>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,]) >>> module = torch.nn.GRUCell(4, 8) >>> td_module = TensorDictModule( ... module=module, in_keys=["input", "hidden"], out_keys=["output"] ... ) - >>> params = make_functional(td_module) - >>> td_functional = td_module(td.clone(), params=params) + >>> params = TensorDict.from_module(td_module) + >>> # functional API + >>> with params.to_module(td_module): + ... td_functional = td_module(td.clone()) >>> print(td_functional) TensorDict( fields={ @@ -1022,7 +1032,10 @@ class TensorDictModule(TensorDictModuleBase): batch_size=torch.Size([4]), device=None, is_shared=False) - >>> td_vmap = vmap(td_module, (None, 0))(td.clone(), params_repeat) + >>> def func(td, params): + ... with params.to_module(td_module): + ... return td_module(td) + >>> td_vmap = vmap(func, (None, 0))(td.clone(), params_repeat) >>> print(td_vmap) TensorDict( fields={ @@ -1089,7 +1102,8 @@ def __init__( ) self.module = module - make_functional(self, keep_params=True, return_params=False) + if _auto_make_functional(): + make_functional(self, keep_params=True, return_params=False) @property def is_functional(self) -> bool: @@ -1200,9 +1214,9 @@ def forward( module = indent(str(module), 4 * " ") in_keys = indent(f"in_keys={self.in_keys}", 4 * " ") out_keys = indent(f"out_keys={self.out_keys}", 4 * " ") - raise RuntimeError( + raise err from RuntimeError( f"TensorDictModule failed with operation\n{module}\n{in_keys}\n{out_keys}." - ) from err + ) @property def device(self) -> torch.device: diff --git a/tensordict/nn/ensemble.py b/tensordict/nn/ensemble.py index 2cefe015e..8c28dc3d7 100644 --- a/tensordict/nn/ensemble.py +++ b/tensordict/nn/ensemble.py @@ -8,7 +8,6 @@ import torch from tensordict import TensorDict from tensordict.nn.common import TensorDictBase, TensorDictModuleBase -from tensordict.nn.functional_modules import make_functional from tensordict.nn.params import TensorDictParams @@ -76,17 +75,21 @@ def __init__( super().__init__() self.in_keys = module.in_keys self.out_keys = module.out_keys - params_td = make_functional(module).expand(num_copies).to_tensordict() + params_td = TensorDict.from_module(module).expand(num_copies).to_tensordict() self.module = module if expand_input: - self.vmapped_forward = torch.vmap(self.module, (None, 0)) + self.vmapped_forward = torch.vmap(self._func_module_call, (None, 0)) else: - self.vmapped_forward = torch.vmap(self.module, 0) + self.vmapped_forward = torch.vmap(self._func_module_call, 0) self.reset_parameters_recursive(params_td) self.params_td = TensorDictParams(params_td) + def _func_module_call(self, input, params): + with params.to_module(self.module): + return self.module(input) + def forward(self, tensordict: TensorDict) -> TensorDict: return self.vmapped_forward(tensordict, self.params_td) diff --git a/tensordict/nn/probabilistic.py b/tensordict/nn/probabilistic.py index ae40e6cb2..04a6aa914 100644 --- a/tensordict/nn/probabilistic.py +++ b/tensordict/nn/probabilistic.py @@ -225,8 +225,9 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase): >>> td_module = ProbabilisticTensorDictSequential( ... module, normal_params, prob_module ... ) - >>> params = make_functional(td_module, funs_to_decorate=["forward", "get_dist", "log_prob"]) - >>> _ = td_module(td, params=params) + >>> params = TensorDict.from_module(td_module) + >>> with params.to_module(td_module): + ... _ = td_module(td) >>> print(td) TensorDict( fields={ @@ -240,13 +241,17 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase): batch_size=torch.Size([3]), device=None, is_shared=False) - >>> dist = td_module.get_dist(td, params=params) + >>> with params.to_module(td_module): + ... dist = td_module.get_dist(td) >>> print(dist) Normal(loc: torch.Size([3, 4]), scale: torch.Size([3, 4])) >>> # we can also apply the module to the TensorDict with vmap >>> from torch import vmap >>> params = params.expand(4) - >>> td_vmap = vmap(td_module, (None, 0))(td, params) + >>> def func(td, params): + ... with params.to_module(td_module): + ... return td_module(td) + >>> td_vmap = vmap(func, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index a5265d849..eb7e14cef 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -38,10 +38,6 @@ class TensorDictSequential(TensorDictModule): """A sequence of TensorDictModules. - By default, :class:`TensorDictSequential` subclasses are always functional, - meaning that they support the ``td_module(input, params=params)`` function - call signature. - Similarly to :obj:`nn.Sequence` which passes a tensor through a chain of mappings that read and write a single tensor each, this module will read and write over a tensordict by querying each of the input modules. When calling a :obj:`TensorDictSequencial` instance with a functional module, it is expected that the parameter lists (and @@ -115,8 +111,9 @@ class TensorDictSequential(TensorDictModule): ... module=module2, in_keys=["hidden"], out_keys=["output"] ... ) >>> td_module = TensorDictSequential(td_module1, td_module2) - >>> params = make_functional(td_module) - >>> _ = td_module(td, params=params) + >>> params = TensorDict.from_module(td_module) + >>> with params.to_module(td_module): + ... _ = td_module(td) >>> print(td) TensorDict( fields={ @@ -134,7 +131,10 @@ class TensorDictSequential(TensorDictModule): In the vmap case: >>> from torch import vmap >>> params = params.expand(4) - >>> td_vmap = vmap(td_module, (None, 0))(td, params) + >>> def func(td, params): + ... with params.to_module(td_module): + ... return td_module(td) + >>> td_vmap = vmap(func, (None, 0))(td, params) >>> print(td_vmap) TensorDict( fields={ diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index a15a71b28..3e1d698bd 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -7,12 +7,20 @@ import functools import inspect +import os +from distutils.util import strtobool from typing import Any, Callable import torch from torch import nn +AUTO_MAKE_FUNCTIONAL = strtobool(os.environ.get("AUTO_MAKE_FUNCTIONAL", "True")) + + +DISPATCH_TDNN_MODULES = strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True")) + __all__ = ["mappings", "inv_softplus", "biased_softplus"] + _SKIP_EXISTING = False from tensordict._contextlib import _DecoratorContextManager @@ -287,3 +295,53 @@ def _rebuild_buffer(data, requires_grad, backward_hooks): # For backward compatibility in imports from tensordict.utils import Buffer # noqa + + +def _auto_make_functional(): + """Returns ``True`` if TensorDictModuleBase subclasses are automatically made functional with the old API.""" + global AUTO_MAKE_FUNCTIONAL + return AUTO_MAKE_FUNCTIONAL + + +class _set_auto_make_functional(_DecoratorContextManager): + """Controls if TensorDictModule subclasses should be made functional automatically with the old API.""" + + def __init__(self, mode): + self.mode = mode + + def clone(self): + return self.__class__(self.mode) + + def __enter__(self): + global AUTO_MAKE_FUNCTIONAL + self._saved_mode = AUTO_MAKE_FUNCTIONAL + AUTO_MAKE_FUNCTIONAL = self.mode + + def __exit__(self, exc_type, exc_val, exc_tb): + global AUTO_MAKE_FUNCTIONAL + AUTO_MAKE_FUNCTIONAL = self._saved_mode + + +def _dispatch_td_nn_modules(): + """Returns ``True`` if @dispatch should be used. Not using dispatch is faster and also better compatible with torch.compile.""" + global DISPATCH_TDNN_MODULES + return DISPATCH_TDNN_MODULES + + +class _set_dispatch_td_nn_modules(_DecoratorContextManager): + """Controls whether @dispatch should be used. Not using dispatch is faster and also better compatible with torch.compile.""" + + def __init__(self, mode): + self.mode = mode + + def clone(self): + return self.__class__(self.mode) + + def __enter__(self): + global DISPATCH_TDNN_MODULES + self._saved_mode = DISPATCH_TDNN_MODULES + DISPATCH_TDNN_MODULES = self.mode + + def __exit__(self, exc_type, exc_val, exc_tb): + global DISPATCH_TDNN_MODULES + DISPATCH_TDNN_MODULES = self._saved_mode diff --git a/test/test_nn.py b/test/test_nn.py index b894b3756..07e80c16d 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -34,7 +34,13 @@ from tensordict.nn.ensemble import EnsembleModule from tensordict.nn.functional_modules import is_functional, make_functional from tensordict.nn.probabilistic import InteractionType, set_interaction_type -from tensordict.nn.utils import Buffer, set_skip_existing, skip_existing +from tensordict.nn.utils import ( + _set_auto_make_functional, + _set_dispatch_td_nn_modules, + Buffer, + set_skip_existing, + skip_existing, +) from torch import distributions as d, nn from torch.distributions import Normal from torch.utils._pytree import tree_map @@ -414,6 +420,28 @@ def test_functional_before(self): assert td.shape == torch.Size([3]) assert td.get("out").shape == torch.Size([3, 4]) + @pytest.mark.skipif( + not _has_functorch, reason=f"functorch not found: err={FUNCTORCH_ERR}" + ) + def test_functional_deactivate(self): + torch.manual_seed(0) + param_multiplier = 1 + + net = nn.Linear(3, 4 * param_multiplier) + + td = TensorDict({"in": torch.randn(3, 3)}, [3]) + + with _set_auto_make_functional(False): + tensordict_module = TensorDictModule( + module=net, in_keys=["in"], out_keys=["out"] + ) + assert not is_functional(tensordict_module) + params = TensorDict.from_module(tensordict_module) + with pytest.raises(TypeError): + tensordict_module(td, params=params) + make_functional(tensordict_module) + tensordict_module(td, params=params) + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not found: err={FUNCTORCH_ERR}" ) @@ -735,6 +763,22 @@ def __deepcopy__(self, memodict=None): assert tdmodule.some_attribute == "a" assert isinstance(copy.deepcopy(tdmodule), TensorDictModule) + def test_dispatch_deactivate(self): + tdm = TensorDictModule(nn.Linear(1, 1), ["a"], ["b"]) + td = TensorDict({"a": torch.zeros(1, 1)}, 1) + tdm(td) + with _set_dispatch_td_nn_modules(True): + out = tdm(a=torch.zeros(1, 1)) + assert (out == td["b"]).all() + with _set_dispatch_td_nn_modules(False), pytest.raises( + TypeError, match="missing 1 required positional argument" + ): + tdm(a=torch.zeros(1, 1)) + + # checks that things are back in place + tdm = TensorDictModule(nn.Linear(1, 1), ["a"], ["b"]) + tdm(a=torch.zeros(1, 1)) + def test_dispatch(self): tdm = TensorDictModule(nn.Linear(1, 1), ["a"], ["b"]) td = TensorDict({"a": torch.zeros(1, 1)}, 1) @@ -1881,12 +1925,10 @@ def test_probabilistic_sequential_type_checks(): def test_keyerr_msg(): module = TensorDictModule(nn.Linear(2, 3), in_keys=["a"], out_keys=["b"]) with pytest.raises( - RuntimeError, match="TensorDictModule failed with operation" - ) as err: + KeyError, + match="Some tensors that are necessary for the module call may not have not been found in the input tensordict", + ): module(TensorDict({"c": torch.randn(())}, [])) - assert "Some tensors that are necessary for the module call" in str( - err.value.__cause__ - ) def test_input():