Skip to content

Commit

Permalink
[BugFix] Fix MultiAction reset
Browse files Browse the repository at this point in the history
ghstack-source-id: a2f7bfdd7522a214430182dac65687a977b1a10d
Pull Request resolved: #2789
  • Loading branch information
kurtamohler authored and vmoens committed Feb 18, 2025
1 parent 03d6586 commit 76aa9bc
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 24 deletions.
40 changes: 20 additions & 20 deletions torchrl/data/map/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,26 +86,26 @@ class QueryModule(TensorDictModuleBase):
providing the ``clone`` argument to the forward method.
Defaults to ``False``.
Examples:
>>> query_module = QueryModule(
... in_keys=["key1", "key2"],
... index_key="index",
... hash_module=SipHash(),
... )
>>> query = TensorDict(
... {
... "key1": torch.Tensor([[1], [1], [1], [2]]),
... "key2": torch.Tensor([[3], [3], [2], [3]]),
... "other": torch.randn(4),
... },
... batch_size=(4,),
... )
>>> res = query_module(query)
>>> # The first two pairs of key1 and key2 match
>>> assert res["index"][0] == res["index"][1]
>>> # The last three pairs of key1 and key2 have at least one mismatching value
>>> assert res["index"][1] != res["index"][2]
>>> assert res["index"][2] != res["index"][3]
Examples:
>>> query_module = QueryModule(
... in_keys=["key1", "key2"],
... index_key="index",
... hash_module=SipHash(),
... )
>>> query = TensorDict(
... {
... "key1": torch.Tensor([[1], [1], [1], [2]]),
... "key2": torch.Tensor([[3], [3], [2], [3]]),
... "other": torch.randn(4),
... },
... batch_size=(4,),
... )
>>> res = query_module(query)
>>> # The first two pairs of key1 and key2 match
>>> assert res["index"][0] == res["index"][1]
>>> # The last three pairs of key1 and key2 have at least one mismatching value
>>> assert res["index"][1] != res["index"][2]
>>> assert res["index"][2] != res["index"][3]
"""

Expand Down
19 changes: 15 additions & 4 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ class Transform(nn.Module):
"""

invertible = False
enable_inv_on_reset = False

def __init__(
self,
Expand Down Expand Up @@ -293,6 +294,13 @@ def _reset(
"""Resets a transform if it is stateful."""
return tensordict_reset

def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Inverts the input to :meth:`TransformedEnv._reset`, if needed."""
if self.enable_inv_on_reset:
with _set_missing_tolerance(self, True):
tensordict = self.inv(tensordict)
return tensordict

def init(self, tensordict) -> None:
pass

Expand Down Expand Up @@ -1018,10 +1026,7 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs):
tensordict = tensordict.select(
*self.reset_keys, *self.state_spec.keys(True, True), strict=False
)
# Inputs might be transformed, so need to apply inverse transform
# before passing to the env reset function.
with _set_missing_tolerance(self.transform, True):
tensordict = self.transform.inv(tensordict)
tensordict = self.transform._reset_env_preprocess(tensordict)
tensordict_reset = self.base_env._reset(tensordict, **kwargs)
if tensordict is None:
# make sure all transforms see a source tensordict
Expand Down Expand Up @@ -1369,6 +1374,11 @@ def _reset(
tensordict_reset = t._reset(tensordict, tensordict_reset)
return tensordict_reset

def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase:
for t in reversed(self.transforms):
tensordict = t._reset_env_preprocess(tensordict)
return tensordict

def init(self, tensordict: TensorDictBase) -> None:
for t in self.transforms:
t.init(tensordict)
Expand Down Expand Up @@ -4725,6 +4735,7 @@ class UnaryTransform(Transform):
[torchrl][INFO] check_env_specs succeeded!
"""
enable_inv_on_reset = True

def __init__(
self,
Expand Down

0 comments on commit 76aa9bc

Please sign in to comment.