From 76aa9bc0c3cf1eb3fd2ab85bf156b9dee445dd8c Mon Sep 17 00:00:00 2001 From: Kurt Mohler Date: Tue, 18 Feb 2025 09:23:03 +0000 Subject: [PATCH] [BugFix] Fix MultiAction reset ghstack-source-id: a2f7bfdd7522a214430182dac65687a977b1a10d Pull Request resolved: https://github.com/pytorch/rl/pull/2789 --- torchrl/data/map/query.py | 40 +++++++++++++-------------- torchrl/envs/transforms/transforms.py | 19 ++++++++++--- 2 files changed, 35 insertions(+), 24 deletions(-) diff --git a/torchrl/data/map/query.py b/torchrl/data/map/query.py index 6c4c2f9e0e2..3eca179cf56 100644 --- a/torchrl/data/map/query.py +++ b/torchrl/data/map/query.py @@ -86,26 +86,26 @@ class QueryModule(TensorDictModuleBase): providing the ``clone`` argument to the forward method. Defaults to ``False``. - Examples: - >>> query_module = QueryModule( - ... in_keys=["key1", "key2"], - ... index_key="index", - ... hash_module=SipHash(), - ... ) - >>> query = TensorDict( - ... { - ... "key1": torch.Tensor([[1], [1], [1], [2]]), - ... "key2": torch.Tensor([[3], [3], [2], [3]]), - ... "other": torch.randn(4), - ... }, - ... batch_size=(4,), - ... ) - >>> res = query_module(query) - >>> # The first two pairs of key1 and key2 match - >>> assert res["index"][0] == res["index"][1] - >>> # The last three pairs of key1 and key2 have at least one mismatching value - >>> assert res["index"][1] != res["index"][2] - >>> assert res["index"][2] != res["index"][3] + Examples: + >>> query_module = QueryModule( + ... in_keys=["key1", "key2"], + ... index_key="index", + ... hash_module=SipHash(), + ... ) + >>> query = TensorDict( + ... { + ... "key1": torch.Tensor([[1], [1], [1], [2]]), + ... "key2": torch.Tensor([[3], [3], [2], [3]]), + ... "other": torch.randn(4), + ... }, + ... batch_size=(4,), + ... ) + >>> res = query_module(query) + >>> # The first two pairs of key1 and key2 match + >>> assert res["index"][0] == res["index"][1] + >>> # The last three pairs of key1 and key2 have at least one mismatching value + >>> assert res["index"][1] != res["index"][2] + >>> assert res["index"][2] != res["index"][3] """ diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 77d456c005e..84eb873cda0 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -209,6 +209,7 @@ class Transform(nn.Module): """ invertible = False + enable_inv_on_reset = False def __init__( self, @@ -293,6 +294,13 @@ def _reset( """Resets a transform if it is stateful.""" return tensordict_reset + def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase: + """Inverts the input to :meth:`TransformedEnv._reset`, if needed.""" + if self.enable_inv_on_reset: + with _set_missing_tolerance(self, True): + tensordict = self.inv(tensordict) + return tensordict + def init(self, tensordict) -> None: pass @@ -1018,10 +1026,7 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): tensordict = tensordict.select( *self.reset_keys, *self.state_spec.keys(True, True), strict=False ) - # Inputs might be transformed, so need to apply inverse transform - # before passing to the env reset function. - with _set_missing_tolerance(self.transform, True): - tensordict = self.transform.inv(tensordict) + tensordict = self.transform._reset_env_preprocess(tensordict) tensordict_reset = self.base_env._reset(tensordict, **kwargs) if tensordict is None: # make sure all transforms see a source tensordict @@ -1369,6 +1374,11 @@ def _reset( tensordict_reset = t._reset(tensordict, tensordict_reset) return tensordict_reset + def _reset_env_preprocess(self, tensordict: TensorDictBase) -> TensorDictBase: + for t in reversed(self.transforms): + tensordict = t._reset_env_preprocess(tensordict) + return tensordict + def init(self, tensordict: TensorDictBase) -> None: for t in self.transforms: t.init(tensordict) @@ -4725,6 +4735,7 @@ class UnaryTransform(Transform): [torchrl][INFO] check_env_specs succeeded! """ + enable_inv_on_reset = True def __init__( self,