Skip to content

Commit

Permalink
[Feature] Subclass conservation in td ops
Browse files Browse the repository at this point in the history
ghstack-source-id: 83e79abda6a4bb6839d99240052323380981855c
Pull Request resolved: #1186
  • Loading branch information
vmoens committed Jan 20, 2025
1 parent bbf773b commit 070ca61
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 34 deletions.
28 changes: 14 additions & 14 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _new_unsafe(
if source: # faster than calling items
for key, value in source.items():
if nested and isinstance(value, dict):
value = TensorDict._new_unsafe(
value = cls._new_unsafe(
source=value,
batch_size=self._batch_size,
device=self._device,
Expand Down Expand Up @@ -374,7 +374,7 @@ def from_module(
filter_empty=filter_empty,
)
if result is None:
result = TensorDict._new_unsafe({}, batch_size=torch.Size(()))
result = cls._new_unsafe({}, batch_size=torch.Size(()))
if lock:
result.lock_()
return result
Expand Down Expand Up @@ -419,7 +419,7 @@ def _from_module(
destination = hook_result
if not filter_empty or destination:
destination_set = True
destination = TensorDict._new_unsafe(destination, batch_size=torch.Size(()))
destination = cls._new_unsafe(destination, batch_size=torch.Size(()))
else:
destination_set = False
for name, submodule in module._modules.items():
Expand All @@ -433,7 +433,7 @@ def _from_module(
)
if subtd is not None:
if not destination_set:
destination = TensorDict._new_unsafe(batch_size=torch.Size(()))
destination = cls._new_unsafe(batch_size=torch.Size(()))
destination_set = True
destination._set_str(
name, subtd, validated=True, inplace=False, non_blocking=False
Expand Down Expand Up @@ -610,7 +610,7 @@ def _quick_set(swap_dict, swap_td):
_quick_set(_swap, swap_dest)
return swap_dest
else:
return TensorDict._new_unsafe(_swap, batch_size=torch.Size(()))
return self._new_unsafe(_swap, batch_size=torch.Size(()))

@_maybe_broadcast_other("__ne__")
def __ne__(self, other: Any) -> T | bool:
Expand Down Expand Up @@ -1479,7 +1479,7 @@ def _add_batch_dim_wrapper(key, value):
return value
return _add_batch_dim(value, in_dim, vmap_level)

out = TensorDict._new_unsafe(
out = self._new_unsafe(
{key: _add_batch_dim_wrapper(key, value) for key, value in td.items()},
batch_size=torch.Size(
[b for i, b in enumerate(td.batch_size) if i != in_dim]
Expand Down Expand Up @@ -1613,7 +1613,7 @@ def _check_for_invalid_index(index):
)
else:
source[key] = _get_item(item, index)
result = TensorDict._new_unsafe(
result = self._new_unsafe(
source=source,
batch_size=batch_size,
device=self.device,
Expand Down Expand Up @@ -1694,7 +1694,7 @@ def empty(
is_shared=is_shared,
is_memmap=is_memmap,
):
result = TensorDict._new_unsafe(
result = self._new_unsafe(
{}, batch_size=batch_size, names=names, device=device
)
result._is_shared = is_shared
Expand Down Expand Up @@ -3231,7 +3231,7 @@ def _clone(self, recurse: bool = True) -> T:
if recurse and self.device is not None:
return self._clone_recurse()

result = TensorDict._new_unsafe(
result = self._new_unsafe(
source={key: _clone_value(value, recurse) for key, value in self.items()},
batch_size=self.batch_size,
device=self.device,
Expand All @@ -3248,7 +3248,7 @@ def contiguous(self) -> T:
source = {key: value.contiguous() for key, value in self.items()}
batch_size = self.batch_size
device = self.device
out = TensorDict._new_unsafe(
out = self._new_unsafe(
source=source,
batch_size=batch_size,
device=device,
Expand All @@ -3260,7 +3260,7 @@ def empty(
self, recurse=False, *, batch_size=None, device=NO_DEFAULT, names=NO_DEFAULT
) -> T:
if not recurse:
return TensorDict._new_unsafe(
return self._new_unsafe(
device=self._device if device is NO_DEFAULT else device,
batch_size=(
self._batch_size if batch_size is None else torch.Size(batch_size)
Expand Down Expand Up @@ -3309,7 +3309,7 @@ def _select(
*val, strict=strict, inplace=inplace, set_shared=set_shared
)

result = TensorDict._new_unsafe(
result = self._new_unsafe(
device=self.device,
batch_size=self.batch_size,
source=source,
Expand Down Expand Up @@ -3358,7 +3358,7 @@ def _exclude(
_tensordict[key] = val
if inplace:
return self
result = TensorDict._new_unsafe(
result = self._new_unsafe(
_tensordict,
batch_size=self.batch_size,
device=self.device,
Expand Down Expand Up @@ -4059,7 +4059,7 @@ def is_contiguous(self) -> bool:
return all(value.is_contiguous() for value in self.values())

def contiguous(self) -> T:
return TensorDict._new_unsafe(
return self._new_unsafe(
batch_size=self.batch_size,
source={key: value.contiguous() for key, value in self.items()},
device=self.device,
Expand Down
32 changes: 21 additions & 11 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from tensordict.utils import (
_check_keys,
_ErrorInteceptor,
_is_tensorclass,
_pass_through,
_shape,
_zip_strict,
DeviceType,
is_tensorclass,
lazy_legacy,
set_lazy_legacy,
)
Expand Down Expand Up @@ -138,12 +138,12 @@ def _gather_tensor(tensor, dest_container=None, dest_key=None):
return out

if out is None:
if len(index.shape) == input.ndim and input._has_names():
names = input.names
if len(index.shape) == input.ndim:
names = input._maybe_names()
else:
names = None
device = input.device
return TensorDict(
return type(input)._new_unsafe(
{
key: _gather_tensor(value)
for key, value in input.items(is_leaf=_is_leaf_nontensor)
Expand Down Expand Up @@ -300,6 +300,7 @@ def _cat(
raise RuntimeError("list_of_tensordicts cannot be empty")

batch_size = list(list_of_tensordicts[0].batch_size)
tdtype = type(list_of_tensordicts[0])
if dim < 0:
dim = len(batch_size) + dim
if dim >= len(batch_size):
Expand Down Expand Up @@ -334,9 +335,13 @@ def _cat(
names = None
if list_of_tensordicts[0]._has_names():
names = list_of_tensordicts[0].names
return TensorDict._new_unsafe(
out, device=device, batch_size=batch_size, names=names
)
# if we have a TD subclass, use _new_unsafe bc we know it exists. Otherwise, use
# TensorDict's one
if issubclass(tdtype, TensorDict):
clz = tdtype
else:
clz = TensorDict
return clz._new_unsafe(out, device=device, batch_size=batch_size, names=names)
else:
if out.batch_size != batch_size:
raise RuntimeError(
Expand Down Expand Up @@ -453,14 +458,19 @@ def _stack(
raise RuntimeError("list_of_tensordicts cannot be empty")
if maybe_dense_stack is None:
maybe_dense_stack = lazy_legacy()
is_tc = any(is_tensorclass(td) for td in list_of_tensordicts)
td_types = [type(td) for td in list_of_tensordicts]
is_tc = any(_is_tensorclass(td_type) for td_type in td_types)
if all(_pass_through(td) for td in list_of_tensordicts):
return type(list_of_tensordicts[0])._stack_non_tensor(
list_of_tensordicts, dim=dim
)
if is_tc:
tc_type = type(list_of_tensordicts[0])
list_of_tensordicts = [tc._tensordict for tc in list_of_tensordicts]
clz = type(list_of_tensordicts[0])
elif issubclass(td_types[0], TensorDict):
clz = td_types[0]
else:
clz = TensorDict

batch_size = list_of_tensordicts[0].batch_size
if dim < 0:
Expand Down Expand Up @@ -617,15 +627,15 @@ def stack_fn(key, values, is_not_init, is_tensor):
for key, (values, is_not_init, is_tensor) in out.items()
}

result = TensorDict._new_unsafe(
result = clz._new_unsafe(
out,
batch_size=LazyStackedTensorDict._compute_batch_size(
batch_size, dim, len(list_of_tensordicts)
),
device=device,
)
if is_tc:
return tc_type._from_tensordict(result)
return td_types[0]._from_tensordict(result)
return result
else:
out = LazyStackedTensorDict(
Expand Down
10 changes: 10 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,16 @@ class TensorDictBase(MutableMapping):
_memmap_prefix = None
_stream: torch.cuda.Stream | None = None

@classmethod
def _new_unsafe(cls, *args, **kwargs):
# This to make sure all TensorDictBase subclasses have a proper fallback if they don't have a _new_unsafe
# In other words, only TensorDict subclasses will have their type preserved, others will become TensorDict
# instances (note that TensorDictBase should not be directly subclassed outside of this codebase, as it is
# highly abstract).
from tensordict._td import TensorDict

return TensorDict._new_unsafe(*args, **kwargs)

def __bool__(self) -> bool:
raise RuntimeError("Converting a tensordict to boolean value is not permitted")

Expand Down
22 changes: 13 additions & 9 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def _new_unsafe(
cls,
parameters: TensorDictBase,
*,
no_convert=False,
no_convert=None,
lock: bool = False,
params: dict | None = None,
buffers: dict | None = None,
Expand All @@ -399,24 +399,28 @@ def _new_unsafe(
if is_compiling():
return TensorDictParams(parameters, no_convert="skip", lock=lock)

self = TensorDictParams.__new__(cls)
nn.Module.__init__(self)

if parameters is None:
parameters = kwargs
elif kwargs:
raise TypeError(
f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args."
)

if isinstance(parameters, dict):
parameters = TensorDict._new_unsafe(parameters)
parameters = TensorDict._new_unsafe(parameters, **kwargs)
if no_convert is None:
# Then _new_unsafe is called from somewhere that doesn't know
# that it's a TDParams and we return a TensorDict (eg, torch.gather)
return parameters
elif isinstance(parameters, TensorDictParams):
if kwargs:
raise TypeError(
f"parameters cannot be passed along with extra keyword arguments, but got {kwargs.keys()} extra args."
)
params = dict(parameters._parameters)
buffers = dict(parameters._buffers)
parameters = parameters._param_td
no_convert = "skip"

self = TensorDictParams.__new__(cls)
nn.Module.__init__(self)

self._param_td = parameters
self.no_convert = no_convert
if no_convert != "skip":
Expand Down
32 changes: 32 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2709,6 +2709,38 @@ def test_reduction_feature_full(self, reduction):
assert getattr(td, reduction)(reduce=True, dim="feature").shape == (3, 4)
assert getattr(td, reduction)(reduce=True, dim=1).shape == (3, 5)

def test_subclassing(self):
class SubTD(TensorDict): ...

t = SubTD(a=torch.randn(3))
assert isinstance(t + t, SubTD)
assert isinstance(t / 2, SubTD)
assert isinstance(2 / t, SubTD)
assert isinstance(t.to(torch.float), SubTD)
assert isinstance(t.to("cpu"), SubTD)
assert isinstance(torch.zeros_like(t), SubTD)
assert isinstance(t.copy(), SubTD)
assert isinstance(t.clone(), SubTD)
assert isinstance(t.empty(), SubTD)
assert isinstance(t.select(), SubTD)
assert isinstance(t.exclude("a"), SubTD)
assert isinstance(t.split_keys({"a"})[0], SubTD)
assert isinstance(t.flatten_keys(), SubTD)
assert isinstance(t.unflatten_keys(), SubTD)
stack = torch.stack([t, t])
assert isinstance(stack, SubTD)
assert isinstance(stack[0], SubTD)
assert isinstance(stack.unbind(0)[0], SubTD)
assert isinstance(stack.split(1)[0], SubTD)
assert isinstance(stack.gather(0, torch.ones((1,), dtype=torch.long)), SubTD)
unsqueeze = stack.unsqueeze(0)
assert isinstance(unsqueeze, SubTD)
assert isinstance(unsqueeze.transpose(1, 0), SubTD)
assert isinstance(unsqueeze.permute(1, 0), SubTD)
assert isinstance(unsqueeze.squeeze(), SubTD)
assert isinstance(unsqueeze.reshape(-1), SubTD)
assert isinstance(unsqueeze.view(-1), SubTD)

@pytest.mark.parametrize("device", get_available_devices())
def test_subtensordict_construction(self, device):
torch.manual_seed(1)
Expand Down

0 comments on commit 070ca61

Please sign in to comment.