From 070ca618bcc6b7fadcb4955a1016ca31babd6e36 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 16 Jan 2025 12:57:31 +0000 Subject: [PATCH] [Feature] Subclass conservation in td ops ghstack-source-id: 83e79abda6a4bb6839d99240052323380981855c Pull Request resolved: https://github.com/pytorch/tensordict/pull/1186 --- tensordict/_td.py | 28 ++++++++++++++-------------- tensordict/_torch_func.py | 32 +++++++++++++++++++++----------- tensordict/base.py | 10 ++++++++++ tensordict/nn/params.py | 22 +++++++++++++--------- test/test_tensordict.py | 32 ++++++++++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 34 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 0f66988bf..77f806ee2 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -344,7 +344,7 @@ def _new_unsafe( if source: # faster than calling items for key, value in source.items(): if nested and isinstance(value, dict): - value = TensorDict._new_unsafe( + value = cls._new_unsafe( source=value, batch_size=self._batch_size, device=self._device, @@ -374,7 +374,7 @@ def from_module( filter_empty=filter_empty, ) if result is None: - result = TensorDict._new_unsafe({}, batch_size=torch.Size(())) + result = cls._new_unsafe({}, batch_size=torch.Size(())) if lock: result.lock_() return result @@ -419,7 +419,7 @@ def _from_module( destination = hook_result if not filter_empty or destination: destination_set = True - destination = TensorDict._new_unsafe(destination, batch_size=torch.Size(())) + destination = cls._new_unsafe(destination, batch_size=torch.Size(())) else: destination_set = False for name, submodule in module._modules.items(): @@ -433,7 +433,7 @@ def _from_module( ) if subtd is not None: if not destination_set: - destination = TensorDict._new_unsafe(batch_size=torch.Size(())) + destination = cls._new_unsafe(batch_size=torch.Size(())) destination_set = True destination._set_str( name, subtd, validated=True, inplace=False, non_blocking=False @@ -610,7 +610,7 @@ def _quick_set(swap_dict, swap_td): _quick_set(_swap, swap_dest) return swap_dest else: - return TensorDict._new_unsafe(_swap, batch_size=torch.Size(())) + return self._new_unsafe(_swap, batch_size=torch.Size(())) @_maybe_broadcast_other("__ne__") def __ne__(self, other: Any) -> T | bool: @@ -1479,7 +1479,7 @@ def _add_batch_dim_wrapper(key, value): return value return _add_batch_dim(value, in_dim, vmap_level) - out = TensorDict._new_unsafe( + out = self._new_unsafe( {key: _add_batch_dim_wrapper(key, value) for key, value in td.items()}, batch_size=torch.Size( [b for i, b in enumerate(td.batch_size) if i != in_dim] @@ -1613,7 +1613,7 @@ def _check_for_invalid_index(index): ) else: source[key] = _get_item(item, index) - result = TensorDict._new_unsafe( + result = self._new_unsafe( source=source, batch_size=batch_size, device=self.device, @@ -1694,7 +1694,7 @@ def empty( is_shared=is_shared, is_memmap=is_memmap, ): - result = TensorDict._new_unsafe( + result = self._new_unsafe( {}, batch_size=batch_size, names=names, device=device ) result._is_shared = is_shared @@ -3231,7 +3231,7 @@ def _clone(self, recurse: bool = True) -> T: if recurse and self.device is not None: return self._clone_recurse() - result = TensorDict._new_unsafe( + result = self._new_unsafe( source={key: _clone_value(value, recurse) for key, value in self.items()}, batch_size=self.batch_size, device=self.device, @@ -3248,7 +3248,7 @@ def contiguous(self) -> T: source = {key: value.contiguous() for key, value in self.items()} batch_size = self.batch_size device = self.device - out = TensorDict._new_unsafe( + out = self._new_unsafe( source=source, batch_size=batch_size, device=device, @@ -3260,7 +3260,7 @@ def empty( self, recurse=False, *, batch_size=None, device=NO_DEFAULT, names=NO_DEFAULT ) -> T: if not recurse: - return TensorDict._new_unsafe( + return self._new_unsafe( device=self._device if device is NO_DEFAULT else device, batch_size=( self._batch_size if batch_size is None else torch.Size(batch_size) @@ -3309,7 +3309,7 @@ def _select( *val, strict=strict, inplace=inplace, set_shared=set_shared ) - result = TensorDict._new_unsafe( + result = self._new_unsafe( device=self.device, batch_size=self.batch_size, source=source, @@ -3358,7 +3358,7 @@ def _exclude( _tensordict[key] = val if inplace: return self - result = TensorDict._new_unsafe( + result = self._new_unsafe( _tensordict, batch_size=self.batch_size, device=self.device, @@ -4059,7 +4059,7 @@ def is_contiguous(self) -> bool: return all(value.is_contiguous() for value in self.values()) def contiguous(self) -> T: - return TensorDict._new_unsafe( + return self._new_unsafe( batch_size=self.batch_size, source={key: value.contiguous() for key, value in self.items()}, device=self.device, diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index be504b2b5..a69047e81 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -23,11 +23,11 @@ from tensordict.utils import ( _check_keys, _ErrorInteceptor, + _is_tensorclass, _pass_through, _shape, _zip_strict, DeviceType, - is_tensorclass, lazy_legacy, set_lazy_legacy, ) @@ -138,12 +138,12 @@ def _gather_tensor(tensor, dest_container=None, dest_key=None): return out if out is None: - if len(index.shape) == input.ndim and input._has_names(): - names = input.names + if len(index.shape) == input.ndim: + names = input._maybe_names() else: names = None device = input.device - return TensorDict( + return type(input)._new_unsafe( { key: _gather_tensor(value) for key, value in input.items(is_leaf=_is_leaf_nontensor) @@ -300,6 +300,7 @@ def _cat( raise RuntimeError("list_of_tensordicts cannot be empty") batch_size = list(list_of_tensordicts[0].batch_size) + tdtype = type(list_of_tensordicts[0]) if dim < 0: dim = len(batch_size) + dim if dim >= len(batch_size): @@ -334,9 +335,13 @@ def _cat( names = None if list_of_tensordicts[0]._has_names(): names = list_of_tensordicts[0].names - return TensorDict._new_unsafe( - out, device=device, batch_size=batch_size, names=names - ) + # if we have a TD subclass, use _new_unsafe bc we know it exists. Otherwise, use + # TensorDict's one + if issubclass(tdtype, TensorDict): + clz = tdtype + else: + clz = TensorDict + return clz._new_unsafe(out, device=device, batch_size=batch_size, names=names) else: if out.batch_size != batch_size: raise RuntimeError( @@ -453,14 +458,19 @@ def _stack( raise RuntimeError("list_of_tensordicts cannot be empty") if maybe_dense_stack is None: maybe_dense_stack = lazy_legacy() - is_tc = any(is_tensorclass(td) for td in list_of_tensordicts) + td_types = [type(td) for td in list_of_tensordicts] + is_tc = any(_is_tensorclass(td_type) for td_type in td_types) if all(_pass_through(td) for td in list_of_tensordicts): return type(list_of_tensordicts[0])._stack_non_tensor( list_of_tensordicts, dim=dim ) if is_tc: - tc_type = type(list_of_tensordicts[0]) list_of_tensordicts = [tc._tensordict for tc in list_of_tensordicts] + clz = type(list_of_tensordicts[0]) + elif issubclass(td_types[0], TensorDict): + clz = td_types[0] + else: + clz = TensorDict batch_size = list_of_tensordicts[0].batch_size if dim < 0: @@ -617,7 +627,7 @@ def stack_fn(key, values, is_not_init, is_tensor): for key, (values, is_not_init, is_tensor) in out.items() } - result = TensorDict._new_unsafe( + result = clz._new_unsafe( out, batch_size=LazyStackedTensorDict._compute_batch_size( batch_size, dim, len(list_of_tensordicts) @@ -625,7 +635,7 @@ def stack_fn(key, values, is_not_init, is_tensor): device=device, ) if is_tc: - return tc_type._from_tensordict(result) + return td_types[0]._from_tensordict(result) return result else: out = LazyStackedTensorDict( diff --git a/tensordict/base.py b/tensordict/base.py index a73e4163f..055308d5d 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -273,6 +273,16 @@ class TensorDictBase(MutableMapping): _memmap_prefix = None _stream: torch.cuda.Stream | None = None + @classmethod + def _new_unsafe(cls, *args, **kwargs): + # This to make sure all TensorDictBase subclasses have a proper fallback if they don't have a _new_unsafe + # In other words, only TensorDict subclasses will have their type preserved, others will become TensorDict + # instances (note that TensorDictBase should not be directly subclassed outside of this codebase, as it is + # highly abstract). + from tensordict._td import TensorDict + + return TensorDict._new_unsafe(*args, **kwargs) + def __bool__(self) -> bool: raise RuntimeError("Converting a tensordict to boolean value is not permitted") diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 03c9632a1..0a5107e9e 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -390,7 +390,7 @@ def _new_unsafe( cls, parameters: TensorDictBase, *, - no_convert=False, + no_convert=None, lock: bool = False, params: dict | None = None, buffers: dict | None = None, @@ -399,24 +399,28 @@ def _new_unsafe( if is_compiling(): return TensorDictParams(parameters, no_convert="skip", lock=lock) - self = TensorDictParams.__new__(cls) - nn.Module.__init__(self) - if parameters is None: parameters = kwargs - elif kwargs: - raise TypeError( - f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args." - ) if isinstance(parameters, dict): - parameters = TensorDict._new_unsafe(parameters) + parameters = TensorDict._new_unsafe(parameters, **kwargs) + if no_convert is None: + # Then _new_unsafe is called from somewhere that doesn't know + # that it's a TDParams and we return a TensorDict (eg, torch.gather) + return parameters elif isinstance(parameters, TensorDictParams): + if kwargs: + raise TypeError( + f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args." + ) params = dict(parameters._parameters) buffers = dict(parameters._buffers) parameters = parameters._param_td no_convert = "skip" + self = TensorDictParams.__new__(cls) + nn.Module.__init__(self) + self._param_td = parameters self.no_convert = no_convert if no_convert != "skip": diff --git a/test/test_tensordict.py b/test/test_tensordict.py index ebf1f76bd..838ad35d0 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2709,6 +2709,38 @@ def test_reduction_feature_full(self, reduction): assert getattr(td, reduction)(reduce=True, dim="feature").shape == (3, 4) assert getattr(td, reduction)(reduce=True, dim=1).shape == (3, 5) + def test_subclassing(self): + class SubTD(TensorDict): ... + + t = SubTD(a=torch.randn(3)) + assert isinstance(t + t, SubTD) + assert isinstance(t / 2, SubTD) + assert isinstance(2 / t, SubTD) + assert isinstance(t.to(torch.float), SubTD) + assert isinstance(t.to("cpu"), SubTD) + assert isinstance(torch.zeros_like(t), SubTD) + assert isinstance(t.copy(), SubTD) + assert isinstance(t.clone(), SubTD) + assert isinstance(t.empty(), SubTD) + assert isinstance(t.select(), SubTD) + assert isinstance(t.exclude("a"), SubTD) + assert isinstance(t.split_keys({"a"})[0], SubTD) + assert isinstance(t.flatten_keys(), SubTD) + assert isinstance(t.unflatten_keys(), SubTD) + stack = torch.stack([t, t]) + assert isinstance(stack, SubTD) + assert isinstance(stack[0], SubTD) + assert isinstance(stack.unbind(0)[0], SubTD) + assert isinstance(stack.split(1)[0], SubTD) + assert isinstance(stack.gather(0, torch.ones((1,), dtype=torch.long)), SubTD) + unsqueeze = stack.unsqueeze(0) + assert isinstance(unsqueeze, SubTD) + assert isinstance(unsqueeze.transpose(1, 0), SubTD) + assert isinstance(unsqueeze.permute(1, 0), SubTD) + assert isinstance(unsqueeze.squeeze(), SubTD) + assert isinstance(unsqueeze.reshape(-1), SubTD) + assert isinstance(unsqueeze.view(-1), SubTD) + @pytest.mark.parametrize("device", get_available_devices()) def test_subtensordict_construction(self, device): torch.manual_seed(1)