diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index a83d1419122..c5d54566c08 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -846,7 +846,8 @@ def _maybe_make_final_rollout(self, make_rollout: bool): if key in self._final_rollout.keys(True): continue self._final_rollout.set(key, spec.zero()) - + elif not make_rollout and hasattr(self.policy, "out_keys") and self.policy.out_keys: + self._policy_output_keys = list(self.policy.out_keys) else: if make_rollout: # otherwise, we perform a small number of steps with the policy to