Skip to content

Commit

Permalink
[Feature] broadcast pointwise ops for tensor/tensordict mixed inputs
Browse files Browse the repository at this point in the history
ghstack-source-id: bbefbb1a2e9841847c618bb9cf49160ff1a5c36a
Pull Request resolved: #1166
  • Loading branch information
vmoens committed Jan 8, 2025
1 parent 646683c commit aeff837
Show file tree
Hide file tree
Showing 4 changed files with 404 additions and 6 deletions.
108 changes: 108 additions & 0 deletions docs/source/reference/tensordict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,114 @@ However, physical storage of PyTorch tensors should not be any different:

MemoryMappedTensor

Pointwise Operations
--------------------

Tensordict supports various pointwise operations, allowing you to perform element-wise computations on the tensors
stored within it. These operations are similar to those performed on regular PyTorch tensors.

Supported Operations
~~~~~~~~~~~~~~~~~~~~

The following pointwise operations are currently supported:

- Left and right addition (`+`)
- Left and right subtraction (`-`)
- Left and right multiplication (`*`)
- Left and right division (`/`)
- Left power (`**`)

Many other ops, like :meth:`~tensordict.TensorDict.clamp`, :meth:`~tensordict.TensorDict.sqrt` etc. are supported.

Performing Pointwise Operations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

You can perform pointwise operations between two Tensordicts or between a Tensordict and a tensor/scalar value.

Example 1: Tensordict-Tensordict Operation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

>>> import torch
>>> from tensordict import TensorDict
>>> td1 = TensorDict(
... a=torch.randn(3, 4),
... b=torch.zeros(3, 4, 5),
... c=torch.ones(3, 4, 5, 6),
... batch_size=(3, 4),
... )
>>> td2 = TensorDict(
... a=torch.randn(3, 4),
... b=torch.zeros(3, 4, 5),
... c=torch.ones(3, 4, 5, 6),
... batch_size=(3, 4),
... )
>>> result = td1 * td2

In this example, the * operator is applied element-wise to the corresponding tensors in td1 and td2.

Example 2: Tensordict-Tensor Operation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict(
... a=torch.randn(3, 4),
... b=torch.zeros(3, 4, 5),
... c=torch.ones(3, 4, 5, 6),
... batch_size=(3, 4),
... )
>>> tensor = torch.randn(4)
>>> result = td * tensor

ere, the * operator is applied element-wise to each tensor in td and the provided tensor. The tensor is broadcasted to match the shape of each tensor in the Tensordict.

Example 3: Tensordict-Scalar Operation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

>>> import torch
>>> from tensordict import TensorDict
>>> td = TensorDict(
... a=torch.randn(3, 4),
... b=torch.zeros(3, 4, 5),
... c=torch.ones(3, 4, 5, 6),
... batch_size=(3, 4),
... )
>>> scalar = 2.0
>>> result = td * scalar

In this case, the * operator is applied element-wise to each tensor in td and the provided scalar.

Broadcasting Rules
~~~~~~~~~~~~~~~~~~

When performing pointwise operations between a Tensordict and a tensor/scalar, the tensor/scalar is broadcasted to match
the shape of each tensor in the Tensordict: the tensor is broadcast on the left to match the tensordict shape, then
individually broadcast on the right to match the tensors shapes. This follows the standard broadcasting rules used in
PyTorch if one thinks of the ``TensorDict`` as a single tensor instance.

For example, if you have a Tensordict with tensors of shape ``(3, 4)`` and you multiply it by a tensor of shape ``(4,)``,
the tensor will be broadcasted to shape (3, 4) before the operation is applied. If the tensordict contains a tensor of
shape ``(3, 4, 5)``, the tensor used for the multiplication will be broadcast to ``(3, 4, 5)`` on the right for that
multiplication.

If the pointwise operation is executed across multiple tensordicts and their batch-size differ, they will be
broadcasted to a common shape.

Efficiency of pointwise operations
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

When possible, ``torch._foreach_<op>`` fused kernels will be used to speed up the computation of the pointwise
operation.

Handling Missing Entries
~~~~~~~~~~~~~~~~~~~~~~~~

When performing pointwise operations between two Tensordicts, they must have the same keys.
Some operations, like :meth:`~tensordict.TensorDict.add`, have a ``default`` keyword argument that can be used
to operate with tensordict with exclusive entries.
If ``default=None`` (the default), the two Tensordicts must have exactly matching key sets.
If ``default="intersection"``, only the intersecting key sets will be considered, and other keys will be ignored.
In all other cases, ``default`` will be used for all missing entries on both sides of the operation.

Utils
-----
Expand Down
9 changes: 9 additions & 0 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
_is_leaf_nontensor,
_is_tensor_collection,
_load_metadata,
_maybe_broadcast_other,
_NESTED_TENSORS_AS_LISTS,
_register_tensor_class,
BEST_ATTEMPT_INPLACE,
Expand Down Expand Up @@ -611,6 +612,7 @@ def _quick_set(swap_dict, swap_td):
else:
return TensorDict._new_unsafe(_swap, batch_size=torch.Size(()))

@_maybe_broadcast_other("__ne__")
def __ne__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other != self
Expand All @@ -635,6 +637,7 @@ def __ne__(self, other: Any) -> T | bool:
)
return True

@_maybe_broadcast_other("__xor__")
def __xor__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other ^ self
Expand All @@ -659,6 +662,7 @@ def __xor__(self, other: Any) -> T | bool:
)
return True

@_maybe_broadcast_other("__or__")
def __or__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other | self
Expand All @@ -683,6 +687,7 @@ def __or__(self, other: Any) -> T | bool:
)
return False

@_maybe_broadcast_other("__eq__")
def __eq__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other == self
Expand All @@ -705,6 +710,7 @@ def __eq__(self, other: Any) -> T | bool:
)
return False

@_maybe_broadcast_other("__ge__")
def __ge__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other <= self
Expand All @@ -727,6 +733,7 @@ def __ge__(self, other: Any) -> T | bool:
)
return False

@_maybe_broadcast_other("__gt__")
def __gt__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other < self
Expand All @@ -749,6 +756,7 @@ def __gt__(self, other: Any) -> T | bool:
)
return False

@_maybe_broadcast_other("__le__")
def __le__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other >= self
Expand All @@ -771,6 +779,7 @@ def __le__(self, other: Any) -> T | bool:
)
return False

@_maybe_broadcast_other("__lt__")
def __lt__(self, other: Any) -> T | bool:
if is_tensorclass(other):
return other > self
Expand Down
Loading

0 comments on commit aeff837

Please sign in to comment.