From 8ae33aaa4e01574bd24ecf9e5bca68d2e2e54d9a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 24 Feb 2025 18:22:36 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- torchrl/collectors/collectors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index a83d1419122..c5d54566c08 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -846,7 +846,8 @@ def _maybe_make_final_rollout(self, make_rollout: bool): if key in self._final_rollout.keys(True): continue self._final_rollout.set(key, spec.zero()) - + elif not make_rollout and hasattr(self.policy, "out_keys") and self.policy.out_keys: + self._policy_output_keys = list(self.policy.out_keys) else: if make_rollout: # otherwise, we perform a small number of steps with the policy to