From 74cae09876d2962417153eaeb81fc77cf8fef8e5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 9 Jan 2025 18:43:37 +0000 Subject: [PATCH] [Feature] UnbatchedTensor ghstack-source-id: fa25726d61e913a725a71f1579eb06b09455e7c8 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1170 --- docs/source/overview.rst | 56 +++++++--- tensordict/__init__.py | 1 + tensordict/_torch_func.py | 10 +- tensordict/_unbatched.py | 214 ++++++++++++++++++++++++++++++++++++++ tensordict/base.py | 16 ++- tensordict/tensorclass.py | 2 + tensordict/utils.py | 34 ++++-- test/test_tensordict.py | 104 ++++++++++++++++++ 8 files changed, 405 insertions(+), 32 deletions(-) create mode 100644 tensordict/_unbatched.py diff --git a/docs/source/overview.rst b/docs/source/overview.rst index 5b6d100a2..72e8942e0 100644 --- a/docs/source/overview.rst +++ b/docs/source/overview.rst @@ -71,25 +71,26 @@ Features -------- A :class:`~tensordict.TensorDict` is a dict-like container for tensors. To instantiate a :class:`~tensordict.TensorDict`, -you must specify key-value pairs as well as the batch size. The leading dimensions of any values in the :class:`~tensordict.TensorDict` must be compatible with the batch size. +you can specify key-value pairs +as well as the batch size (an empty tensordict can be created via `TensorDict()`). +The leading dimensions of any values in the :class:`~tensordict.TensorDict` must be compatible with the batch size. ->>> import torch ->>> from tensordict import TensorDict - ->>> tensordict = TensorDict( -... {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)}, -... batch_size=[2, 3], -... ) + >>> import torch + >>> from tensordict import TensorDict + >>> tensordict = TensorDict( + ... {"zeros": torch.zeros(2, 3, 4), "ones": torch.ones(2, 3, 4, 5)}, + ... batch_size=[2, 3], + ... ) The syntax for setting or retrieving values is much like that for a regular dictionary. ->>> zeros = tensordict["zeros"] ->>> tensordict["twos"] = 2 * torch.ones(2, 3) + >>> zeros = tensordict["zeros"] + >>> tensordict["twos"] = 2 * torch.ones(2, 3) One can also index a tensordict along its batch_size which makes it possible to obtain congruent slices of data in just a few characters (notice that indexing the nth leading dimensions with tree_map using an ellipsis would require a bit more coding): ->>> sub_tensordict = tensordict[..., :2] + >>> sub_tensordict = tensordict[..., :2] One can also use the set method with ``inplace=True`` or the :meth:`~tensordict.TensorDict.set_` method to do inplace updates of the contents. The former is a fault-tolerant version of the latter: if no matching key is found, it will write a new one. @@ -97,14 +98,39 @@ The former is a fault-tolerant version of the latter: if no matching key is foun The contents of the TensorDict can now be manipulated collectively. For example, to place all of the contents onto a particular device one can simply do ->>> tensordict = tensordict.to("cuda:0") + >>> tensordict = tensordict.to("cuda:0") + +You can then assert that the device of the tensordict is `"cuda:0"`: + + >>> assert tensordict.device == torch.device("cuda:0") To reshape the batch dimensions one can do ->>> tensordict = tensordict.reshape(6) + >>> tensordict = tensordict.reshape(6) + +The class supports many other operations, including :func:`~torch.squeeze`, :func:`~torch.unsqueeze`, +:meth:`~tensordict.TensorDict.view`, :func:`~torch.permute`, :meth:`~tensordict.TensorDict.unbind`, +:func:`~torch.stack`, :func:`~torch.cat` and many more. + +If an operation is not present, the :meth:`~tensordict.TensorDict.apply` method will usually provide the solution +that was needed. + +Escaping shape operations +~~~~~~~~~~~~~~~~~~~~~~~~~ + +In some cases, it may be desirable to store tensors in a TensorDict without enforcing batch size consistency during +shape operations. + +This can be achieved by wrapping the tensor in an :class:`~tensordict.UnbatchedTensor` instance. + +An :class:`~tensordict.UnbatchedTensor` ignores its shape during shape operations on the TensorDict, allowing for +flexible storage and manipulation of tensors with arbitrary shapes. -The class supports many other operations, including squeeze, unsqueeze, view, permute, unbind, stack, cat and many more. -If an operation is not present, the TensorDict.apply method will usually provide the solution that was needed. + >>> from tensordict import UnbatchedTensor + >>> tensordict = TensorDict({"zeros": UnbatchedTensor(torch.zeros(10))}, batch_size=[2, 3]) + >>> reshaped_td = tensordict.reshape(6) + >>> reshaped_td["zeros"] is tensordict["zeros"] + True Non-tensor data --------------- diff --git a/tensordict/__init__.py b/tensordict/__init__.py index c339365be..ae080a5f0 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -23,6 +23,7 @@ stack, TensorDict, ) +from tensordict._unbatched import UnbatchedTensor from tensordict.base import ( from_any, diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 2771d4985..be504b2b5 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -23,10 +23,10 @@ from tensordict.utils import ( _check_keys, _ErrorInteceptor, + _pass_through, _shape, _zip_strict, DeviceType, - is_non_tensor, is_tensorclass, lazy_legacy, set_lazy_legacy, @@ -454,10 +454,10 @@ def _stack( if maybe_dense_stack is None: maybe_dense_stack = lazy_legacy() is_tc = any(is_tensorclass(td) for td in list_of_tensordicts) - if all(is_non_tensor(td) for td in list_of_tensordicts): - from tensordict.tensorclass import NonTensorData - - return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) + if all(_pass_through(td) for td in list_of_tensordicts): + return type(list_of_tensordicts[0])._stack_non_tensor( + list_of_tensordicts, dim=dim + ) if is_tc: tc_type = type(list_of_tensordicts[0]) list_of_tensordicts = [tc._tensordict for tc in list_of_tensordicts] diff --git a/tensordict/_unbatched.py b/tensordict/_unbatched.py new file mode 100644 index 000000000..70ff94e84 --- /dev/null +++ b/tensordict/_unbatched.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from functools import wraps +from typing import Any, Callable + +import torch +from tensordict.base import TensorDictBase + +from tensordict.tensorclass import ( + _arg_to_tensordict, + _from_tensordict_with_copy, + _TD_PASS_THROUGH, + TD_HANDLED_FUNCTIONS, + TensorClass, +) +from tensordict.utils import _getitem_batch_size, _is_tensorclass, unravel_key +from torch import Tensor + + +def _arg_to_tensordict_unbatched(arg, batch_size): + if _is_tensorclass(type(arg)): + arg = arg._tensordict.empty() + arg.batch_size = batch_size + return arg + elif isinstance(arg, (tuple, list)) and all( + _is_tensorclass(type(item)) for item in arg + ): + arg_list = [] + for item in arg: + item = item._tensordict.empty() + item.batch_size = batch_size + arg_list.append(item) + + return type(arg)(arg_list) + return arg + + +def _bypass(func): + @wraps(func) + def bypassed_func(self, *args, **kwargs): + meta_tensor = torch.zeros( + self.batch_size, dtype=self.dtype, device=torch.device("meta") + ) + name = func.__name__ + r = getattr(meta_tensor, name)(*args, **kwargs) + self_copy = self.copy() + self_copy.batch_size = r.shape + return self_copy + + return bypassed_func + + +_TORCH_SHAPE_OPS = ( + torch.gather, + torch.unbind, + torch.cat, + torch.stack, + torch.unflatten, + torch.flatten, + torch.split, + torch.squeeze, + torch.unsqueeze, +) + + +class UnbatchedTensor(TensorClass): + """A TensorClass that represents a tensor whose shape is ignored during shape operations. + + This class allows tensors to be stored in a TensorDict without enforcing batch size consistency. + Shape operations (e.g., reshape, unsqueeze, squeeze) on the TensorDict will return the same UnbatchedTensor instance, + while other operations (e.g., apply, key manipulation, pointwise arithmetic) may modify the underlying tensor content. + + Example: + >>> td = TensorDict(a=UnbatchedTensor(torch.randn(3, 4)), b=torch.randn(2, 3), batch_size=(2,)) + >>> td_reshaped = td.reshape((1, 2)) + >>> td_reshaped["a"] is td["a"] + True + + Note that accessing an UnbatchedTensor using `get()` and `__getitem__()` will return different results. + `get()` returns the UnbatchedTensor instance, while `__getitem__()` returns the underlying tensor content. + + Example: + >>> td.get("a") + + >>> td["a"] + tensor([[...]]) + + """ + + data: torch.Tensor | TensorDictBase + _pass_through = True + + @classmethod + def __torch_function__( + cls, + func: Callable, + types: tuple[type, ...], + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + ) -> Callable: + if func not in _TD_PASS_THROUGH or not all( + issubclass(t, (Tensor, cls, TensorDictBase)) for t in types + ): + return NotImplemented + + if kwargs is None: + kwargs = {} + + # get the output type from the arguments / keyword arguments + if len(args) > 0: + tensorclass_instance = args[0] + else: + tensorclass_instance = kwargs.get("input", kwargs["tensors"]) + if isinstance(tensorclass_instance, (tuple, list)): + tensorclass_instance = tensorclass_instance[0] + + if func not in _TORCH_SHAPE_OPS: + args = tuple(_arg_to_tensordict(arg) for arg in args) + kwargs = {key: _arg_to_tensordict(value) for key, value in kwargs.items()} + result = TD_HANDLED_FUNCTIONS[func](*args, **kwargs) + else: + # Get a brute force batch size + args = tuple( + _arg_to_tensordict_unbatched(arg, tensorclass_instance.batch_size) + for arg in args + ) + kwargs = { + key: _arg_to_tensordict_unbatched( + value, tensorclass_instance.batch_size + ) + for key, value in kwargs.items() + } + example_td = TD_HANDLED_FUNCTIONS[func](*args, **kwargs) + result = tensorclass_instance.copy() + result.batch_size = example_td.batch_size + return result + + if isinstance(result, (list, tuple)): + return type(result)( + _from_tensordict_with_copy(tensorclass_instance, tensordict_result) + for tensordict_result in result + ) + return _from_tensordict_with_copy(tensorclass_instance, result) + + def __getitem__(self, index): + if isinstance(index, (tuple, str)) and unravel_key(index): + raise ValueError( + "TensorClass fields must be accessed as attributes, not items." + ) + self_copy = self.copy() + self_copy.batch_size = _getitem_batch_size(self.batch_size, index) + return self_copy + + @property + def batch_size(self): + return self._batch_size + + @batch_size.setter + def batch_size(self, batch_size): + self.__dict__["_batch_size"] = torch.Size(batch_size) + + shape = batch_size + + def unbind(self, dim: int): + return tuple( + self[(slice(None),) * dim + (0,)] for _ in range(self.batch_size[dim]) + ) + + @_bypass + def reshape(self, *shape): ... + + @_bypass + def view(self, *shape): ... + + def unsqueeze(self, dim): + shape = list(self.batch_size) + shape.insert(dim, 0) + self_copy = self.copy() + self_copy.batch_size = shape + return self_copy + + def transpose(self, dim0, dim1): + batch_size = list(self.batch_size) + batch_size[dim1], batch_size[dim0] = batch_size[dim0], batch_size[dim1] + self_copy = self.copy() + self_copy.batch_size = batch_size + return self_copy + + def permute(self, *dims): + if len(dims) == 1 and not isinstance(dims[0], int): + return self.permute(*dims[0]) + batch_size = list(self.batch_size) + batch_size = [batch_size[d] for d in dims] + self_copy = self.copy() + self_copy.batch_size = batch_size + return self_copy + + @classmethod + def _stack_non_tensor(cls, list_of_non_tensor, dim=0, raise_if_non_unique=False): + result = list_of_non_tensor[0].copy() + batch_size = list(result.batch_size) + batch_size.insert(dim, len(list_of_non_tensor)) + result.batch_size = torch.Size(batch_size) + return result + + @_bypass + def unflatten(self, dim, unflattened_size): ... + + @_bypass + def flatten(self, start_dim=0, end_dim=-1): ... diff --git a/tensordict/base.py b/tensordict/base.py index 36af86611..e0b1d0411 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -66,6 +66,8 @@ _lock_warn, _make_dtype_promotion, _parse_to, + _pass_through, + _pass_through_cls, _pin_mem, _PIN_MEM_TIMEOUT, _prefix_last_key, @@ -6450,7 +6452,7 @@ def _get_tuple(self, key, default): ... def _get_tuple_maybe_non_tensor(self, key, default): result = self._get_tuple(key, default) - if is_non_tensor(result): + if _pass_through(result): # Only lazy stacks of non tensors are actually tensordict instances if isinstance(result, TensorDictBase): return result.tolist() @@ -7371,7 +7373,11 @@ def flatten(tensor): else: names = None out = self._fast_apply( - flatten, batch_size=batch_size, propagate_lock=True, names=names + flatten, + batch_size=batch_size, + propagate_lock=True, + names=names, + call_on_nested=True, ) return out @@ -7417,7 +7423,9 @@ def unflatten(tensor): else: batch_size = list(unflattened_size) + list(self.batch_size[1:]) # TODO: check that this works with nested tds of different batch size - out = self._fast_apply(unflatten, batch_size=batch_size, propagate_lock=True) + out = self._fast_apply( + unflatten, batch_size=batch_size, propagate_lock=True, call_on_nested=True + ) if self._has_names(): names = copy(self.names) for _ in range(len(unflattened_size) - 1): @@ -13317,7 +13325,7 @@ def _default_is_leaf(cls: Type) -> bool: def _is_leaf_nontensor(cls: Type) -> bool: if _is_tensor_collection(cls): - return _is_non_tensor(cls) + return _pass_through_cls(cls) # if issubclass(cls, KeyedJaggedTensor): # return False return issubclass(cls, torch.Tensor) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index e3599bd39..04ada3198 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -3427,6 +3427,8 @@ def _maybe_from_list(nontensor): def is_empty(self) -> bool: return False + _stack_non_tensor = NonTensorData._stack_non_tensor + @classmethod def from_nontensordata(cls, non_tensor: NonTensorData): data = non_tensor.data diff --git a/tensordict/utils.py b/tensordict/utils.py index 5813e5b7f..b92cd0f9f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1231,13 +1231,13 @@ def __call__(func): if attr is not None: @wraps(func) - def new_func(_self, *args, **kwargs): + def func_as_decorator(_self, *args, **kwargs): _attr_pre = getattr(_self, attr) out = func(_self, *args, **kwargs) _attr_post = getattr(_self, attr) if out is not None: if _attr_post is not _attr_pre: - out._last_op = (new_func.__name__, (args, kwargs, _self)) + out._last_op = (func.__name__, (args, kwargs, _self)) else: out._last_op = None return out @@ -1245,13 +1245,13 @@ def new_func(_self, *args, **kwargs): else: @wraps(func) - def new_func(_self, *args, **kwargs): + def func_as_decorator(_self, *args, **kwargs): out = func(_self, *args, **kwargs) if out is not None: - out._last_op = (new_func.__name__, (args, kwargs, _self)) + out._last_op = (func.__name__, (args, kwargs, _self)) return out - return new_func + return func_as_decorator return __call__ @@ -2452,9 +2452,13 @@ def __call__(self, mod: torch.nn.Module, args, kwargs): raise RuntimeError("did not find pre-hook") -def is_non_tensor(data): +def is_non_tensor(data) -> bool: """Checks if an item is a non-tensor.""" - return getattr(type(data), "_is_non_tensor", False) + return _is_non_tensor(type(data)) + + +def _pass_through(data) -> bool: + return _pass_through_cls(type(data)) _NON_TENSOR_MEMO = {} @@ -2466,7 +2470,21 @@ def _is_non_tensor(cls: type): if not is_dynamo: out = _NON_TENSOR_MEMO.get(cls) if out is None: - out = getattr(cls, "_is_non_tensor", False) + out = bool(getattr(cls, "_is_non_tensor", False)) + if not is_dynamo: + _NON_TENSOR_MEMO[cls] = out + return out + + +def _pass_through_cls(cls: type): + out = None + is_dynamo = is_compiling() + if not is_dynamo: + out = _NON_TENSOR_MEMO.get(cls) + if out is None: + out = bool(getattr(cls, "_is_non_tensor", False)) or getattr( + cls, "_pass_through", False + ) if not is_dynamo: _NON_TENSOR_MEMO[cls] = out return out diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 60ba211a9..ce17a33c1 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -45,6 +45,7 @@ set_get_defaults_to_none, TensorClass, TensorDict, + UnbatchedTensor, ) from tensordict._lazy import _CustomOpTensorDict from tensordict._reductions import _reduce_td @@ -59,6 +60,7 @@ from tensordict.utils import ( _getitem_batch_size, _LOCK_ERROR, + _pass_through, assert_allclose_td, convert_ellipsis_to_idx, is_non_tensor, @@ -11780,6 +11782,108 @@ class SubTC(NonTensorData): ... assert is_non_tensor(SubTC(data=1, batch_size=[])) +class TestUnbatchedTensor: + def test_unbatched(self): + assert UnbatchedTensor._pass_through + td = TensorDict( + a=UnbatchedTensor(torch.randn(10)), + b=torch.randn(3), + batch_size=(3,), + ) + assert _pass_through(td.get("a")) + assert isinstance(td["a"], torch.Tensor) + assert isinstance(td.get("a"), UnbatchedTensor) + + def test_unbatched_shape_ops(self): + td = TensorDict( + a=UnbatchedTensor(torch.randn(10)), + b=torch.randn(3), + batch_size=(3,), + ) + # get item + assert td[0]["a"] is td["a"] + assert td[:]["a"] is td["a"] + + unbind = td.unbind(0)[0] + assert unbind["a"] is td["a"] + assert unbind.batch_size == () + + split = td.split(1)[0] + assert split["a"] is td["a"] + assert split.batch_size == (1,) + assert td.split((2, 1))[0]["a"] is td["a"] + + reshape = td.reshape((1, 3)) + assert reshape["a"] is td["a"] + assert reshape.batch_size == (1, 3) + transpose = reshape.transpose(0, 1) + assert transpose["a"] is td["a"] + assert transpose.batch_size == (3, 1) + permute = reshape.permute(1, 0) + assert permute["a"] is td["a"] + assert permute.batch_size == (3, 1) + squeeze = reshape.squeeze() + assert squeeze["a"] is td["a"] + assert squeeze.batch_size == (3,) + + view = td.view((1, 3)) + assert view["a"] is td["a"] + assert view.batch_size == (1, 3) + unsqueeze = td.unsqueeze(0) + assert unsqueeze["a"] is td["a"] + assert unsqueeze.batch_size == (1, 3) + gather = td.gather(0, torch.tensor((0,))) + assert gather["a"] is td["a"] + assert gather.batch_size == (1,) + + unflatten = td.unflatten(0, (1, 3)) + assert unflatten["a"] is td["a"] + assert unflatten.batch_size == (1, 3) + assert unflatten.get("a").batch_size == (1, 3) + assert unflatten.get("a")._tensordict.batch_size == () + + flatten = unflatten.flatten(0, 1) + assert flatten["a"] is td["a"] + assert flatten.batch_size == (3,) + + def test_unbatched_torch_func(self): + td = TensorDict( + a=UnbatchedTensor(torch.randn(10)), + b=torch.randn(3), + batch_size=(3,), + ) + assert torch.unbind(td, 0)[0]["a"] is td["a"] + assert torch.stack([td, td], 0)[0]["a"] is td["a"] + assert torch.cat([td, td], 0)[0]["a"] is td["a"] + assert (torch.ones_like(td)["a"] == 1).all() + assert torch.unsqueeze(td, 0)["a"] is td["a"] + assert torch.squeeze(td)["a"] is td["a"] + unflatten = torch.unflatten(td, 0, (1, 3)) + assert unflatten["a"] is td["a"] + flatten = torch.flatten(unflatten, 0, 1) + assert flatten["a"] is td["a"] + permute = torch.permute(unflatten, (1, 0)) + assert permute["a"] is td["a"] + transpose = torch.transpose(unflatten, 1, 0) + assert transpose["a"] is td["a"] + + def test_unbatched_other_ops(self): + td = TensorDict( + a=UnbatchedTensor(torch.randn(10)), + b=torch.randn(3), + c_d=UnbatchedTensor(torch.randn(10)), + batch_size=(3,), + ) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + assert td.copy()["a"] is td["a"] + assert td.int()["a"].dtype == torch.int + assert td.to(device)["a"].device == device + assert td.select("a")["a"] is td["a"] + assert td.exclude("b")["a"] is td["a"] + assert td.unflatten_keys(separator="_")["c", "d"] is td["c_d"] + assert td.unflatten_keys(separator="_").flatten_keys()["c.d"] is td["c_d"] + + def _to_float(td, td_name, tmpdir): if hasattr(td, "_source"): td._source = td._source.float()