From aeff8376045135bb7741ae07cd8a6efba1cecec0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 8 Jan 2025 13:00:37 +0000 Subject: [PATCH] [Feature] broadcast pointwise ops for tensor/tensordict mixed inputs ghstack-source-id: bbefbb1a2e9841847c618bb9cf49160ff1a5c36a Pull Request resolved: https://github.com/pytorch/tensordict/pull/1166 --- docs/source/reference/tensordict.rst | 108 ++++++++++++++++ tensordict/_td.py | 9 ++ tensordict/base.py | 112 ++++++++++++++++- test/test_tensordict.py | 181 ++++++++++++++++++++++++++- 4 files changed, 404 insertions(+), 6 deletions(-) diff --git a/docs/source/reference/tensordict.rst b/docs/source/reference/tensordict.rst index d6e941c7e..bd9c5bf09 100644 --- a/docs/source/reference/tensordict.rst +++ b/docs/source/reference/tensordict.rst @@ -109,6 +109,114 @@ However, physical storage of PyTorch tensors should not be any different: MemoryMappedTensor +Pointwise Operations +-------------------- + +Tensordict supports various pointwise operations, allowing you to perform element-wise computations on the tensors +stored within it. These operations are similar to those performed on regular PyTorch tensors. + +Supported Operations +~~~~~~~~~~~~~~~~~~~~ + +The following pointwise operations are currently supported: + +- Left and right addition (`+`) +- Left and right subtraction (`-`) +- Left and right multiplication (`*`) +- Left and right division (`/`) +- Left power (`**`) + +Many other ops, like :meth:`~tensordict.TensorDict.clamp`, :meth:`~tensordict.TensorDict.sqrt` etc. are supported. + +Performing Pointwise Operations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You can perform pointwise operations between two Tensordicts or between a Tensordict and a tensor/scalar value. + +Example 1: Tensordict-Tensordict Operation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + >>> import torch + >>> from tensordict import TensorDict + >>> td1 = TensorDict( + ... a=torch.randn(3, 4), + ... b=torch.zeros(3, 4, 5), + ... c=torch.ones(3, 4, 5, 6), + ... batch_size=(3, 4), + ... ) + >>> td2 = TensorDict( + ... a=torch.randn(3, 4), + ... b=torch.zeros(3, 4, 5), + ... c=torch.ones(3, 4, 5, 6), + ... batch_size=(3, 4), + ... ) + >>> result = td1 * td2 + +In this example, the * operator is applied element-wise to the corresponding tensors in td1 and td2. + +Example 2: Tensordict-Tensor Operation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + >>> import torch + >>> from tensordict import TensorDict + >>> td = TensorDict( + ... a=torch.randn(3, 4), + ... b=torch.zeros(3, 4, 5), + ... c=torch.ones(3, 4, 5, 6), + ... batch_size=(3, 4), + ... ) + >>> tensor = torch.randn(4) + >>> result = td * tensor + +ere, the * operator is applied element-wise to each tensor in td and the provided tensor. The tensor is broadcasted to match the shape of each tensor in the Tensordict. + +Example 3: Tensordict-Scalar Operation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + >>> import torch + >>> from tensordict import TensorDict + >>> td = TensorDict( + ... a=torch.randn(3, 4), + ... b=torch.zeros(3, 4, 5), + ... c=torch.ones(3, 4, 5, 6), + ... batch_size=(3, 4), + ... ) + >>> scalar = 2.0 + >>> result = td * scalar + +In this case, the * operator is applied element-wise to each tensor in td and the provided scalar. + +Broadcasting Rules +~~~~~~~~~~~~~~~~~~ + +When performing pointwise operations between a Tensordict and a tensor/scalar, the tensor/scalar is broadcasted to match +the shape of each tensor in the Tensordict: the tensor is broadcast on the left to match the tensordict shape, then +individually broadcast on the right to match the tensors shapes. This follows the standard broadcasting rules used in +PyTorch if one thinks of the ``TensorDict`` as a single tensor instance. + +For example, if you have a Tensordict with tensors of shape ``(3, 4)`` and you multiply it by a tensor of shape ``(4,)``, +the tensor will be broadcasted to shape (3, 4) before the operation is applied. If the tensordict contains a tensor of +shape ``(3, 4, 5)``, the tensor used for the multiplication will be broadcast to ``(3, 4, 5)`` on the right for that +multiplication. + +If the pointwise operation is executed across multiple tensordicts and their batch-size differ, they will be +broadcasted to a common shape. + +Efficiency of pointwise operations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +When possible, ``torch._foreach_`` fused kernels will be used to speed up the computation of the pointwise +operation. + +Handling Missing Entries +~~~~~~~~~~~~~~~~~~~~~~~~ + +When performing pointwise operations between two Tensordicts, they must have the same keys. +Some operations, like :meth:`~tensordict.TensorDict.add`, have a ``default`` keyword argument that can be used +to operate with tensordict with exclusive entries. +If ``default=None`` (the default), the two Tensordicts must have exactly matching key sets. +If ``default="intersection"``, only the intersecting key sets will be considered, and other keys will be ignored. +In all other cases, ``default`` will be used for all missing entries on both sides of the operation. Utils ----- diff --git a/tensordict/_td.py b/tensordict/_td.py index ccc5ffbe3..fa2db9ac9 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -30,6 +30,7 @@ _is_leaf_nontensor, _is_tensor_collection, _load_metadata, + _maybe_broadcast_other, _NESTED_TENSORS_AS_LISTS, _register_tensor_class, BEST_ATTEMPT_INPLACE, @@ -611,6 +612,7 @@ def _quick_set(swap_dict, swap_td): else: return TensorDict._new_unsafe(_swap, batch_size=torch.Size(())) + @_maybe_broadcast_other("__ne__") def __ne__(self, other: Any) -> T | bool: if is_tensorclass(other): return other != self @@ -635,6 +637,7 @@ def __ne__(self, other: Any) -> T | bool: ) return True + @_maybe_broadcast_other("__xor__") def __xor__(self, other: Any) -> T | bool: if is_tensorclass(other): return other ^ self @@ -659,6 +662,7 @@ def __xor__(self, other: Any) -> T | bool: ) return True + @_maybe_broadcast_other("__or__") def __or__(self, other: Any) -> T | bool: if is_tensorclass(other): return other | self @@ -683,6 +687,7 @@ def __or__(self, other: Any) -> T | bool: ) return False + @_maybe_broadcast_other("__eq__") def __eq__(self, other: Any) -> T | bool: if is_tensorclass(other): return other == self @@ -705,6 +710,7 @@ def __eq__(self, other: Any) -> T | bool: ) return False + @_maybe_broadcast_other("__ge__") def __ge__(self, other: Any) -> T | bool: if is_tensorclass(other): return other <= self @@ -727,6 +733,7 @@ def __ge__(self, other: Any) -> T | bool: ) return False + @_maybe_broadcast_other("__gt__") def __gt__(self, other: Any) -> T | bool: if is_tensorclass(other): return other < self @@ -749,6 +756,7 @@ def __gt__(self, other: Any) -> T | bool: ) return False + @_maybe_broadcast_other("__le__") def __le__(self, other: Any) -> T | bool: if is_tensorclass(other): return other >= self @@ -771,6 +779,7 @@ def __le__(self, other: Any) -> T | bool: ) return False + @_maybe_broadcast_other("__lt__") def __lt__(self, other: Any) -> T | bool: if is_tensorclass(other): return other > self diff --git a/tensordict/base.py b/tensordict/base.py index 5c0476560..a37ecaf31 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -205,6 +205,59 @@ def has_transfer(self): _device_recorder = _RecordDeviceTransfer() +def _maybe_broadcast_other(op: str, n_other: int = 1): + """Ensures that elementwise ops are broadcast when an nd tensor is passed.""" + + def wrap_func(func): + @wraps(func) + def new_func(self, *others, **kwargs): + others, args = others[:n_other], others[n_other:] + need_broadcast = False + for other in others: + if other is None: + continue + if (isinstance(other, torch.Tensor) and other.ndim) or ( + _is_tensor_collection(type(other)) + and other.ndim + and other.shape != self.shape + ): + need_broadcast = True + break + if not need_broadcast: + return func(self, *others, *args, **kwargs) + others_map = [] + shape = self.shape + self_expand = self + shapes = [shape, *[other.shape for other in others if other is not None]] + shape = torch.broadcast_shapes(*shapes) + if shape != self_expand.shape: + self_expand = self_expand.expand(shape) + for other in others: + if other is None: + others_map.append(other) + continue + # broadcast dims + if shape != other.shape: + other = other.expand(shape) + others_map.append(other) + if any(isinstance(other, torch.Tensor) for other in others_map): + return self_expand._fast_apply( + lambda x: getattr(x, op)( + *[ + expand_as_right(other, x) if other is not None else None + for other in others_map + ], + *args, + **kwargs, + ) + ) + return getattr(self_expand, op)(*others_map, *args, **kwargs) + + return new_func + + return wrap_func + + class TensorDictBase(MutableMapping): """TensorDictBase is an abstract parent class for TensorDicts, a torch.Tensor data container.""" @@ -9849,6 +9902,7 @@ def pop(name, val): result.update(items) return result + @_maybe_broadcast_other("bitwise_and") def bitwise_and( self, other: TensorDictBase | torch.Tensor, @@ -9901,6 +9955,7 @@ def pop(name, val): result.update(items) return result + @_maybe_broadcast_other("logical_and") def logical_and( self, other: TensorDictBase | torch.Tensor, @@ -9953,6 +10008,7 @@ def pop(name, val): result.update(items) return result + @_maybe_broadcast_other("add") def add( self, other: TensorDictBase | torch.Tensor, @@ -10034,6 +10090,7 @@ def add_( torch._foreach_add_(vals, other_val) return self + @_maybe_broadcast_other("lerp", 2) def lerp( self, end: TensorDictBase | torch.Tensor, @@ -10097,6 +10154,7 @@ def lerp_( torch._foreach_lerp_(self._values_list(True, True), end_val, weight_val) return self + @_maybe_broadcast_other("addcdiv", 2) def addcdiv( self, other1: TensorDictBase | torch.Tensor, @@ -10159,7 +10217,14 @@ def addcdiv_(self, other1, other2, *, value: float | None = 1): ) return self - def addcmul(self, other1, other2, *, value: float | None = 1): # noqa: D417 + @_maybe_broadcast_other("addcmul", 2) + def addcmul( + self, + other1: TensorDictBase | torch.Tensor, + other2: TensorDictBase | torch.Tensor, + *, + value: float | None = 1, + ): # noqa: D417 r"""Performs the element-wise multiplication of :attr:`other1` by :attr:`other2`, multiplies the result by the scalar :attr:`value` and adds it to ``self``. .. math:: @@ -10216,6 +10281,7 @@ def addcmul_(self, other1, other2, *, value: float | None = 1): ) return self + @_maybe_broadcast_other("sub") def sub( self, other: TensorDictBase | torch.Tensor | float, @@ -10314,6 +10380,7 @@ def mul_(self, other: TensorDictBase | torch.Tensor) -> T: torch._foreach_mul_(vals, other_val) return self + @_maybe_broadcast_other("mul") def mul( self, other: TensorDictBase | torch.Tensor, @@ -10385,6 +10452,7 @@ def maximum_(self, other: TensorDictBase | torch.Tensor) -> T: torch._foreach_maximum_(vals, other_val) return self + @_maybe_broadcast_other("maximum") def maximum( self, other: TensorDictBase | torch.Tensor, @@ -10451,6 +10519,7 @@ def minimum_(self, other: TensorDictBase | torch.Tensor) -> T: torch._foreach_minimum_(vals, other_val) return self + @_maybe_broadcast_other("minimum") def minimum( self, other: TensorDictBase | torch.Tensor, @@ -10524,6 +10593,7 @@ def clamp_max_(self, other: TensorDictBase | torch.Tensor) -> T: ) return self + @_maybe_broadcast_other("clamp_max") def clamp_max( self, other: TensorDictBase | torch.Tensor, @@ -10605,6 +10675,7 @@ def clamp_min_(self, other: TensorDictBase | torch.Tensor) -> T: return self + @_maybe_broadcast_other("clamp_min") def clamp_min( self, other: TensorDictBase | torch.Tensor, @@ -10662,7 +10733,14 @@ def pop(name, val): result.update(items) return result - def clamp(self, min=None, max=None, *, out=None): # noqa: W605 + @_maybe_broadcast_other("clamp", 2) + def clamp( + self, + min: TensorDictBase | torch.Tensor = None, + max: TensorDictBase | torch.Tensor = None, + *, + out=None, + ): # noqa: W605 r"""Clamps all elements in :attr:`self` into the range `[` :attr:`min`, :attr:`max` `]`. Letting min_value and max_value be :attr:`min` and :attr:`max`, respectively, this returns: @@ -10690,11 +10768,33 @@ def clamp(self, min=None, max=None, *, out=None): # noqa: W605 "clamp() with min/max=None isn't implemented with specified output." ) return self.clamp_min(min) + + is_tc_min = is_tensor_collection(min) + is_tc_max = is_tensor_collection(max) + + if is_tc_min ^ is_tc_max: + raise ValueError( + "Mixed tensordict and non-tensordict min/max values are not authorized." + ) + if out is None: + if is_tc_min and is_tc_max: + return self._fast_apply( + lambda x, low, high: x.clamp(low, high), min, max, default=None + ) return self._fast_apply(lambda x: x.clamp(min, max)) - result = self._fast_apply( - lambda x, y: x.clamp(min, max, out=y), out, default=None - ) + if is_tc_min and is_tc_max: + result = self._fast_apply( + lambda x, y, low, high: x.clamp(low, high, out=y), + out, + min, + max, + default=None, + ) + else: + result = self._fast_apply( + lambda x, y: x.clamp(min, max, out=y), out, default=None + ) with out.unlock_() if out.is_locked else contextlib.nullcontext(): return out.update(result) @@ -10714,6 +10814,7 @@ def pow_(self, other: TensorDictBase | torch.Tensor) -> T: torch._foreach_pow_(vals, other_val) return self + @_maybe_broadcast_other("pow") def pow( self, other: TensorDictBase | torch.Tensor, @@ -10785,6 +10886,7 @@ def div_(self, other: TensorDictBase | torch.Tensor) -> T: torch._foreach_div_(vals, other_val) return self + @_maybe_broadcast_other("div") def div( self, other: TensorDictBase | torch.Tensor, diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 97191db16..a77cb11e7 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -3487,6 +3487,185 @@ def test_clamp_max_default(self): assert "d" not in tdpow assert "b" in tdpow + @pytest.mark.parametrize("shape", [(4,), (3, 4), (2, 3, 4)]) + def test_broadcast_tensor(self, shape): + torch.manual_seed(0) + td = TensorDict( + a=torch.randn(3, 4), + b=torch.zeros(3, 4, 5), + c=torch.ones(3, 4, 5, 6), + batch_size=(3, 4), + ) + broadcast_shape = torch.broadcast_shapes(shape, td.shape) + td_mul = td * torch.ones(shape) + assert td_mul.shape == broadcast_shape + assert (td_mul == td).all() + td_add = td + torch.ones(shape) + assert td_add.shape == broadcast_shape + assert (td_add == td + 1).all() + td_sub = td - torch.ones(shape) + assert td_sub.shape == broadcast_shape + assert (td_sub == td - 1).all() + td_div = td / torch.ones(shape) + assert td_div.shape == broadcast_shape + assert (td_div == td).all() + td_max = td.maximum(torch.ones(shape)) + assert td_max.shape == broadcast_shape + assert (td_max == td.maximum(torch.ones_like(td))).all() + td_min = td.minimum(torch.ones(shape)) + assert td_min.shape == broadcast_shape + assert (td_min == td.minimum(torch.ones_like(td))).all() + td_max = td.clamp_max(torch.ones(shape)) + assert td_max.shape == broadcast_shape + assert (td_max == td.clamp_max(torch.ones_like(td))).all() + td_min = td.clamp_min(torch.ones(shape)) + assert td_min.shape == broadcast_shape + assert (td_min == td.clamp_min(torch.ones_like(td))).all() + + td_clamp = td.clamp(-torch.ones(shape), torch.ones(shape)) + assert td_clamp.shape == broadcast_shape + assert_allclose_td( + td_clamp, + td.clamp(-torch.ones_like(td), torch.ones_like(td)).expand(broadcast_shape), + ) + td_clamp = td.clamp(None, torch.ones(shape)) + assert td_clamp.shape == broadcast_shape + assert_allclose_td( + td_clamp, td.clamp(None, torch.ones_like(td)).expand(broadcast_shape) + ) + td_clamp = td.clamp(-torch.ones(shape), None) + assert td_clamp.shape == broadcast_shape + assert_allclose_td( + td_clamp, td.clamp(-torch.ones_like(td), None).expand(broadcast_shape) + ) + + td_pow = td.pow(torch.ones(shape)) + assert td_pow.shape == broadcast_shape + assert (td_pow == td.pow(torch.ones_like(td))).all() + + td_ba = td.bool().bitwise_and(torch.ones(shape, dtype=torch.bool)) + assert td_ba.shape == broadcast_shape + assert (td_ba == td.bool().bitwise_and(torch.ones_like(td.bool()))).all() + + td_la = td.logical_and(torch.ones(shape)) + assert td_la.shape == broadcast_shape + assert (td_la == td.logical_and(torch.ones_like(td))).all() + + td_lerp = td.lerp(-torch.ones(shape), torch.ones(shape)) + assert td_lerp.shape == broadcast_shape + assert_allclose_td( + td_lerp, + td.lerp(-torch.ones_like(td), torch.ones_like(td)).expand(broadcast_shape), + ) + + td_addcdiv = td.addcdiv(-torch.ones(shape), torch.ones(shape)) + assert td_addcdiv.shape == broadcast_shape + assert_allclose_td( + td_addcdiv, + td.addcdiv(-torch.ones_like(td), torch.ones_like(td)).expand( + broadcast_shape + ), + ) + + td_addcmul = td.addcmul(-torch.ones(shape), torch.ones(shape)) + assert td_addcmul.shape == broadcast_shape + assert_allclose_td( + td_addcmul, + td.addcmul(-torch.ones_like(td), torch.ones_like(td)).expand( + broadcast_shape + ), + ) + + @pytest.mark.parametrize("shape", [(4,), (3, 4), (2, 3, 4)]) + def test_broadcast_tensordict(self, shape): + torch.manual_seed(0) + td = TensorDict( + a=torch.randn(3, 4), + b=torch.zeros(3, 4, 5), + c=torch.ones(3, 4, 5, 6), + batch_size=(3, 4), + ) + td_mul = td * torch.ones(shape) + td_mul = td * td.new_ones(shape) + broadcast_shape = torch.broadcast_shapes(shape, td.shape) + assert td_mul.shape == broadcast_shape + assert (td_mul == td).all() + td_add = td + td.new_ones(shape) + assert td_add.shape == broadcast_shape + assert (td_add == td + 1).all() + td_sub = td - td.new_ones(shape) + assert td_sub.shape == broadcast_shape + assert (td_sub == td - 1).all() + td_div = td / td.new_ones(shape) + assert td_div.shape == broadcast_shape + assert (td_div == td).all() + td_max = td.maximum(td.new_ones(shape)) + assert td_max.shape == broadcast_shape + assert (td_max == td.maximum(torch.ones_like(td))).all() + td_min = td.minimum(td.new_ones(shape)) + assert td_min.shape == broadcast_shape + assert (td_min == td.minimum(torch.ones_like(td))).all() + td_max = td.clamp_max(td.new_ones(shape)) + assert td_max.shape == broadcast_shape + assert (td_max == td.clamp_max(torch.ones_like(td))).all() + td_min = td.clamp_min(td.new_ones(shape)) + assert td_min.shape == broadcast_shape + assert (td_min == td.clamp_min(torch.ones_like(td))).all() + + td_clamp = td.clamp(-td.new_ones(shape), td.new_ones(shape)) + assert td_clamp.shape == broadcast_shape + assert_allclose_td( + td_clamp, + td.clamp(-torch.ones_like(td), torch.ones_like(td)).expand(broadcast_shape), + ) + td_clamp = td.clamp(None, td.new_ones(shape)) + assert td_clamp.shape == broadcast_shape + assert_allclose_td( + td_clamp, td.clamp(None, torch.ones_like(td)).expand(broadcast_shape) + ) + td_clamp = td.clamp(-torch.ones(shape), None) + assert td_clamp.shape == broadcast_shape + assert_allclose_td( + td_clamp, td.clamp(-torch.ones_like(td), None).expand(broadcast_shape) + ) + + td_pow = td.pow(td.new_ones(shape)) + assert td_pow.shape == broadcast_shape + assert (td_pow == td.pow(torch.ones_like(td))).all() + + td_ba = td.bool().bitwise_and(td.new_ones(shape, dtype=torch.bool)) + assert td_ba.shape == broadcast_shape + assert (td_ba == td.bool().bitwise_and(torch.ones_like(td.bool()))).all() + + td_la = td.logical_and(td.new_ones(shape)) + assert td_la.shape == broadcast_shape + assert (td_la == td.logical_and(torch.ones_like(td))).all() + + td_lerp = td.lerp(-td.new_ones(shape), td.new_ones(shape)) + assert td_lerp.shape == broadcast_shape + assert_allclose_td( + td_lerp, + td.lerp(-torch.ones_like(td), torch.ones_like(td)).expand(broadcast_shape), + ) + + td_addcdiv = td.addcdiv(-td.new_ones(shape), td.new_ones(shape)) + assert td_addcdiv.shape == broadcast_shape + assert_allclose_td( + td_addcdiv, + td.addcdiv(-torch.ones_like(td), torch.ones_like(td)).expand( + broadcast_shape + ), + ) + + td_addcmul = td.addcmul(-td.new_ones(shape), td.new_ones(shape)) + assert td_addcmul.shape == broadcast_shape + assert_allclose_td( + td_addcmul, + td.addcmul(-torch.ones_like(td), torch.ones_like(td)).expand( + broadcast_shape + ), + ) + @pytest.mark.parametrize( "td_name,device", @@ -7005,7 +7184,7 @@ def test_unflatten_keys(self, td_name, device, inplace, separator): td_unflatten = td_flatten.unflatten_keys( inplace=inplace, separator=separator ) - assert (td == td_unflatten).all() + assert (td == td.empty(recurse=True).update(td_unflatten)).all() if inplace: assert td is td_unflatten