diff --git a/test/test_env.py b/test/test_env.py index 74b941de580..ed93315ef41 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4058,6 +4058,35 @@ def test_chess_tokenized(self): assert "fen" in ftd["next"] env.check_env_specs() + @pytest.mark.parametrize("stateful", [False, True]) + @pytest.mark.parametrize("include_san", [False, True]) + def test_env_reset_with_hash(self, stateful, include_san): + env = ChessEnv( + include_fen=True, + include_hash=True, + include_hash_inv=True, + stateful=stateful, + include_san=include_san, + ) + cases = [ + # (fen, num_legal_moves) + ("5R1k/8/8/8/6R1/8/8/5K2 b - - 0 1", 1), + ("8/8/2kq4/4K3/1R3Q2/8/8/8 w - - 0 1", 2), + ("6R1/8/8/4rq2/3pPk2/5n2/8/2B1R2K b - e3 0 1", 2), + ] + for fen, num_legal_moves in cases: + # Load the state by fen. + td = env.reset(TensorDict({"fen": fen})) + assert td["fen"] == fen + assert td["action_mask"].sum() == num_legal_moves + # Reset to initial state just to make sure that the next reset + # actually changes the state. + assert env.reset()["action_mask"].sum() == 20 + # Load the state by fen hash and make sure it gives the same output + # as before. + td_check = env.reset(td.select("fen_hash")) + assert (td_check == td).all() + class TestCustomEnvs: def test_tictactoe_env(self): diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index e1016338a95..77d456c005e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -1018,6 +1018,10 @@ def _reset(self, tensordict: Optional[TensorDictBase] = None, **kwargs): tensordict = tensordict.select( *self.reset_keys, *self.state_spec.keys(True, True), strict=False ) + # Inputs might be transformed, so need to apply inverse transform + # before passing to the env reset function. + with _set_missing_tolerance(self.transform, True): + tensordict = self.transform.inv(tensordict) tensordict_reset = self.base_env._reset(tensordict, **kwargs) if tensordict is None: # make sure all transforms see a source tensordict