Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add __abs__ docstrings, __neg__, __rxor__, __ror__, __invert__, __and__, __rand__, __radd__, __rtruediv__, __rmul__, __rsub__, __rpow__, bitwise_and, logical_and #1154

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 234 additions & 8 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,30 @@ class TensorDictBase(MutableMapping):
def __bool__(self) -> bool:
raise RuntimeError("Converting a tensordict to boolean value is not permitted")

def __abs__(self) -> T:
"""Returns a new TensorDict instance with absolute values of all tensors.

Returns:
A new TensorDict instance with the same key set as the original,
but with all tensors having their absolute values computed.

.. seealso:: :meth:`~.abs`

"""
return self.abs()

def __neg__(self) -> T:
"""Returns a new TensorDict instance with negated values of all tensors.

Returns:
A new TensorDict instance with the same key set as the original,
but with all tensors having their values negated.

.. seealso:: :meth:`~.neg`

"""
return self.neg()

@abc.abstractmethod
def __ne__(self, other: object) -> T:
"""NOT operation over two tensordicts, for evey key.
Expand All @@ -237,7 +261,7 @@ def __ne__(self, other: object) -> T:
...

@abc.abstractmethod
def __xor__(self, other: TensorDictBase | float):
def __xor__(self, other: TensorDictBase | torch.Tensor | float):
"""XOR operation over two tensordicts, for evey key.

The two tensordicts must have the same key set.
Expand All @@ -252,6 +276,13 @@ def __xor__(self, other: TensorDictBase | float):
"""
...

def __rxor__(self, other: TensorDictBase | torch.Tensor | float):
"""XOR operation over two tensordicts, for evey key.

Wraps `__xor__` as it is assumed to be commutative.
"""
return self.__xor__(other)

@abc.abstractmethod
def __or__(self, other: TensorDictBase | torch.Tensor) -> T:
"""OR operation over two tensordicts, for evey key.
Expand All @@ -268,6 +299,71 @@ def __or__(self, other: TensorDictBase | torch.Tensor) -> T:
"""
...

def __ror__(self, other: TensorDictBase | torch.Tensor) -> T:
"""Right-side OR operation over two tensordicts, for evey key.

This is a wrapper around `__or__` since it is assumed to be commutative.
"""
return self | other

def __invert__(self) -> T:
"""Returns a new TensorDict instance with all tensors inverted (i.e., bitwise NOT operation).

Returns:
A new TensorDict instance with the same key set as the original,
but with all tensors having their bits inverted.
"""
keys, vals = self._items_list(True, True)
vals = [~v for v in vals]
items = dict(zip(keys, vals))

def get(name, val):
return items.get(name, val)

return self._fast_apply(
get,
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
)

def __and__(self, other: TensorDictBase | torch.Tensor | float) -> T:
"""Returns a new TensorDict instance with all tensors performing a logical or bitwise AND operation with the given value.

Args:
other: The value to perform the AND operation with.

Returns:
A new TensorDict instance with the same key set as the original,
but with all tensors having performed a AND operation with the given value.
"""
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
new_keys, other_val = other._items_list(True, True, sorting_keys=keys)
vals = [(v1 & v2) for v1, v2 in zip(vals, other_val)]
else:
vals = [(v & other) for v in vals]
items = dict(zip(keys, vals))

def pop(name, val):
return items.pop(name, None)

result = self._fast_apply(
pop,
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
filter_empty=True,
default=None,
)
if items:
result.update(items)
return result

__rand__ = __and__

@abc.abstractmethod
def __eq__(self, other: object) -> T:
"""Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set.
Expand Down Expand Up @@ -8903,21 +8999,27 @@ def record(tensor):
def __add__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.add(other)

def __radd__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.add(other)

def __iadd__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.add_(other)

def __abs__(self):
return self.abs()

def __truediv__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.div(other)

def __itruediv__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.div_(other)

def __rtruediv__(self, other: TensorDictBase | torch.Tensor) -> T:
return other * self.reciprocal()

def __mul__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.mul(other)

def __rmul__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.mul(other)

def __imul__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.mul_(other)

Expand All @@ -8927,9 +9029,18 @@ def __sub__(self, other: TensorDictBase | torch.Tensor) -> T:
def __isub__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.sub_(other)

def __rsub__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.sub(other)

def __pow__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.pow(other)

def __rpow__(self, other: TensorDictBase | torch.Tensor) -> T:
raise NotImplementedError(
"rpow isn't implemented for tensordict yet. Make sure both elements are wrapped "
"in a tensordict for this to work."
)

def __ipow__(self, other: TensorDictBase | torch.Tensor) -> T:
return self.pow_(other)

Expand Down Expand Up @@ -9661,6 +9772,110 @@ def pop(name, val):
result.update(items)
return result

def bitwise_and(
self,
other: TensorDictBase | torch.Tensor,
*,
default: str | CompatibleType | None = None,
) -> TensorDictBase: # noqa: D417
r"""Performs a bitwise AND operation between ``self`` and :attr:`other`.

.. math::
\text{{out}}_i = \text{{input}}_i \land \text{{other}}_i

Args:
other (TensorDictBase or torch.Tensor): the tensor or TensorDict to perform the bitwise AND with.

Keyword Args:
default (torch.Tensor or str, optional): the default value to use for exclusive entries.
If none is provided, the two tensordicts key list must match exactly.
If ``default="intersection"`` is passed, only the intersecting key sets will be considered
and other keys will be ignored.
In all other cases, ``default`` will be used for all missing entries on both sides of the
operation.
"""
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
new_keys, other_val = other._items_list(
True, True, sorting_keys=keys, default=default
)
if default is not None:
as_dict = dict(zip(keys, vals))
vals = [as_dict.get(key, default) for key in new_keys]
keys = new_keys
vals = [(v1.bitwise_and(v2)) for v1, v2 in zip(vals, other_val)]
else:
vals = [v.bitwise_and(other) for v in vals]
items = dict(zip(keys, vals))

def pop(name, val):
return items.pop(name, None)

result = self._fast_apply(
pop,
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
filter_empty=True,
default=None,
)
if items:
result.update(items)
return result

def logical_and(
self,
other: TensorDictBase | torch.Tensor,
*,
default: str | CompatibleType | None = None,
) -> TensorDictBase: # noqa: D417
r"""Performs a logical AND operation between ``self`` and :attr:`other`.

.. math::
\text{{out}}_i = \text{{input}}_i \land \text{{other}}_i

Args:
other (TensorDictBase or torch.Tensor): the tensor or TensorDict to perform the logical AND with.

Keyword Args:
default (torch.Tensor or str, optional): the default value to use for exclusive entries.
If none is provided, the two tensordicts key list must match exactly.
If ``default="intersection"`` is passed, only the intersecting key sets will be considered
and other keys will be ignored.
In all other cases, ``default`` will be used for all missing entries on both sides of the
operation.
"""
keys, vals = self._items_list(True, True)
if _is_tensor_collection(type(other)):
new_keys, other_val = other._items_list(
True, True, sorting_keys=keys, default=default
)
if default is not None:
as_dict = dict(zip(keys, vals))
vals = [as_dict.get(key, default) for key in new_keys]
keys = new_keys
vals = [(v1.logical_and(v2)) for v1, v2 in zip(vals, other_val)]
else:
vals = [v.logical_and(other) for v in vals]
items = dict(zip(keys, vals))

def pop(name, val):
return items.pop(name, None)

result = self._fast_apply(
pop,
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=True,
filter_empty=True,
default=None,
)
if items:
result.update(items)
return result

def add(
self,
other: TensorDictBase | torch.Tensor,
Expand Down Expand Up @@ -9719,7 +9934,12 @@ def pop(name, val):
result.update(items)
return result

def add_(self, other: TensorDictBase | float, *, alpha: float | None = None):
def add_(
self,
other: TensorDictBase | torch.Tensor | float,
*,
alpha: float | None = None,
):
"""In-place version of :meth:`~.add`.

.. note::
Expand Down Expand Up @@ -9781,7 +10001,11 @@ def get(name, val):
propagate_lock=True,
)

def lerp_(self, end: TensorDictBase | float, weight: TensorDictBase | float):
def lerp_(
self,
end: TensorDictBase | torch.Tensor | float,
weight: TensorDictBase | torch.Tensor | float,
):
"""In-place version of :meth:`~.lerp`."""
if _is_tensor_collection(type(end)):
end_val = end._values_list(True, True)
Expand Down Expand Up @@ -9917,7 +10141,7 @@ def addcmul_(self, other1, other2, *, value: float | None = 1):

def sub(
self,
other: TensorDictBase | float,
other: TensorDictBase | torch.Tensor | float,
*,
alpha: float | None = None,
default: str | CompatibleType | None = None,
Expand Down Expand Up @@ -9976,7 +10200,9 @@ def pop(name, val):
result.update(items)
return result

def sub_(self, other: TensorDictBase | float, alpha: float | None = None):
def sub_(
self, other: TensorDictBase | torch.Tensor | float, alpha: float | None = None
):
"""In-place version of :meth:`~.sub`.

.. note::
Expand Down
18 changes: 15 additions & 3 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,19 +166,30 @@ def __subclasscheck__(self, subclass):
_FALLBACK_METHOD_FROM_TD = [
"__abs__",
"__add__",
"__and__",
"__bool__",
"__eq__",
"__ge__",
"__gt__",
"__iadd__",
"__imul__",
"__invert__",
"__ipow__",
"__isub__",
"__itruediv__",
"__mul__",
"__ne__",
"__neg__",
"__or__",
"__pow__",
"__radd__",
"__rand__",
"__rmul__",
"__ror__",
"__rpow__",
"__rsub__",
"__rtruediv__",
"__rxor__",
"__sub__",
"__truediv__",
"__xor__",
Expand Down Expand Up @@ -228,6 +239,7 @@ def __subclasscheck__(self, subclass):
"atan_",
"auto_batch_size_",
"auto_device_",
"bitwise_and",
"ceil",
"ceil_",
"chunk",
Expand Down Expand Up @@ -291,10 +303,8 @@ def __subclasscheck__(self, subclass):
"log2",
"log2_",
"log_",
"map",
"logical_and" "map",
"map_iter",
"to_namedtuple",
"to_pytree",
"masked_fill",
"masked_fill_",
"max",
Expand Down Expand Up @@ -356,6 +366,8 @@ def __subclasscheck__(self, subclass):
"tanh_",
"to",
"to_module",
"to_namedtuple",
"to_pytree",
"transpose",
"trunc",
"trunc_",
Expand Down
Loading
Loading