Skip to content

Commit

Permalink
[Feature] Storing non-tensor data in tensordicts (#601)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 20, 2023
1 parent 68e8f40 commit 7ac55a1
Show file tree
Hide file tree
Showing 13 changed files with 975 additions and 107 deletions.
3 changes: 2 additions & 1 deletion docs/source/reference/prototype.rst
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,5 @@ Here is an example:
:toctree: generated/
:template: td_template.rst

@tensorclass
tensorclass
NonTensorData
3 changes: 2 additions & 1 deletion tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tensordict.memmap import MemoryMappedTensor
from tensordict.memmap_deprec import is_memmap, MemmapTensor, set_transfer_ownership
from tensordict.persistent import PersistentTensorDict
from tensordict.tensorclass import tensorclass
from tensordict.tensorclass import NonTensorData, tensorclass
from tensordict.utils import assert_allclose_td, is_batchedtensor, is_tensorclass

try:
Expand Down Expand Up @@ -46,6 +46,7 @@
"PersistentTensorDict",
"tensorclass",
"dense_stack_tds",
"NonTensorData",
]

# from tensordict._pytree import *
30 changes: 19 additions & 11 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from copy import copy, deepcopy
from pathlib import Path
from textwrap import indent
from typing import Any, Callable, Iterator, Sequence
from typing import Any, Callable, Iterator, Sequence, Type

import numpy as np
import torch
Expand Down Expand Up @@ -183,11 +183,6 @@ def __init__(
"at least one tensordict must be provided to "
"StackedTensorDict to be instantiated"
)
if not isinstance(tensordicts[0], TensorDictBase):
raise TypeError(
f"Expected input to be TensorDictBase instance"
f" but got {type(tensordicts[0])} instead."
)
if stack_dim < 0:
raise RuntimeError(
f"stack_dim must be non negative, got stack_dim={stack_dim}"
Expand All @@ -196,7 +191,7 @@ def __init__(
device = tensordicts[0].device

for td in tensordicts[1:]:
if not isinstance(td, TensorDictBase):
if not is_tensor_collection(td):
raise TypeError(
"Expected all inputs to be TensorDictBase instances but got "
f"{type(td)} instead."
Expand Down Expand Up @@ -1057,10 +1052,16 @@ def _change_batch_size(self, new_size: torch.Size) -> None:
self._batch_size = new_size

def keys(
self, include_nested: bool = False, leaves_only: bool = False
self,
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
) -> _LazyStackedTensorDictKeysView:
keys = _LazyStackedTensorDictKeysView(
self, include_nested=include_nested, leaves_only=leaves_only
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
return keys

Expand Down Expand Up @@ -1970,6 +1971,7 @@ def _repr_exclusive_fields(self):
unlock = _renamed_inplace_method(unlock_)

__xor__ = TensorDict.__xor__
__or__ = TensorDict.__or__
_check_device = TensorDict._check_device
_check_is_shared = TensorDict._check_is_shared
_convert_to_tensordict = TensorDict._convert_to_tensordict
Expand Down Expand Up @@ -2195,9 +2197,14 @@ def __repr__(self) -> str:

# @cache # noqa: B019
def keys(
self, include_nested: bool = False, leaves_only: bool = False
self,
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
) -> _TensorDictKeysView:
return self._source.keys(include_nested=include_nested, leaves_only=leaves_only)
return self._source.keys(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
)

def select(
self, *keys: str, inplace: bool = False, strict: bool = True
Expand Down Expand Up @@ -2444,6 +2451,7 @@ def sorted_keys(self):
return self._source.sorted_keys

__xor__ = TensorDict.__xor__
__or__ = TensorDict.__or__
__eq__ = TensorDict.__eq__
__ne__ = TensorDict.__ne__
__setitem__ = TensorDict.__setitem__
Expand Down
106 changes: 86 additions & 20 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
from numbers import Number
from pathlib import Path
from textwrap import indent
from typing import Any, Callable, Iterable, Iterator, List, Sequence
from typing import Any, Callable, Iterable, Iterator, List, Sequence, Type
from warnings import warn

import numpy as np
import torch
from functorch import dim as ftdim
from tensordict.base import (
_ACCEPTED_CLASSES,
_default_is_leaf,
_is_tensor_collection,
_register_tensor_class,
BEST_ATTEMPT_INPLACE,
Expand Down Expand Up @@ -381,18 +382,42 @@ def __xor__(self, other: object) -> T | bool:
)
return True

def __or__(self, other: object) -> T | bool:
if _is_tensorclass(other):
return other | self
if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__):
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
raise KeyError(
f"keys in {self} and {other} mismatch, got {keys1} and {keys2}"
)
d = {}
for key, item1 in self.items():
d[key] = item1 | other.get(key)
return TensorDict(batch_size=self.batch_size, source=d, device=self.device)
if isinstance(other, (numbers.Number, Tensor)):
return TensorDict(
{key: value | other for key, value in self.items()},
self.batch_size,
device=self.device,
)
return False

def __eq__(self, other: object) -> T | bool:
if is_tensorclass(other):
return other == self
if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__):
if isinstance(other, (dict,)):
other = self.empty(recurse=True).update(other)
if _is_tensor_collection(other.__class__):
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
d = {}
for key, item1 in self.items():
d[key] = item1 == other.get(key)
return TensorDict(batch_size=self.batch_size, source=d, device=self.device)
return TensorDict(source=d, batch_size=self.batch_size, device=self.device)
if isinstance(other, (numbers.Number, Tensor)):
return TensorDict(
{key: value == other for key, value in self.items()},
Expand Down Expand Up @@ -1737,21 +1762,30 @@ def select(self, *keys: NestedKey, inplace: bool = False, strict: bool = True) -
return out

def keys(
self, include_nested: bool = False, leaves_only: bool = False
self,
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
) -> _TensorDictKeysView:
if not include_nested and not leaves_only:
return self._tensordict.keys()
else:
return self._nested_keys(
include_nested=include_nested, leaves_only=leaves_only
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
)

# @cache # noqa: B019
def _nested_keys(
self, include_nested: bool = False, leaves_only: bool = False
self,
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
) -> _TensorDictKeysView:
return _TensorDictKeysView(
self, include_nested=include_nested, leaves_only=leaves_only
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)

def __getstate__(self):
Expand Down Expand Up @@ -1780,21 +1814,31 @@ def __setstate__(self, state):

# some custom methods for efficiency
def items(
self, include_nested: bool = False, leaves_only: bool = False
self,
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
) -> Iterator[tuple[str, CompatibleType]]:
if not include_nested and not leaves_only:
return self._tensordict.items()
else:
return super().items(include_nested=include_nested, leaves_only=leaves_only)
return super().items(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
)

def values(
self, include_nested: bool = False, leaves_only: bool = False
self,
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
) -> Iterator[tuple[str, CompatibleType]]:
if not include_nested and not leaves_only:
return self._tensordict.values()
else:
return super().values(
include_nested=include_nested, leaves_only=leaves_only
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)


Expand Down Expand Up @@ -1947,8 +1991,13 @@ def _set_str(
parent.batch_size, value, self.batch_dims, self.device
)
for _key, _tensor in value.items():
value_expand[_key] = _expand_to_match_shape(
parent.batch_size, _tensor, self.batch_dims, self.device
value_expand._set_str(
_key,
_expand_to_match_shape(
parent.batch_size, _tensor, self.batch_dims, self.device
),
inplace=inplace,
validated=validated,
)
else:
value_expand = torch.zeros(
Expand All @@ -1963,7 +2012,6 @@ def _set_str(
value_expand.share_memory_()
elif self.is_memmap():
value_expand = MemoryMappedTensor.from_tensor(value_expand)

parent._set_str(key, value_expand, inplace=False, validated=validated)

parent._set_at_str(key, value, self.idx, validated=validated)
Expand Down Expand Up @@ -2021,9 +2069,14 @@ def _set_at_tuple(self, key, value, idx, *, validated):

# @cache # noqa: B019
def keys(
self, include_nested: bool = False, leaves_only: bool = False
self,
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
) -> _TensorDictKeysView:
return self._source.keys(include_nested=include_nested, leaves_only=leaves_only)
return self._source.keys(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
)

def entry_class(self, key: NestedKey) -> type:
source_type = type(self._source.get(key))
Expand Down Expand Up @@ -2059,6 +2112,14 @@ def get(
) -> CompatibleType:
return self._source.get_at(key, self.idx, default=default)

def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT):
out = super()._get_non_tensor(key, default=default)
from tensordict.tensorclass import NonTensorData

if isinstance(out, _SubTensorDict) and isinstance(out._source, NonTensorData):
return out._source.data
return out

def _get_str(self, key, default):
if key in self.keys() and _is_tensor_collection(self.entry_class(key)):
return _SubTensorDict(self._source._get_str(key, NO_DEFAULT), self.idx)
Expand Down Expand Up @@ -2366,6 +2427,7 @@ def save_metadata(prefix=prefix, self=self):
result = self
return result

@classmethod
def _load_memmap(cls, prefix: Path, metadata: dict):
index = metadata["index"]
return _SubTensorDict(
Expand Down Expand Up @@ -2430,6 +2492,7 @@ def _create_nested_str(self, key):
__ne__ = TensorDict.__ne__
__setitem__ = TensorDict.__setitem__
__xor__ = TensorDict.__xor__
__or__ = TensorDict.__or__
_check_device = TensorDict._check_device
_check_is_shared = TensorDict._check_is_shared
all = TensorDict.all
Expand Down Expand Up @@ -2495,10 +2558,14 @@ def __init__(
tensordict: T,
include_nested: bool,
leaves_only: bool,
is_leaf: Callable[[Type], bool] = None,
) -> None:
self.tensordict = tensordict
self.include_nested = include_nested
self.leaves_only = leaves_only
if is_leaf is None:
is_leaf = _default_is_leaf
self.is_leaf = is_leaf

def __iter__(self) -> Iterable[str] | Iterable[tuple[str, ...]]:
if not self.include_nested:
Expand All @@ -2522,12 +2589,11 @@ def _iter_helper(
for key, value in self._items(tensordict):
full_key = self._combine_keys(prefix, key)
cls = value.__class__
if self.include_nested and (
_is_tensor_collection(cls) or issubclass(cls, KeyedJaggedTensor)
):
is_leaf = self.is_leaf(cls)
if self.include_nested and not is_leaf:
subkeys = tuple(self._iter_helper(value, prefix=full_key))
yield from subkeys
if not self.leaves_only or not _is_tensor_collection(cls):
if not self.leaves_only or is_leaf:
yield full_key

def _combine_keys(self, prefix: tuple | None, key: str) -> tuple:
Expand Down
6 changes: 6 additions & 0 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,12 @@ def _stack(
) -> T:
if not list_of_tensordicts:
raise RuntimeError("list_of_tensordicts cannot be empty")

from tensordict.tensorclass import NonTensorData

if all(isinstance(tensordict, NonTensorData) for tensordict in list_of_tensordicts):
return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim)

batch_size = list_of_tensordicts[0].batch_size
if dim < 0:
dim = len(batch_size) + dim + 1
Expand Down
Loading

0 comments on commit 7ac55a1

Please sign in to comment.