Skip to content

Commit

Permalink
[Feature] _auto_make_functional and _dispatch_td_nn_modules (#633)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 23, 2024
1 parent 5c6f298 commit a66e05c
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 52 deletions.
76 changes: 45 additions & 31 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
repopulate_module,
)

from tensordict.nn.utils import set_skip_existing
from tensordict.nn.utils import (
_auto_make_functional,
_dispatch_td_nn_modules,
set_skip_existing,
)
from tensordict.utils import implement_for, NestedKey
from torch import nn, Tensor

Expand Down Expand Up @@ -238,10 +242,16 @@ def __call__(self, func: Callable) -> Callable:
"named 'tensordict'."
)
break
# if the env variable was used, we can skip the wrapper altogether
if not _dispatch_td_nn_modules():
return func

@functools.wraps(func)
def wrapper(_self, *args: Any, **kwargs: Any) -> Any:

if not _dispatch_td_nn_modules():
return func(_self, *args, **kwargs)

source = self.source
if isinstance(source, str):
source = getattr(_self, source)
Expand Down Expand Up @@ -829,29 +839,32 @@ def reset_parameters_recursive(
lambda x: x.detach().requires_grad_(), inplace=False
)

if not is_functional(self):
if _auto_make_functional() and not is_functional(self):
make_functional(self, keep_params=True)
is_stateless = self._is_stateless
if is_stateless:
repopulate_module(self, sanitized_parameters)
else:
old_params = _swap_state(
self,
sanitized_parameters,
is_stateless=False,
return_old_tensordict=True,
)
is_stateless = self._is_stateless
if is_stateless:
repopulate_module(self, sanitized_parameters)
else:
old_params = _swap_state(
self,
sanitized_parameters,
is_stateless=False,
return_old_tensordict=True,
)

self._reset_parameters(self)
self._reset_parameters(self)

if is_stateless:
new_parameters = extract_weights_and_buffers(self)
if is_stateless:
new_parameters = extract_weights_and_buffers(self)
else:
new_parameters = _swap_state(
self, old_params, is_stateless=False, return_old_tensordict=True
)
return new_parameters
else:
new_parameters = _swap_state(
self, old_params, is_stateless=False, return_old_tensordict=True
)

return new_parameters
with sanitized_parameters.to_module(self):
self._reset_parameters(self)
return sanitized_parameters

def _reset_parameters(self, module: nn.Module) -> None:
for child in module.children():
Expand All @@ -865,10 +878,6 @@ def _reset_parameters(self, module: nn.Module) -> None:
class TensorDictModule(TensorDictModuleBase):
"""A TensorDictModule, is a python wrapper around a :obj:`nn.Module` that reads and writes to a TensorDict.
By default, :class:`TensorDictModule` subclasses are always functional,
meaning that they support the ``td_module(input, params=params)`` function
call signature.
Args:
module (Callable): a callable, typically a :class:`torch.nn.Module`,
used to map the input to the output parameter space. Its forward method
Expand Down Expand Up @@ -966,14 +975,15 @@ class TensorDictModule(TensorDictModuleBase):
>>> import torch
>>> from tensordict import TensorDict
>>> from tensordict.nn import TensorDictModule
>>> from tensordict.nn.functional_modules import make_functional
>>> td = TensorDict({"input": torch.randn(3, 4), "hidden": torch.randn(3, 8)}, [3,])
>>> module = torch.nn.GRUCell(4, 8)
>>> td_module = TensorDictModule(
... module=module, in_keys=["input", "hidden"], out_keys=["output"]
... )
>>> params = make_functional(td_module)
>>> td_functional = td_module(td.clone(), params=params)
>>> params = TensorDict.from_module(td_module)
>>> # functional API
>>> with params.to_module(td_module):
... td_functional = td_module(td.clone())
>>> print(td_functional)
TensorDict(
fields={
Expand Down Expand Up @@ -1022,7 +1032,10 @@ class TensorDictModule(TensorDictModuleBase):
batch_size=torch.Size([4]),
device=None,
is_shared=False)
>>> td_vmap = vmap(td_module, (None, 0))(td.clone(), params_repeat)
>>> def func(td, params):
... with params.to_module(td_module):
... return td_module(td)
>>> td_vmap = vmap(func, (None, 0))(td.clone(), params_repeat)
>>> print(td_vmap)
TensorDict(
fields={
Expand Down Expand Up @@ -1089,7 +1102,8 @@ def __init__(
)

self.module = module
make_functional(self, keep_params=True, return_params=False)
if _auto_make_functional():
make_functional(self, keep_params=True, return_params=False)

@property
def is_functional(self) -> bool:
Expand Down Expand Up @@ -1200,9 +1214,9 @@ def forward(
module = indent(str(module), 4 * " ")
in_keys = indent(f"in_keys={self.in_keys}", 4 * " ")
out_keys = indent(f"out_keys={self.out_keys}", 4 * " ")
raise RuntimeError(
raise err from RuntimeError(
f"TensorDictModule failed with operation\n{module}\n{in_keys}\n{out_keys}."
) from err
)

@property
def device(self) -> torch.device:
Expand Down
11 changes: 7 additions & 4 deletions tensordict/nn/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
from tensordict import TensorDict
from tensordict.nn.common import TensorDictBase, TensorDictModuleBase
from tensordict.nn.functional_modules import make_functional

from tensordict.nn.params import TensorDictParams

Expand Down Expand Up @@ -76,17 +75,21 @@ def __init__(
super().__init__()
self.in_keys = module.in_keys
self.out_keys = module.out_keys
params_td = make_functional(module).expand(num_copies).to_tensordict()
params_td = TensorDict.from_module(module).expand(num_copies).to_tensordict()

self.module = module
if expand_input:
self.vmapped_forward = torch.vmap(self.module, (None, 0))
self.vmapped_forward = torch.vmap(self._func_module_call, (None, 0))
else:
self.vmapped_forward = torch.vmap(self.module, 0)
self.vmapped_forward = torch.vmap(self._func_module_call, 0)

self.reset_parameters_recursive(params_td)
self.params_td = TensorDictParams(params_td)

def _func_module_call(self, input, params):
with params.to_module(self.module):
return self.module(input)

def forward(self, tensordict: TensorDict) -> TensorDict:
return self.vmapped_forward(tensordict, self.params_td)

Expand Down
13 changes: 9 additions & 4 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,9 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase):
>>> td_module = ProbabilisticTensorDictSequential(
... module, normal_params, prob_module
... )
>>> params = make_functional(td_module, funs_to_decorate=["forward", "get_dist", "log_prob"])
>>> _ = td_module(td, params=params)
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
... _ = td_module(td)
>>> print(td)
TensorDict(
fields={
Expand All @@ -240,13 +241,17 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase):
batch_size=torch.Size([3]),
device=None,
is_shared=False)
>>> dist = td_module.get_dist(td, params=params)
>>> with params.to_module(td_module):
... dist = td_module.get_dist(td)
>>> print(dist)
Normal(loc: torch.Size([3, 4]), scale: torch.Size([3, 4]))
>>> # we can also apply the module to the TensorDict with vmap
>>> from torch import vmap
>>> params = params.expand(4)
>>> td_vmap = vmap(td_module, (None, 0))(td, params)
>>> def func(td, params):
... with params.to_module(td_module):
... return td_module(td)
>>> td_vmap = vmap(func, (None, 0))(td, params)
>>> print(td_vmap)
TensorDict(
fields={
Expand Down
14 changes: 7 additions & 7 deletions tensordict/nn/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@
class TensorDictSequential(TensorDictModule):
"""A sequence of TensorDictModules.
By default, :class:`TensorDictSequential` subclasses are always functional,
meaning that they support the ``td_module(input, params=params)`` function
call signature.
Similarly to :obj:`nn.Sequence` which passes a tensor through a chain of mappings that read and write a single tensor
each, this module will read and write over a tensordict by querying each of the input modules.
When calling a :obj:`TensorDictSequencial` instance with a functional module, it is expected that the parameter lists (and
Expand Down Expand Up @@ -115,8 +111,9 @@ class TensorDictSequential(TensorDictModule):
... module=module2, in_keys=["hidden"], out_keys=["output"]
... )
>>> td_module = TensorDictSequential(td_module1, td_module2)
>>> params = make_functional(td_module)
>>> _ = td_module(td, params=params)
>>> params = TensorDict.from_module(td_module)
>>> with params.to_module(td_module):
... _ = td_module(td)
>>> print(td)
TensorDict(
fields={
Expand All @@ -134,7 +131,10 @@ class TensorDictSequential(TensorDictModule):
In the vmap case:
>>> from torch import vmap
>>> params = params.expand(4)
>>> td_vmap = vmap(td_module, (None, 0))(td, params)
>>> def func(td, params):
... with params.to_module(td_module):
... return td_module(td)
>>> td_vmap = vmap(func, (None, 0))(td, params)
>>> print(td_vmap)
TensorDict(
fields={
Expand Down
58 changes: 58 additions & 0 deletions tensordict/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,20 @@

import functools
import inspect
import os
from distutils.util import strtobool
from typing import Any, Callable

import torch
from torch import nn

AUTO_MAKE_FUNCTIONAL = strtobool(os.environ.get("AUTO_MAKE_FUNCTIONAL", "True"))


DISPATCH_TDNN_MODULES = strtobool(os.environ.get("DISPATCH_TDNN_MODULES", "True"))

__all__ = ["mappings", "inv_softplus", "biased_softplus"]

_SKIP_EXISTING = False

from tensordict._contextlib import _DecoratorContextManager
Expand Down Expand Up @@ -287,3 +295,53 @@ def _rebuild_buffer(data, requires_grad, backward_hooks):

# For backward compatibility in imports
from tensordict.utils import Buffer # noqa


def _auto_make_functional():
"""Returns ``True`` if TensorDictModuleBase subclasses are automatically made functional with the old API."""
global AUTO_MAKE_FUNCTIONAL
return AUTO_MAKE_FUNCTIONAL


class _set_auto_make_functional(_DecoratorContextManager):
"""Controls if TensorDictModule subclasses should be made functional automatically with the old API."""

def __init__(self, mode):
self.mode = mode

def clone(self):
return self.__class__(self.mode)

def __enter__(self):
global AUTO_MAKE_FUNCTIONAL
self._saved_mode = AUTO_MAKE_FUNCTIONAL
AUTO_MAKE_FUNCTIONAL = self.mode

def __exit__(self, exc_type, exc_val, exc_tb):
global AUTO_MAKE_FUNCTIONAL
AUTO_MAKE_FUNCTIONAL = self._saved_mode


def _dispatch_td_nn_modules():
"""Returns ``True`` if @dispatch should be used. Not using dispatch is faster and also better compatible with torch.compile."""
global DISPATCH_TDNN_MODULES
return DISPATCH_TDNN_MODULES


class _set_dispatch_td_nn_modules(_DecoratorContextManager):
"""Controls whether @dispatch should be used. Not using dispatch is faster and also better compatible with torch.compile."""

def __init__(self, mode):
self.mode = mode

def clone(self):
return self.__class__(self.mode)

def __enter__(self):
global DISPATCH_TDNN_MODULES
self._saved_mode = DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self.mode

def __exit__(self, exc_type, exc_val, exc_tb):
global DISPATCH_TDNN_MODULES
DISPATCH_TDNN_MODULES = self._saved_mode
Loading

0 comments on commit a66e05c

Please sign in to comment.