Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Feb 24, 2025
1 parent 0fd3ccc commit 348558c
Showing 1 changed file with 32 additions and 30 deletions.
62 changes: 32 additions & 30 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,8 +768,7 @@ def __init__(
self.set_truncated = set_truncated

self._make_shuttle()
if self._use_buffers:
self._make_final_rollout()
self._maybe_make_final_rollout(make_rollout=self._use_buffers)
self._set_truncated_keys()

if split_trajs is None:
Expand Down Expand Up @@ -806,28 +805,29 @@ def _make_shuttle(self):
traj_ids,
)

def _make_final_rollout(self):
with torch.no_grad():
self._final_rollout = self.env.fake_tensordict()

# If storing device is not None, we use this to cast the storage.
# If it is None and the env and policy are on the same device,
# the storing device is already the same as those, so we don't need
# to consider this use case.
# In all other cases, we can't really put a device on the storage,
# since at least one data source has a device that is not clear.
if self.storing_device:
self._final_rollout = self._final_rollout.to(
self.storing_device, non_blocking=True
)
else:
# erase all devices
self._final_rollout.clear_device_()
def _maybe_make_final_rollout(self, make_rollout: bool):
if make_rollout:
with torch.no_grad():
self._final_rollout = self.env.fake_tensordict()

# If storing device is not None, we use this to cast the storage.
# If it is None and the env and policy are on the same device,
# the storing device is already the same as those, so we don't need
# to consider this use case.
# In all other cases, we can't really put a device on the storage,
# since at least one data source has a device that is not clear.
if self.storing_device:
self._final_rollout = self._final_rollout.to(
self.storing_device, non_blocking=True
)
else:
# erase all devices
self._final_rollout.clear_device_()

# If the policy has a valid spec, we use it
self._policy_output_keys = set()
if (
hasattr(self.policy, "spec")
if (make_rollout
and hasattr(self.policy, "spec")
and self.policy.spec is not None
and all(v is not None for v in self.policy.spec.values(True, True))
):
Expand All @@ -848,12 +848,13 @@ def _make_final_rollout(self):
self._final_rollout.set(key, spec.zero())

else:
# otherwise, we perform a small number of steps with the policy to
# determine the relevant keys with which to pre-populate _final_rollout.
# This is the safest thing to do if the spec has None fields or if there is
# no spec at all.
# See #505 for additional context.
self._final_rollout.update(self._shuttle.copy())
if make_rollout:
# otherwise, we perform a small number of steps with the policy to
# determine the relevant keys with which to pre-populate _final_rollout.
# This is the safest thing to do if the spec has None fields or if there is
# no spec at all.
# See #505 for additional context.
self._final_rollout.update(self._shuttle.copy())
with torch.no_grad():
policy_input = self._shuttle.copy()
if self.policy_device:
Expand Down Expand Up @@ -911,9 +912,10 @@ def filter_policy(name, value_output, value_input, value_input_clone):
set(filtered_policy_output.keys(True, True))
)
)
self._final_rollout.update(
policy_output.select(*self._policy_output_keys)
)
if make_rollout:
self._final_rollout.update(
policy_output.select(*self._policy_output_keys)
)
del filtered_policy_output, policy_output, policy_input

_env_output_keys = []
Expand Down

0 comments on commit 348558c

Please sign in to comment.