From 7ac55a10357b440b612d1d3093c51896d97fff8e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 20 Dec 2023 16:46:55 +0000 Subject: [PATCH] [Feature] Storing non-tensor data in tensordicts (#601) --- docs/source/reference/prototype.rst | 3 +- tensordict/__init__.py | 3 +- tensordict/_lazy.py | 30 ++- tensordict/_td.py | 106 ++++++-- tensordict/_torch_func.py | 6 + tensordict/base.py | 280 +++++++++++++++++++-- tensordict/nn/params.py | 25 +- tensordict/persistent.py | 35 ++- tensordict/tensorclass.py | 363 ++++++++++++++++++++++++++-- tensordict/utils.py | 27 ++- test/_utils_internal.py | 27 ++- test/test_tensorclass.py | 2 +- test/test_tensordict.py | 175 +++++++++++++- 13 files changed, 975 insertions(+), 107 deletions(-) diff --git a/docs/source/reference/prototype.rst b/docs/source/reference/prototype.rst index 11c3c5961..20841eab5 100644 --- a/docs/source/reference/prototype.rst +++ b/docs/source/reference/prototype.rst @@ -253,4 +253,5 @@ Here is an example: :toctree: generated/ :template: td_template.rst - @tensorclass + tensorclass + NonTensorData diff --git a/tensordict/__init__.py b/tensordict/__init__.py index b5aaf60d4..2238fb0a3 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -16,7 +16,7 @@ from tensordict.memmap import MemoryMappedTensor from tensordict.memmap_deprec import is_memmap, MemmapTensor, set_transfer_ownership from tensordict.persistent import PersistentTensorDict -from tensordict.tensorclass import tensorclass +from tensordict.tensorclass import NonTensorData, tensorclass from tensordict.utils import assert_allclose_td, is_batchedtensor, is_tensorclass try: @@ -46,6 +46,7 @@ "PersistentTensorDict", "tensorclass", "dense_stack_tds", + "NonTensorData", ] # from tensordict._pytree import * diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index 065400e6f..a44aa6fff 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -15,7 +15,7 @@ from copy import copy, deepcopy from pathlib import Path from textwrap import indent -from typing import Any, Callable, Iterator, Sequence +from typing import Any, Callable, Iterator, Sequence, Type import numpy as np import torch @@ -183,11 +183,6 @@ def __init__( "at least one tensordict must be provided to " "StackedTensorDict to be instantiated" ) - if not isinstance(tensordicts[0], TensorDictBase): - raise TypeError( - f"Expected input to be TensorDictBase instance" - f" but got {type(tensordicts[0])} instead." - ) if stack_dim < 0: raise RuntimeError( f"stack_dim must be non negative, got stack_dim={stack_dim}" @@ -196,7 +191,7 @@ def __init__( device = tensordicts[0].device for td in tensordicts[1:]: - if not isinstance(td, TensorDictBase): + if not is_tensor_collection(td): raise TypeError( "Expected all inputs to be TensorDictBase instances but got " f"{type(td)} instead." @@ -1057,10 +1052,16 @@ def _change_batch_size(self, new_size: torch.Size) -> None: self._batch_size = new_size def keys( - self, include_nested: bool = False, leaves_only: bool = False + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] | None = None, ) -> _LazyStackedTensorDictKeysView: keys = _LazyStackedTensorDictKeysView( - self, include_nested=include_nested, leaves_only=leaves_only + self, + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, ) return keys @@ -1970,6 +1971,7 @@ def _repr_exclusive_fields(self): unlock = _renamed_inplace_method(unlock_) __xor__ = TensorDict.__xor__ + __or__ = TensorDict.__or__ _check_device = TensorDict._check_device _check_is_shared = TensorDict._check_is_shared _convert_to_tensordict = TensorDict._convert_to_tensordict @@ -2195,9 +2197,14 @@ def __repr__(self) -> str: # @cache # noqa: B019 def keys( - self, include_nested: bool = False, leaves_only: bool = False + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] | None = None, ) -> _TensorDictKeysView: - return self._source.keys(include_nested=include_nested, leaves_only=leaves_only) + return self._source.keys( + include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + ) def select( self, *keys: str, inplace: bool = False, strict: bool = True @@ -2444,6 +2451,7 @@ def sorted_keys(self): return self._source.sorted_keys __xor__ = TensorDict.__xor__ + __or__ = TensorDict.__or__ __eq__ = TensorDict.__eq__ __ne__ = TensorDict.__ne__ __setitem__ = TensorDict.__setitem__ diff --git a/tensordict/_td.py b/tensordict/_td.py index 784467ba4..883a02c3d 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -15,7 +15,7 @@ from numbers import Number from pathlib import Path from textwrap import indent -from typing import Any, Callable, Iterable, Iterator, List, Sequence +from typing import Any, Callable, Iterable, Iterator, List, Sequence, Type from warnings import warn import numpy as np @@ -23,6 +23,7 @@ from functorch import dim as ftdim from tensordict.base import ( _ACCEPTED_CLASSES, + _default_is_leaf, _is_tensor_collection, _register_tensor_class, BEST_ATTEMPT_INPLACE, @@ -381,10 +382,34 @@ def __xor__(self, other: object) -> T | bool: ) return True + def __or__(self, other: object) -> T | bool: + if _is_tensorclass(other): + return other | self + if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__): + keys1 = set(self.keys()) + keys2 = set(other.keys()) + if len(keys1.difference(keys2)) or len(keys1) != len(keys2): + raise KeyError( + f"keys in {self} and {other} mismatch, got {keys1} and {keys2}" + ) + d = {} + for key, item1 in self.items(): + d[key] = item1 | other.get(key) + return TensorDict(batch_size=self.batch_size, source=d, device=self.device) + if isinstance(other, (numbers.Number, Tensor)): + return TensorDict( + {key: value | other for key, value in self.items()}, + self.batch_size, + device=self.device, + ) + return False + def __eq__(self, other: object) -> T | bool: if is_tensorclass(other): return other == self - if isinstance(other, (dict,)) or _is_tensor_collection(other.__class__): + if isinstance(other, (dict,)): + other = self.empty(recurse=True).update(other) + if _is_tensor_collection(other.__class__): keys1 = set(self.keys()) keys2 = set(other.keys()) if len(keys1.difference(keys2)) or len(keys1) != len(keys2): @@ -392,7 +417,7 @@ def __eq__(self, other: object) -> T | bool: d = {} for key, item1 in self.items(): d[key] = item1 == other.get(key) - return TensorDict(batch_size=self.batch_size, source=d, device=self.device) + return TensorDict(source=d, batch_size=self.batch_size, device=self.device) if isinstance(other, (numbers.Number, Tensor)): return TensorDict( {key: value == other for key, value in self.items()}, @@ -1737,21 +1762,30 @@ def select(self, *keys: NestedKey, inplace: bool = False, strict: bool = True) - return out def keys( - self, include_nested: bool = False, leaves_only: bool = False + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] | None = None, ) -> _TensorDictKeysView: if not include_nested and not leaves_only: return self._tensordict.keys() else: return self._nested_keys( - include_nested=include_nested, leaves_only=leaves_only + include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf ) # @cache # noqa: B019 def _nested_keys( - self, include_nested: bool = False, leaves_only: bool = False + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] | None = None, ) -> _TensorDictKeysView: return _TensorDictKeysView( - self, include_nested=include_nested, leaves_only=leaves_only + self, + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, ) def __getstate__(self): @@ -1780,21 +1814,31 @@ def __setstate__(self, state): # some custom methods for efficiency def items( - self, include_nested: bool = False, leaves_only: bool = False + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] | None = None, ) -> Iterator[tuple[str, CompatibleType]]: if not include_nested and not leaves_only: return self._tensordict.items() else: - return super().items(include_nested=include_nested, leaves_only=leaves_only) + return super().items( + include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + ) def values( - self, include_nested: bool = False, leaves_only: bool = False + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] | None = None, ) -> Iterator[tuple[str, CompatibleType]]: if not include_nested and not leaves_only: return self._tensordict.values() else: return super().values( - include_nested=include_nested, leaves_only=leaves_only + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, ) @@ -1947,8 +1991,13 @@ def _set_str( parent.batch_size, value, self.batch_dims, self.device ) for _key, _tensor in value.items(): - value_expand[_key] = _expand_to_match_shape( - parent.batch_size, _tensor, self.batch_dims, self.device + value_expand._set_str( + _key, + _expand_to_match_shape( + parent.batch_size, _tensor, self.batch_dims, self.device + ), + inplace=inplace, + validated=validated, ) else: value_expand = torch.zeros( @@ -1963,7 +2012,6 @@ def _set_str( value_expand.share_memory_() elif self.is_memmap(): value_expand = MemoryMappedTensor.from_tensor(value_expand) - parent._set_str(key, value_expand, inplace=False, validated=validated) parent._set_at_str(key, value, self.idx, validated=validated) @@ -2021,9 +2069,14 @@ def _set_at_tuple(self, key, value, idx, *, validated): # @cache # noqa: B019 def keys( - self, include_nested: bool = False, leaves_only: bool = False + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] | None = None, ) -> _TensorDictKeysView: - return self._source.keys(include_nested=include_nested, leaves_only=leaves_only) + return self._source.keys( + include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf + ) def entry_class(self, key: NestedKey) -> type: source_type = type(self._source.get(key)) @@ -2059,6 +2112,14 @@ def get( ) -> CompatibleType: return self._source.get_at(key, self.idx, default=default) + def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): + out = super()._get_non_tensor(key, default=default) + from tensordict.tensorclass import NonTensorData + + if isinstance(out, _SubTensorDict) and isinstance(out._source, NonTensorData): + return out._source.data + return out + def _get_str(self, key, default): if key in self.keys() and _is_tensor_collection(self.entry_class(key)): return _SubTensorDict(self._source._get_str(key, NO_DEFAULT), self.idx) @@ -2366,6 +2427,7 @@ def save_metadata(prefix=prefix, self=self): result = self return result + @classmethod def _load_memmap(cls, prefix: Path, metadata: dict): index = metadata["index"] return _SubTensorDict( @@ -2430,6 +2492,7 @@ def _create_nested_str(self, key): __ne__ = TensorDict.__ne__ __setitem__ = TensorDict.__setitem__ __xor__ = TensorDict.__xor__ + __or__ = TensorDict.__or__ _check_device = TensorDict._check_device _check_is_shared = TensorDict._check_is_shared all = TensorDict.all @@ -2495,10 +2558,14 @@ def __init__( tensordict: T, include_nested: bool, leaves_only: bool, + is_leaf: Callable[[Type], bool] = None, ) -> None: self.tensordict = tensordict self.include_nested = include_nested self.leaves_only = leaves_only + if is_leaf is None: + is_leaf = _default_is_leaf + self.is_leaf = is_leaf def __iter__(self) -> Iterable[str] | Iterable[tuple[str, ...]]: if not self.include_nested: @@ -2522,12 +2589,11 @@ def _iter_helper( for key, value in self._items(tensordict): full_key = self._combine_keys(prefix, key) cls = value.__class__ - if self.include_nested and ( - _is_tensor_collection(cls) or issubclass(cls, KeyedJaggedTensor) - ): + is_leaf = self.is_leaf(cls) + if self.include_nested and not is_leaf: subkeys = tuple(self._iter_helper(value, prefix=full_key)) yield from subkeys - if not self.leaves_only or not _is_tensor_collection(cls): + if not self.leaves_only or is_leaf: yield full_key def _combine_keys(self, prefix: tuple | None, key: str) -> tuple: diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 2b3053959..c6f53dabb 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -337,6 +337,12 @@ def _stack( ) -> T: if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") + + from tensordict.tensorclass import NonTensorData + + if all(isinstance(tensordict, NonTensorData) for tensordict in list_of_tensordicts): + return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) + batch_size = list_of_tensordicts[0].batch_size if dim < 0: dim = len(batch_size) + dim + 1 diff --git a/tensordict/base.py b/tensordict/base.py index 2fb5dc11f..550bdc881 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -29,6 +29,7 @@ OrderedDict, overload, Sequence, + Type, TypeVar, Union, ) @@ -53,11 +54,13 @@ IndexType, infer_size_impl, int_generator, + KeyedJaggedTensor, lazy_legacy, lock_blocked, NestedKey, prod, TensorDictFuture, + unravel_key, unravel_key_list, ) from torch import distributed as dist, multiprocessing as mp, nn, Tensor @@ -138,6 +141,22 @@ def __xor__(self, other): """ ... + @abc.abstractmethod + def __or__(self, other): + """OR operation over two tensordicts, for evey key. + + The two tensordicts must have the same key set. + + Args: + other (TensorDictBase, dict, or float): the value to compare against. + + Returns: + a new TensorDict instance with all tensors are boolean + tensors of the same shape as the original tensors. + + """ + ... + @abc.abstractmethod def __eq__(self, other: object) -> T: """Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set. @@ -1072,8 +1091,8 @@ def _legacy_permute( source=self, custom_op="permute", inv_op="permute", - custom_op_kwargs={"dims": dims_list}, - inv_op_kwargs={"dims": dims_list}, + custom_op_kwargs={"dims": list(map(int, dims_list))}, + inv_op_kwargs={"dims": list(map(int, dims_list))}, ) # Cache functionality @@ -1842,6 +1861,120 @@ def _set_str(self, key, value, *, inplace, validated): def _set_tuple(self, key, value, *, inplace, validated): ... + @lock_blocked + def set_non_tensor(self, key: NestedKey, value: Any): + """Registers a non-tensor value in the tensordict using :class:`tensordict.tensorclass.NonTensorData`. + + The value can be retrieved using :meth:`TensorDictBase.get_non_tensor` + or directly using `get`, which will return the :class:`tensordict.tensorclass.NonTensorData` + object. + + return: self + + Examples: + >>> data = TensorDict({}, batch_size=[]) + >>> data.set_non_tensor(("nested", "the string"), "a string!") + >>> assert data.get_non_tensor(("nested", "the string")) == "a string!" + >>> # regular `get` works but returns a NonTensorData object + >>> data.get(("nested", "the string")) + NonTensorData( + data='a string!', + batch_size=torch.Size([]), + device=None, + is_shared=False) + + """ + key = unravel_key(key) + return self._set_non_tensor(key, value) + + def _set_non_tensor(self, key: NestedKey, value: Any): + if isinstance(key, tuple): + if len(key) == 1: + return self._set_non_tensor(key[0], value) + sub_td = self._get_str(key[0], None) + if sub_td is None: + sub_td = self._create_nested_str(key[0]) + sub_td._set_non_tensor(key[1:], value) + return self + from tensordict.tensorclass import NonTensorData + + self._set_str( + key, + NonTensorData( + value, + batch_size=self.batch_size, + device=self.device, + names=self.names if self._has_names() else None, + ), + validated=True, + inplace=False, + ) + return self + + def get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): + """Gets a non-tensor value, if it exists, or `default` if the non-tensor value is not found. + + This method is robust to tensor/TensorDict values, meaning that if the + value gathered is a regular tensor it will be returned too (although + this method comes with some overhead and should not be used out of its + natural scope). + + See :meth:`~tensordict.TensorDictBase.set_non_tensor` for more information + on how to set non-tensor values in a tensordict. + + Args: + key (NestedKey): the location of the NonTensorData object. + default (Any, optional): the value to be returned if the key cannot + be found. + + Returns: the content of the :class:`tensordict.tensorclass.NonTensorData`, + or the entry corresponding to the ``key`` if it isn't a + :class:`tensordict.tensorclass.NonTensorData` (or ``default`` if the + entry cannot be found). + + Examples: + >>> data = TensorDict({}, batch_size=[]) + >>> data.set_non_tensor(("nested", "the string"), "a string!") + >>> assert data.get_non_tensor(("nested", "the string")) == "a string!" + >>> # regular `get` works but returns a NonTensorData object + >>> data.get(("nested", "the string")) + NonTensorData( + data='a string!', + batch_size=torch.Size([]), + device=None, + is_shared=False) + + """ + key = unravel_key(key) + return self._get_non_tensor(key, default=default) + + def _get_non_tensor(self, key: NestedKey, default=NO_DEFAULT): + if isinstance(key, tuple): + if len(key) == 1: + return self._get_non_tensor(key[0], default=default) + subtd = self._get_str(key[0], default=default) + if subtd is default: + return subtd + return subtd._get_non_tensor(key[1:], default=default) + value = self._get_str(key, default=default) + from tensordict.tensorclass import NonTensorData + + if isinstance(value, NonTensorData): + return value.data + return value + + def filter_non_tensor_data(self) -> T: + """Filters out all non-tensor-data.""" + from tensordict.tensorclass import NonTensorData + + def _filter(x): + if not isinstance(x, NonTensorData): + if is_tensor_collection(x): + return x.filter_non_tensor_data() + return x + + return self._apply_nest(_filter, call_on_nested=True) + def _convert_inplace(self, inplace, key): if inplace is not False: has_key = key in self.keys() @@ -2378,18 +2511,33 @@ def setdefault( return self.get(key) def items( - self, include_nested: bool = False, leaves_only: bool = False + self, include_nested: bool = False, leaves_only: bool = False, is_leaf=None ) -> Iterator[tuple[str, CompatibleType]]: - """Returns a generator of key-value pairs for the tensordict.""" + """Returns a generator of key-value pairs for the tensordict. + + Args: + include_nested (bool, optional): if ``True``, nested values will be returned. + Defaults to ``False``. + leaves_only (bool, optional): if ``False``, only leaves will be + returned. Defaults to ``False``. + is_leaf: an optional callable that indicates if a class is to be considered a + leaf or not. + + """ + if is_leaf is None: + is_leaf = _default_is_leaf + # check the conditions once only if include_nested and leaves_only: for k in self.keys(): val = self._get_str(k, NO_DEFAULT) - if _is_tensor_collection(val.__class__): + if not is_leaf(val.__class__): yield from ( (_unravel_key_to_tuple((k, _key)), _val) for _key, _val in val.items( - include_nested=include_nested, leaves_only=leaves_only + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, ) ) else: @@ -2398,33 +2546,52 @@ def items( for k in self.keys(): val = self._get_str(k, NO_DEFAULT) yield k, val - if _is_tensor_collection(val.__class__): + if not is_leaf(val.__class__): yield from ( (_unravel_key_to_tuple((k, _key)), _val) for _key, _val in val.items( - include_nested=include_nested, leaves_only=leaves_only + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, ) ) elif leaves_only: for k in self.keys(): val = self._get_str(k, NO_DEFAULT) - if not _is_tensor_collection(val.__class__): + if is_leaf(val.__class__): yield k, val else: for k in self.keys(): yield k, self._get_str(k, NO_DEFAULT) def values( - self, include_nested: bool = False, leaves_only: bool = False + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf=None, ) -> Iterator[CompatibleType]: - """Returns a generator representing the values for the tensordict.""" + """Returns a generator representing the values for the tensordict. + + Args: + include_nested (bool, optional): if ``True``, nested values will be returned. + Defaults to ``False``. + leaves_only (bool, optional): if ``False``, only leaves will be + returned. Defaults to ``False``. + is_leaf: an optional callable that indicates if a class is to be considered a + leaf or not. + + """ + if is_leaf is None: + is_leaf = _default_is_leaf # check the conditions once only if include_nested and leaves_only: for k in self.keys(): val = self._get_str(k, NO_DEFAULT) - if _is_tensor_collection(val.__class__): + if not is_leaf(val.__class__): yield from val.values( - include_nested=include_nested, leaves_only=leaves_only + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, ) else: yield val @@ -2432,22 +2599,48 @@ def values( for k in self.keys(): val = self._get_str(k, NO_DEFAULT) yield val - if _is_tensor_collection(val.__class__): + if not is_leaf(val.__class__): yield from val.values( - include_nested=include_nested, leaves_only=leaves_only + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, ) elif leaves_only: for k in self.keys(): val = self._get_str(k, NO_DEFAULT) - if not _is_tensor_collection(val.__class__): + if is_leaf(val.__class__): yield val else: for k in self.keys(): yield self._get_str(k, NO_DEFAULT) @abc.abstractmethod - def keys(self, include_nested: bool = False, leaves_only: bool = False): - """Returns a generator of tensordict keys.""" + def keys( + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] = None, + ): + """Returns a generator of tensordict keys. + + Args: + include_nested (bool, optional): if ``True``, nested values will be returned. + Defaults to ``False``. + leaves_only (bool, optional): if ``False``, only leaves will be + returned. Defaults to ``False``. + is_leaf: an optional callable that indicates if a class is to be considered a + leaf or not. + + Examples: + >>> from tensordict import TensorDict + >>> data = TensorDict({"0": 0, "1": {"2": 2}}, batch_size=[]) + >>> data.keys() + ['0', '1'] + >>> list(data.keys(leaves_only=True)) + ['0'] + >>> list(data.keys(include_nested=True, leaves_only=True)) + ['0', '1', ('1', '2')] + """ ... def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleType: @@ -3511,7 +3704,22 @@ def _convert_to_tensor(self, array: np.ndarray) -> Tensor: array = array.item() if isinstance(array, list): array = np.asarray(array) - return torch.as_tensor(array, device=self.device) + if not isinstance(array, np.ndarray) and hasattr(array, "numpy"): + # tf.Tensor with no shape can't be converted otherwise + array = array.numpy() + try: + return torch.as_tensor(array, device=self.device) + except Exception: + if hasattr(array, "shape"): + return torch.full(array.shape, float("NaN")) + from tensordict.tensorclass import NonTensorData + + return NonTensorData( + array, + batch_size=self.batch_size, + device=self.device, + names=self.names if self._has_names() else None, + ) @abc.abstractmethod def _convert_to_tensordict(self, dict_value: dict[str, Any]) -> T: @@ -3923,7 +4131,12 @@ def contiguous(self) -> T: ... @cache # noqa: B019 - def flatten_keys(self, separator: str = ".", inplace: bool = False) -> T: + def flatten_keys( + self, + separator: str = ".", + inplace: bool = False, + is_leaf: Callable[[Type], bool] | None = None, + ) -> T: """Converts a nested tensordict into a flat one, recursively. The TensorDict type will be lost and the result will be a simple TensorDict instance. @@ -3933,6 +4146,8 @@ def flatten_keys(self, separator: str = ".", inplace: bool = False) -> T: inplace (bool, optional): if ``True``, the resulting tensordict will have the same identity as the one where the call has been made. Defaults to ``False``. + is_leaf (callable, optional): a callable over a class type returning + a bool indicating if this class has to be considered as a leaf. Examples: >>> data = TensorDict({"a": 1, ("b", "c"): 2, ("e", "f", "g"): 3}, batch_size=[]) @@ -3989,7 +4204,11 @@ def flatten_keys(self, separator: str = ".", inplace: bool = False) -> T: is_shared=False) >>> model.load_state_dict(dict(model_state_dict.flatten_keys("."))) """ - all_leaves = list(self.keys(include_nested=True, leaves_only=True)) + if is_leaf is None: + is_leaf = _is_leaf_nontensor + all_leaves = list( + self.keys(include_nested=True, leaves_only=True, is_leaf=is_leaf) + ) all_leaves_flat = [ separator.join(key) if isinstance(key, tuple) else key for key in all_leaves ] @@ -4361,11 +4580,10 @@ def detach(self) -> T: return self._fast_apply(lambda x: x.detach()) -_ACCEPTED_CLASSES = [ +_ACCEPTED_CLASSES = { Tensor, TensorDictBase, -] -_ACCEPTED_CLASSES = set(_ACCEPTED_CLASSES) +} def _register_tensor_class(cls): @@ -4408,3 +4626,17 @@ def is_tensor_collection(datatype: type | Any) -> bool: if not isinstance(datatype, type): datatype = type(datatype) return _is_tensor_collection(datatype) + + +def _default_is_leaf(cls: Type) -> bool: + return not _is_tensor_collection(cls) + + +def _is_leaf_nontensor(cls: Type) -> bool: + from tensordict.tensorclass import NonTensorData + + if issubclass(cls, KeyedJaggedTensor): + return False + if _is_tensor_collection(cls): + return issubclass(cls, NonTensorData) + return issubclass(cls, torch.Tensor) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 96bdf7ede..5840e875a 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -11,7 +11,7 @@ import weakref from copy import copy from functools import wraps -from typing import Any, Callable, Iterator, OrderedDict, Sequence +from typing import Any, Callable, Iterator, OrderedDict, Sequence, Type import torch from functorch import dim as ftdim @@ -21,6 +21,7 @@ from tensordict._torch_func import TD_HANDLED_FUNCTIONS from tensordict.base import ( + _default_is_leaf, _is_tensor_collection, _register_tensor_class, CompatibleType, @@ -840,6 +841,10 @@ def _legacy_unsqueeze(self, dim: int) -> TensorDictBase: def __xor__(self, other): ... + @_fallback + def __or__(self, other): + ... + _check_device = TensorDict._check_device _check_is_shared = TensorDict._check_is_shared @@ -897,10 +902,15 @@ def __repr__(self): return f"TensorDictParams(params={self._param_td})" def values( - self, include_nested: bool = False, leaves_only: bool = False + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] | None = None, ) -> Iterator[CompatibleType]: + if is_leaf is None: + is_leaf = _default_is_leaf for v in self._param_td.values(include_nested, leaves_only): - if _is_tensor_collection(type(v)): + if not is_leaf(type(v)): yield v continue yield self._apply_get_post_hook(v) @@ -961,10 +971,15 @@ def _load_from_state_dict( self.data.load_state_dict(data) def items( - self, include_nested: bool = False, leaves_only: bool = False + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] | None = None, ) -> Iterator[CompatibleType]: + if is_leaf is None: + is_leaf = _default_is_leaf for k, v in self._param_td.items(include_nested, leaves_only): - if _is_tensor_collection(type(v)): + if not is_leaf(type(v)): yield k, v continue yield k, self._apply_get_post_hook(v) diff --git a/tensordict/persistent.py b/tensordict/persistent.py index 2d7ac4ddd..793debdd4 100644 --- a/tensordict/persistent.py +++ b/tensordict/persistent.py @@ -9,7 +9,7 @@ import tempfile import warnings from pathlib import Path -from typing import Any, Callable +from typing import Any, Callable, Type from tensordict._td import _unravel_key_to_tuple from torch import multiprocessing as mp @@ -29,7 +29,7 @@ import numpy as np import torch from tensordict._td import _TensorDictKeysView, CompatibleType, NO_DEFAULT, TensorDict -from tensordict.base import is_tensor_collection, T, TensorDictBase +from tensordict.base import _default_is_leaf, is_tensor_collection, T, TensorDictBase from tensordict.memmap import MemoryMappedTensor from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( @@ -416,8 +416,15 @@ def _valid_keys(self): # @cache # noqa: B019 def keys( - self, include_nested: bool = False, leaves_only: bool = False + self, + include_nested: bool = False, + leaves_only: bool = False, + is_leaf: Callable[[Type], bool] | None = None, ) -> _PersistentTDKeysView: + if is_leaf not in (None, _default_is_leaf): + raise ValueError( + f"is_leaf {is_leaf} is not supported within tensordicts of type {type(self)}." + ) return _PersistentTDKeysView( tensordict=self, include_nested=include_nested, @@ -491,6 +498,22 @@ def detach_(self): def device(self): return self._device + def empty(self, recurse=False) -> T: + if recurse: + out = self.empty(recurse=False) + for key, val in self.items(): + if is_tensor_collection(val): + out._set_str( + key, val.empty(recurse=True), inplace=False, validated=True + ) + return out + return TensorDict( + {}, + device=self.device, + batch_size=self.batch_size, + names=self.names if self._has_names() else None, + ) + def entry_class(self, key: NestedKey) -> type: entry_class = self._get_metadata(key) is_array = entry_class.get("array", None) @@ -889,6 +912,11 @@ def _convert_inplace(self, inplace, key): inplace = has_key return inplace + def _set_non_tensor(self, key: NestedKey, value: Any): + raise NotImplementedError( + f"set_non_tensor is not compatible with the tensordict type {type(self)}." + ) + def _set_str(self, key, value, *, inplace, validated): inplace = self._convert_inplace(inplace, key) return self._set(key, value, inplace=inplace, validated=validated) @@ -1005,6 +1033,7 @@ def _remove_batch_dim(self, vmap_level, batch_size, out_dim): __eq__ = TensorDict.__eq__ __ne__ = TensorDict.__ne__ __xor__ = TensorDict.__xor__ + __or__ = TensorDict.__or__ _apply_nest = TensorDict._apply_nest _check_device = TensorDict._check_device _check_is_shared = TensorDict._check_is_shared diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 7c059308f..82cd171da 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -11,6 +11,7 @@ import json import numbers import os +import pickle import re import sys import warnings @@ -18,7 +19,7 @@ from dataclasses import dataclass from pathlib import Path from textwrap import indent -from typing import Any, Callable, Sequence, TypeVar +from typing import Any, Callable, List, Sequence, TypeVar import tensordict as tensordict_lib @@ -26,10 +27,12 @@ from tensordict._td import is_tensor_collection, NO_DEFAULT, TensorDict, TensorDictBase from tensordict._tensordict import _unravel_key_to_tuple from tensordict._torch_func import TD_HANDLED_FUNCTIONS +from tensordict.base import _register_tensor_class from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor from tensordict.utils import ( _get_repr, + _is_json_serializable, _LOCK_ERROR, DeviceType, IndexType, @@ -139,19 +142,21 @@ def __torch_function__( # get the output type from the arguments / keyword arguments if len(args) > 0: - tc = args[0] + tensorclass_instance = args[0] else: - tc = kwargs.get("input", kwargs["tensors"]) - if isinstance(tc, (tuple, list)): - tc = tc[0] - + tensorclass_instance = kwargs.get("input", kwargs["tensors"]) + if isinstance(tensorclass_instance, (tuple, list)): + tensorclass_instance = tensorclass_instance[0] args = tuple(_arg_to_tensordict(arg) for arg in args) kwargs = {key: _arg_to_tensordict(value) for key, value in kwargs.items()} - res = TD_HANDLED_FUNCTIONS[func](*args, **kwargs) - if isinstance(res, (list, tuple)): - return res.__class__(_from_tensordict_with_copy(tc, td) for td in res) - return _from_tensordict_with_copy(tc, res) + result = TD_HANDLED_FUNCTIONS[func](*args, **kwargs) + if isinstance(result, (list, tuple)): + return result.__class__( + _from_tensordict_with_copy(tensorclass_instance, tensordict_result) + for tensordict_result in result + ) + return _from_tensordict_with_copy(tensorclass_instance, result) cls = dataclass(cls) expected_keys = set(cls.__dataclass_fields__) @@ -165,7 +170,8 @@ def __torch_function__( cls.__init__ = _init_wrapper(cls.__init__) cls._from_tensordict = classmethod(_from_tensordict_wrapper(expected_keys)) cls.from_tensordict = cls._from_tensordict - cls.__torch_function__ = classmethod(__torch_function__) + if not hasattr(cls, "__torch_function__"): + cls.__torch_function__ = classmethod(__torch_function__) cls.__getstate__ = _getstate cls.__setstate__ = _setstate cls.__getattribute__ = _getattribute_wrapper(cls.__getattribute__) @@ -176,8 +182,10 @@ def __torch_function__( cls.__setitem__ = _setitem cls.__repr__ = _repr cls.__len__ = _len - cls.__eq__ = __eq__ - cls.__ne__ = __ne__ + cls.__eq__ = _eq + cls.__ne__ = _ne + cls.__or__ = _or + cls.__xor__ = _xor cls.set = _set cls.set_at_ = _set_at_ cls.del_ = _del_ @@ -192,8 +200,8 @@ def __torch_function__( cls.memmap_like = TensorDictBase.memmap_like cls.memmap_ = TensorDictBase.memmap_ cls.memmap = TensorDictBase.memmap - cls._load_memmap = classmethod(_load_memmap) cls.load_memmap = TensorDictBase.load_memmap + cls._load_memmap = classmethod(_load_memmap) for attr in TensorDict.__dict__.keys(): func = getattr(TensorDict, attr) @@ -208,10 +216,7 @@ def __torch_function__( cls.__doc__ = f"{cls.__name__}{inspect.signature(cls)}" - tensordict_lib.base._ACCEPTED_CLASSES = ( - *tensordict_lib.base._ACCEPTED_CLASSES, - cls, - ) + _register_tensor_class(cls) return cls @@ -254,6 +259,7 @@ def wrapper( *args: Any, batch_size: Sequence[int] | torch.Size | int, device: DeviceType | None = None, + names: List[str] | None = None, **kwargs, ): for value, key in zip(args, self.__dataclass_fields__): @@ -279,7 +285,11 @@ def wrapper( ) self._tensordict = TensorDict( - {}, batch_size=torch.Size(batch_size), device=device, _run_checks=False + {}, + batch_size=torch.Size(batch_size), + device=device, + names=names, + _run_checks=False, ) # To save non tensor data (Nested tensor classes also go here) self._non_tensordict = {} @@ -288,6 +298,7 @@ def wrapper( new_params = [ inspect.Parameter("batch_size", inspect.Parameter.KEYWORD_ONLY), inspect.Parameter("device", inspect.Parameter.KEYWORD_ONLY, default=None), + inspect.Parameter("names", inspect.Parameter.KEYWORD_ONLY, default=None), ] wrapper.__signature__ = init_sig.replace(parameters=params + new_params) @@ -366,13 +377,16 @@ def _memmap_( def save_metadata(cls=cls, _non_tensordict=_non_tensordict, prefix=prefix): with open(prefix / "meta.json", "w") as f: metadata = {"_type": str(cls)} + to_pickle = {} for key, value in _non_tensordict.items(): - if ( - isinstance(value, (int, bool, str, float, dict, tuple, list)) - or value is None - ): + if _is_json_serializable(value): metadata[key] = value + else: + to_pickle[key] = value json.dump(metadata, f) + if to_pickle: + with open(prefix / "other.pickle", "wb") as pickle_file: + pickle.dump(to_pickle, pickle_file) if executor is None: save_metadata() @@ -400,6 +414,9 @@ def save_metadata(cls=cls, _non_tensordict=_non_tensordict, prefix=prefix): def _load_memmap(cls, prefix: Path, metadata: dict): del metadata["_type"] non_tensordict = copy(metadata) + if os.path.exists(prefix / "other.pickle"): + with open(prefix / "other.pickle", "rb") as pickle_file: + non_tensordict.update(pickle.load(pickle_file)) td = TensorDict.load_memmap(prefix / "_tensordict") return cls._from_tensordict(td, non_tensordict) @@ -865,7 +882,7 @@ def _load_state_dict( return self -def __eq__(self, other: object) -> bool: +def _eq(self, other: object) -> bool: """Compares the Tensor class object to another object for equality. However, the equality check for non-tensor data is not performed. Args: @@ -926,7 +943,7 @@ def __eq__(self, other: object) -> bool: return _from_tensordict_with_none(self, tensor) -def __ne__(self, other: object) -> bool: +def _ne(self, other: object) -> bool: """Compare the Tensor class object to another object for inequality. However, the equality check for non-tensor data is not performed. Args: @@ -983,6 +1000,54 @@ def __ne__(self, other: object) -> bool: return _from_tensordict_with_none(self, tensor) +def _or(self, other: object) -> bool: + """Compares the Tensor class object to another object for logical OR. However, the logical OR check for non-tensor data is not performed. + + Args: + other: object to compare to this object. Can be a tensorclass, a + tensordict or any compatible type (int, float or tensor), in + which case the equality check will be propagated to the leaves. + + Returns: + False if the objects are of different class types, Tensorclass of boolean + values for tensor attributes and None for non-tensor attributes + + """ + if not is_tensor_collection(other) and not isinstance( + other, (dict, numbers.Number, Tensor, _MemmapTensor) + ): + return False + if is_tensorclass(other): + tensor = self._tensordict | other._tensordict + else: + tensor = self._tensordict | other + return _from_tensordict_with_none(self, tensor) + + +def _xor(self, other: object) -> bool: + """Compares the Tensor class object to another object for exclusive OR. However, the exclusive OR check for non-tensor data is not performed. + + Args: + other: object to compare to this object. Can be a tensorclass, a + tensordict or any compatible type (int, float or tensor), in + which case the equality check will be propagated to the leaves. + + Returns: + False if the objects are of different class types, Tensorclass of boolean + values for tensor attributes and None for non-tensor attributes + + """ + if not is_tensor_collection(other) and not isinstance( + other, (dict, numbers.Number, Tensor, _MemmapTensor) + ): + return False + if is_tensorclass(other): + tensor = self._tensordict ^ other._tensordict + else: + tensor = self._tensordict ^ other + return _from_tensordict_with_none(self, tensor) + + def _single_td_field_as_str(key, item, tensordict): """Returns a string as a key-value pair of tensordict. @@ -1046,3 +1111,251 @@ def _unbind(self, dim: int): self._from_tensordict(td, non_tensordict=copy(self._non_tensordict)) for td in self._tensordict.unbind(dim) ) + + +################ +# Custom classes +# -------------- + +NONTENSOR_HANDLED_FUNCTIONS = [] + + +@tensorclass +class NonTensorData: + """A carrier for non-tensordict data. + + This class can be used whenever non-tensor data needs to be carrier at + any level of a tensordict instance. + + :class:`~tensordict.tensorclass.NonTensorData` instances can be created + explicitely or using :meth:`~tensordict.TensorDictBase.set_non_tensor`. + + This class is serializable using :meth:`tensordict.TensorDictBase.memmap` + and related methods, and can be loaded through :meth:`~tensordict.TensorDictBase.load_memmap`. + If the content of the object is JSON-serializable, it will be serializsed in + the `meta.json` file in the directory pointed by the parent key of the `NoneTensorData` + object. If it isn't, serialization will fall back on pickle. This implies + that we assume that the content of this class is either json-serializable or + pickable, and it is the user responsibility to make sure that one of these + holds. We try to avoid pickling/unpickling objects for performance and security + reasons (as pickle can execute arbitrary code during loading). + + .. note:: if the data passed to :class:`NonTensorData` is a :class:`NonTensorData` + itself, the data from the nested object will be gathered. + + >>> non_tensor = NonTensorData("a string!") + >>> non_tensor = NonTensorData(non_tensor) + >>> assert non_tensor.data == "a string!" + + .. note:: Unlike other tensorclass classes, :class:`NonTensorData` supports + comparisons of two non-tensor data through :meth:`~.__eq__`, :meth:`~.__ne__`, + :meth:`~.__xor__` or :meth:`~.__or__`. These operations return a tensor + of shape `batch_size`. For compatibility with ` == `, + comparison with non-:class:`NonTensorData` will always return an empty + :class:`NonTensorData`. + + >>> a = NonTensorData(True, batch_size=[]) + >>> b = NonTensorData(True, batch_size=[]) + >>> assert a == b + >>> assert not (a != b) + >>> assert not (a ^ b) + >>> assert a | b + >>> # The output is a tensor of shape batch-size + >>> a = NonTensorData(True, batch_size=[3]) + >>> b = NonTensorData(True, batch_size=[3]) + >>> print(a == b) + tensor([True, True, True]) + + .. note:: Stacking :class:`NonTensorData` instances results in either + a single :class:`NonTensorData` instance if all shapes match, or a + :class:`~tensordict.LazyStackedTensorDict` object if the content + mismatch. To get to this result, the content of the :class:`NonTensorData` + instances must be compared, which can be computationally intensive + depending on what this content is. + + >>> data = torch.stack([NonTensorData(1, batch_size=[]) for _ in range(10)]) + >>> data + NonTensorData( + data=1, + batch_size=torch.Size([10]), + device=None, + is_shared=False) + >>> data = torch.stack([NonTensorData(i, batch_size=[3,]) for i in range(10)], 1) + >>> data[:, 0] + NonTensorData( + data=0, + batch_size=torch.Size([3]), + device=None, + is_shared=False) + + .. note:: Non-tensor data can be filtered out from a tensordict using + :meth:`~tensordict.TensorDictBase.filter_non_tensor`. + + Examples: + >>> # create an instance explicitly + >>> non_tensor = NonTensorData("a string!", batch_size=[]) # batch-size can be anything + >>> data = TensorDict({}, batch_size=[3]) + >>> data.set_non_tensor(("nested", "key"), "a string!") + >>> assert isinstance(data.get(("nested", "key")), NonTensorData) + >>> assert data.get_non_tensor(("nested", "key")) == "a string!" + >>> # serialization + >>> class MyPickableClass: + ... value = 10 + >>> data.set_non_tensor("pickable", MyPickableClass()) + >>> import tempfile + >>> with tempfile.TemporaryDirectory() as tmpdir: + ... data.memmap(tmpdir) + ... loaded = TensorDict.load_memmap(tmpdir) + ... # print directory path + ... print_directory_tree(tmpdir) + Directory size: 511.00 B + tmp2cso9og_/ + pickable/ + _tensordict/ + meta.json + other.pickle + meta.json + nested/ + key/ + _tensordict/ + meta.json + meta.json + meta.json + meta.json + >>> assert loaded.get_non_tensor("pickable").value == 10 + + """ + + # Used to carry non-tensor data in a tensordict. + # The advantage of storing this in a tensorclass is that we don't need + # to patch tensordict with additional checks that will encur unwanted overhead + # and all the overhead falls back on this class. + data: Any + + def __post_init__(self): + if isinstance(self.data, NonTensorData): + self.data = self.data.data + + old_eq = self.__class__.__eq__ + if old_eq is _eq: + global NONTENSOR_HANDLED_FUNCTIONS + NONTENSOR_HANDLED_FUNCTIONS.extend(TD_HANDLED_FUNCTIONS) + + # Patch only the first time a class is created + + @functools.wraps(_eq) + def __eq__(self, other): + if isinstance(other, NonTensorData): + return torch.full( + self.batch_size, self.data == other.data, device=self.device + ) + return old_eq(self, other) + + self.__class__.__eq__ = __eq__ + + _ne = self.__class__.__ne__ + + @functools.wraps(_ne) + def __ne__(self, other): + if isinstance(other, NonTensorData): + return torch.full( + self.batch_size, self.data != other.data, device=self.device + ) + return _ne(self, other) + + self.__class__.__ne__ = __ne__ + + _xor = self.__class__.__xor__ + + @functools.wraps(_xor) + def __xor__(self, other): + if isinstance(other, NonTensorData): + return torch.full( + self.batch_size, self.data ^ other.data, device=self.device + ) + return _xor(self, other) + + self.__class__.__xor__ = __xor__ + + _or = self.__class__.__or__ + + @functools.wraps(_or) + def __or__(self, other): + if isinstance(other, NonTensorData): + return torch.full( + self.batch_size, self.data | other.data, device=self.device + ) + return _or(self, other) + + self.__class__.__or__ = __or__ + + def empty(self, recurse=False): + return NonTensorData( + data=self.data, + batch_size=self.batch_size, + names=self.names if self._has_names() else None, + device=self.device, + ) + + def to_dict(self): + # override to_dict to return just the data + return self.data + + @classmethod + def _stack_non_tensor(cls, list_of_non_tensor, dim=0): + # checks have been performed previously, so we're sure the list is non-empty + first = list_of_non_tensor[0] + if all(data.data == first.data for data in list_of_non_tensor[1:]): + batch_size = list(first.batch_size) + batch_size.insert(dim, len(list_of_non_tensor)) + return NonTensorData( + data=first.data, + batch_size=batch_size, + names=first.names if first._has_names() else None, + device=first.device, + ) + + from tensordict._lazy import LazyStackedTensorDict + + return LazyStackedTensorDict(*list_of_non_tensor, stack_dim=dim) + + @classmethod + def __torch_function__( + cls, + func: Callable, + types: tuple[type, ...], + args: tuple[Any, ...] = (), + kwargs: dict[str, Any] | None = None, + ) -> Callable: + # A modified version of __torch_function__ to account for the different behaviour + # of stack, which should return lazy stacks of data of data does not match. + if func not in _TD_PASS_THROUGH or not all( + issubclass(t, (Tensor, cls)) for t in types + ): + return NotImplemented + + escape_conversion = func in (torch.stack,) + + if kwargs is None: + kwargs = {} + + # get the output type from the arguments / keyword arguments + if len(args) > 0: + tensorclass_instance = args[0] + else: + tensorclass_instance = kwargs.get("input", kwargs["tensors"]) + if isinstance(tensorclass_instance, (tuple, list)): + tensorclass_instance = tensorclass_instance[0] + if not escape_conversion: + args = tuple(_arg_to_tensordict(arg) for arg in args) + kwargs = {key: _arg_to_tensordict(value) for key, value in kwargs.items()} + + result = TD_HANDLED_FUNCTIONS[func](*args, **kwargs) + if isinstance(result, (list, tuple)): + return result.__class__( + _from_tensordict_with_copy(tensorclass_instance, tensordict_result) + for tensordict_result in result + ) + if not escape_conversion: + return _from_tensordict_with_copy(tensorclass_instance, result) + return result diff --git a/tensordict/utils.py b/tensordict/utils.py index 726dfa133..f5cbee7a7 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1430,12 +1430,9 @@ def _expand_to_match_shape( ) else: # tensordict - from tensordict import TensorDict - - out = TensorDict( - {}, - [*parent_batch_size, *_shape(tensor)[self_batch_dims:]], - device=self_device, + out = tensor.empty() + out.batch_size = torch.Size( + [*parent_batch_size, *_shape(tensor)[self_batch_dims:]] ) return out @@ -1771,3 +1768,21 @@ def result(self): """Wait and returns the resulting tensordict.""" concurrent.futures.wait(self.futures) return self.resulting_td + + +def _is_json_serializable(item): + if isinstance(item, dict): + for key, val in item.items(): + # Per se, int, float and bool are serializable but not recoverable + # as such + if not isinstance(key, (str,)) or not _is_json_serializable(val): + return False + else: + return True + if isinstance(item, (list, tuple, set)): + for val in item: + if not _is_json_serializable(val): + return False + else: + return True + return isinstance(item, (str, int, float, bool)) or item is None diff --git a/test/_utils_internal.py b/test/_utils_internal.py index ae4fc1053..c85acbfc5 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -10,7 +10,7 @@ import numpy as np import torch -from tensordict import PersistentTensorDict, tensorclass, TensorDict +from tensordict import NonTensorData, PersistentTensorDict, tensorclass, TensorDict from tensordict._lazy import LazyStackedTensorDict from tensordict._torch_func import _stack as stack_td from tensordict.base import is_tensor_collection @@ -297,6 +297,23 @@ def td_params(self, device): TYPES_DEVICES += [["td_params", device]] TYPES_DEVICES_NOLAZY += [["td_params", device]] + def td_with_non_tensor(self, device): + td = self.td(device) + return td.set_non_tensor( + ("data", "non_tensor"), + # this is allowed since nested NonTensorData are automatically unwrapped + NonTensorData( + "some text data", + batch_size=td.batch_size, + device=td.device, + names=td.names if td._has_names() else None, + ), + ) + + for device in get_available_devices(): + TYPES_DEVICES += [["td_with_non_tensor", device]] + TYPES_DEVICES_NOLAZY += [["td_with_non_tensor", device]] + def expand_list(list_of_tensors, *dims): n = len(list_of_tensors) @@ -315,3 +332,11 @@ def decompose(td): yield from decompose(v) else: yield v + + +class DummyPicklableClass: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return self.value == other.value diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index aff7a71a5..28797971c 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -88,7 +88,7 @@ def test_type(): def test_signature(): sig = inspect.signature(MyData) - assert list(sig.parameters) == ["X", "y", "z", "batch_size", "device"] + assert list(sig.parameters) == ["X", "y", "z", "batch_size", "device", "names"] with pytest.raises(TypeError, match="missing 3 required positional arguments"): MyData(batch_size=[10]) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 3482088dd..3a1acc158 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -15,6 +15,7 @@ import torch from tensordict.nn import TensorDictParams +from tensordict.tensorclass import NonTensorData try: import torchsnapshot @@ -35,7 +36,13 @@ import contextlib import platform -from _utils_internal import decompose, get_available_devices, prod, TestTensorDictsBase +from _utils_internal import ( + decompose, + DummyPicklableClass, + get_available_devices, + prod, + TestTensorDictsBase, +) from functorch import dim as ftdim from tensordict import LazyStackedTensorDict, make_tensordict, MemmapTensor, TensorDict @@ -2204,7 +2211,7 @@ def test_delitem(self, td_name, device): def test_to_dict_nested(self, td_name, device): def recursive_checker(cur_dict): for _, value in cur_dict.items(): - if isinstance(value, TensorDict): + if is_tensor_collection(value): return False elif isinstance(value, dict) and not recursive_checker(value): return False @@ -2222,6 +2229,9 @@ def recursive_checker(cur_dict): # Convert into dictionary and recursively check if the values are TensorDicts td_dict = td.to_dict() assert recursive_checker(td_dict) + if td_name == "td_with_non_tensor": + assert td_dict["data"]["non_tensor"] == "some text data" + assert (TensorDict.from_dict(td_dict) == td).all() @pytest.mark.parametrize( "index", ["tensor1", "mask", "int", "range", "tensor2", "slice_tensor"] @@ -3014,20 +3024,38 @@ def test_setitem_slice(self, td_name, device): def test_casts(self, td_name, device): td = getattr(self, td_name)(device) + # exclude non-tensor data + is_leaf = lambda cls: issubclass(cls, torch.Tensor) tdfloat = td.float() - assert all(value.dtype is torch.float for value in tdfloat.values(True, True)) + assert all( + value.dtype is torch.float + for value in tdfloat.values(True, True, is_leaf=is_leaf) + ) tddouble = td.double() - assert all(value.dtype is torch.double for value in tddouble.values(True, True)) + assert all( + value.dtype is torch.double + for value in tddouble.values(True, True, is_leaf=is_leaf) + ) tdbfloat16 = td.bfloat16() assert all( - value.dtype is torch.bfloat16 for value in tdbfloat16.values(True, True) + value.dtype is torch.bfloat16 + for value in tdbfloat16.values(True, True, is_leaf=is_leaf) ) tdhalf = td.half() - assert all(value.dtype is torch.half for value in tdhalf.values(True, True)) + assert all( + value.dtype is torch.half + for value in tdhalf.values(True, True, is_leaf=is_leaf) + ) tdint = td.int() - assert all(value.dtype is torch.int for value in tdint.values(True, True)) + assert all( + value.dtype is torch.int + for value in tdint.values(True, True, is_leaf=is_leaf) + ) tdint = td.type(torch.int) - assert all(value.dtype is torch.int for value in tdint.values(True, True)) + assert all( + value.dtype is torch.int + for value in tdint.values(True, True, is_leaf=is_leaf) + ) def test_empty_like(self, td_name, device): if "sub_td" in td_name: @@ -3045,7 +3073,10 @@ def test_empty_like(self, td_name, device): td.apply_(lambda x: x + 1.0) assert type(td) is type(td_empty) - assert all(val.any() for val in (td != td_empty).values(True, True)) + # exclude non tensor data + comp = td.filter_non_tensor_data() != td_empty.filter_non_tensor_data() + print(td.filter_non_tensor_data()) + assert all(val.any() for val in comp.values(True, True)) @pytest.mark.parametrize("nested", [False, True]) def test_add_batch_dim_cache(self, td_name, device, nested): @@ -3199,6 +3230,67 @@ def test_update_select(self, td_name, device): td.update(other_td, keys_to_update=(("My", ("father",), "was"),)) assert (td["My", "father", "was"] == 1).all() + def test_non_tensor_data(self, td_name, device): + td = getattr(self, td_name)(device) + # check lock + if td_name not in ("sub_td", "sub_td2"): + with td.lock_(), pytest.raises(RuntimeError, match=re.escape(_LOCK_ERROR)): + td.set_non_tensor(("this", "will"), "fail") + # check set + with td.unlock_(): + td.set(("this", "tensor"), torch.zeros(td.shape)) + reached = False + with pytest.raises( + RuntimeError, + match="set_non_tensor is not compatible with the tensordict type", + ) if td_name in ("td_h5",) else contextlib.nullcontext(): + td.set_non_tensor(("this", "will"), "succeed") + reached = True + if not reached: + return + # check get (for tensor) + assert (td.get_non_tensor(("this", "tensor")) == 0).all() + # check get (for non-tensor) + assert td.get_non_tensor(("this", "will")) == "succeed" + assert isinstance(td.get(("this", "will")), NonTensorData) + + def test_non_tensor_data_flatten_keys(self, td_name, device): + td = getattr(self, td_name)(device) + with td.unlock_(): + td.set(("this", "tensor"), torch.zeros(td.shape)) + reached = False + with pytest.raises( + RuntimeError, + match="set_non_tensor is not compatible with the tensordict type", + ) if td_name in ("td_h5",) else contextlib.nullcontext(): + td.set_non_tensor(("this", "will"), "succeed") + reached = True + if not reached: + return + td_flat = td.flatten_keys() + assert (td_flat.get("this.tensor") == 0).all() + assert td_flat.get_non_tensor("this.will") == "succeed" + + def test_non_tensor_data_pickle(self, td_name, device, tmpdir): + td = getattr(self, td_name)(device) + with td.unlock_(): + td.set(("this", "tensor"), torch.zeros(td.shape)) + reached = False + with pytest.raises( + RuntimeError, + match="set_non_tensor is not compatible with the tensordict type", + ) if td_name in ("td_h5",) else contextlib.nullcontext(): + td.set_non_tensor(("this", "will"), "succeed") + reached = True + if not reached: + return + td.set_non_tensor(("non", "json", "serializable"), DummyPicklableClass(10)) + td.memmap(prefix=tmpdir, copy_existing=True) + loaded = TensorDict.load_memmap(tmpdir) + assert isinstance(loaded.get(("non", "json", "serializable")), NonTensorData) + assert loaded.get_non_tensor(("non", "json", "serializable")).value == 10 + assert loaded.get_non_tensor(("this", "will")) == "succeed" + @pytest.mark.parametrize("device", [None, *get_available_devices()]) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @@ -6809,6 +6901,71 @@ def test_map_unbind(self): assert (td_out["2"] == 2).all() +# class TestNonTensorData: +class TestNonTensorData: + @pytest.fixture + def non_tensor_data(self): + return TensorDict( + { + "1": 1, + "nested": { + "int": NonTensorData(3, batch_size=[]), + "str": NonTensorData("a string!", batch_size=[]), + "bool": NonTensorData(True, batch_size=[]), + }, + }, + batch_size=[], + ) + + def test_nontensor_dict(self, non_tensor_data): + assert ( + TensorDict.from_dict(non_tensor_data.to_dict()) == non_tensor_data + ).all() + + def test_set(self, non_tensor_data): + non_tensor_data.set(("nested", "another_string"), "another string!") + assert ( + non_tensor_data.get(("nested", "another_string")).data == "another string!" + ) + assert ( + non_tensor_data.get_non_tensor(("nested", "another_string")) + == "another string!" + ) + + def test_stack(self, non_tensor_data): + assert ( + torch.stack([non_tensor_data, non_tensor_data], 0).get(("nested", "int")) + == NonTensorData(3, batch_size=[2]) + ).all() + assert ( + torch.stack([non_tensor_data, non_tensor_data], 0).get_non_tensor( + ("nested", "int") + ) + == 3 + ) + assert isinstance( + torch.stack([non_tensor_data, non_tensor_data], 0).get(("nested", "int")), + NonTensorData, + ) + non_tensor_copy = non_tensor_data.clone() + non_tensor_copy.get(("nested", "int")).data = 4 + assert isinstance( + torch.stack([non_tensor_data, non_tensor_copy], 0).get(("nested", "int")), + LazyStackedTensorDict, + ) + + def test_comparison(self, non_tensor_data): + non_tensor_data = non_tensor_data.exclude(("nested", "str")) + assert (non_tensor_data | non_tensor_data).get_non_tensor(("nested", "bool")) + assert not (non_tensor_data ^ non_tensor_data).get_non_tensor( + ("nested", "bool") + ) + assert (non_tensor_data == non_tensor_data).get_non_tensor(("nested", "bool")) + assert not (non_tensor_data != non_tensor_data).get_non_tensor( + ("nested", "bool") + ) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)