Skip to content

Commit

Permalink
[Feature] flexible return type when indexing prob sequences
Browse files Browse the repository at this point in the history
ghstack-source-id: 74d28ee84d965c11c527c60b20d9123ef30007f6
Pull Request resolved: #1189
  • Loading branch information
vmoens committed Jan 21, 2025
1 parent 259c941 commit 790bef6
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
6 changes: 5 additions & 1 deletion tensordict/_contextlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,11 @@ def _reverse_squeeze(self, args, kwargs, out):

def _reverse_to_module(self, args, kwargs, out):
try:
with out.unlock_() if not is_compiling() else contextlib.nullcontext():
with (
out.unlock_()
if not is_compiling() and out is not None
else contextlib.nullcontext()
):
return self.to_module(*args, **kwargs, swap_dest=out)
except AttributeError:
# This is a bit unsafe but we assume that out won't have an unlock_() if it's not a TD
Expand Down
21 changes: 21 additions & 0 deletions tensordict/nn/probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,27 @@ def __init__(
super().__init__(*modules, partial_tolerant=partial_tolerant)
self.return_composite = return_composite

def __getitem__(self, index: int | slice | str) -> TensorDictModuleBase:
if isinstance(index, (int, str)):
return self.module.__getitem__(index)
else:
mods = self.module.__getitem__(index)
if self.return_composite and any(
isinstance(
item,
(ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential),
)
for item in mods
):
return type(self)(*mods, return_composite=self.return_composite)
elif isinstance(
mods[-1],
(ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential),
):
return type(self)(*mods)
else:
return TensorDictSequential(*mods)

_dist_sample = ProbabilisticTensorDictModule._dist_sample

@property
Expand Down
19 changes: 19 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,6 +2251,25 @@ def test_nested_keys_probabilistic_normal(self, log_prob_key):
else:
assert td_out[module.log_prob_key].shape == (3, 4, 1)

def test_index_prob_seq(self):
m0 = ProbabilisticTensorDictModule(
in_keys=["loc"], out_keys=["sample"], distribution_class=Normal
)
m1 = TensorDictModule(lambda x: x, in_keys=["other"], out_keys=["something"])
m2 = ProbabilisticTensorDictModule(
in_keys=["scale"], out_keys=["sample2"], distribution_class=Normal
)
seq = ProbabilisticTensorDictSequential(m0, m1, m2)
assert isinstance(seq[0], ProbabilisticTensorDictModule)
assert isinstance(seq[:2], TensorDictSequential)
assert not isinstance(seq[:2], ProbabilisticTensorDictSequential)
assert isinstance(seq[-2:], ProbabilisticTensorDictSequential)

seq = ProbabilisticTensorDictSequential(m0, m1, m2, return_composite=True)
assert isinstance(seq[0], ProbabilisticTensorDictModule)
assert isinstance(seq[:2], ProbabilisticTensorDictSequential)
assert isinstance(seq[-2:], ProbabilisticTensorDictSequential)


class TestEnsembleModule:
def test_init(self):
Expand Down

0 comments on commit 790bef6

Please sign in to comment.