Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed May 26, 2024
1 parent 1e5d8ed commit 6c7a6d0
Showing 1 changed file with 39 additions and 35 deletions.
74 changes: 39 additions & 35 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def __init__(
serial_for_single: bool = False,
non_blocking: bool = False,
mp_start_method: str = None,
use_buffers: bool=None,
use_buffers: bool = None,
):
super().__init__(device=device)
self.serial_for_single = serial_for_single
Expand Down Expand Up @@ -477,8 +477,10 @@ def _get_metadata(
if self._use_buffers is not False:
_use_buffers = not self.meta_data.has_dynamic_specs
if self._use_buffers and not _use_buffers:
warn("A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False.")
warn(
"A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False."
)
self._use_buffers = _use_buffers
if self.share_individual_td is None:
self.share_individual_td = False
Expand All @@ -503,8 +505,10 @@ def _get_metadata(
not metadata.has_dynamic_specs for metadata in self.meta_data
)
if self._use_buffers and not _use_buffers:
warn("A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False.")
warn(
"A value of use_buffers=True was passed but this is incompatible "
"with the list of environments provided. Turning use_buffers to False."
)
self._use_buffers = _use_buffers

self._set_properties()
Expand Down Expand Up @@ -1383,26 +1387,27 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
def _step_and_maybe_reset_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:

for i, _data in enumerate(tensordict.unbind(0)):
self.parent_channels[i].send(("step_and_maybe_reset", _data))

results = [None] * self.num_workers

consumed_indices = []
events = set(range(self.num_workers))
while len(consumed_indices) < self.num_workers:
for i in list(events):
if self._events[i].is_set():
results[i] = self.parent_channels[i].recv()
self._events[i].clear()
consumed_indices.append(i)
events.discard(i)

out_next, out_root = zip(*(future for future in results))
return TensorDict.maybe_dense_stack(
out_next
), TensorDict.maybe_dense_stack(out_root)
return super().step_and_maybe_reset(tensordict)

# for i, _data in enumerate(tensordict.unbind(0)):
# self.parent_channels[i].send(("step_and_maybe_reset", _data))
#
# results = [None] * self.num_workers
#
# consumed_indices = []
# events = set(range(self.num_workers))
# while len(consumed_indices) < self.num_workers:
# for i in list(events):
# if self._events[i].is_set():
# results[i] = self.parent_channels[i].recv()
# self._events[i].clear()
# consumed_indices.append(i)
# events.discard(i)
#
# out_next, out_root = zip(*(future for future in results))
# return TensorDict.maybe_dense_stack(
# out_next
# ), TensorDict.maybe_dense_stack(out_root)

@torch.no_grad()
@_check_start
Expand Down Expand Up @@ -1578,6 +1583,7 @@ def _reset_no_buffers(
for i, channel in enumerate(self.parent_channels):
if not needs_resetting[i]:
out_tds.append(None)
continue
self._events[i].wait()
td = channel.recv()
out_tds.append(td)
Expand Down Expand Up @@ -2115,6 +2121,10 @@ def _run_worker_pipe_direct(
# we use 'data' to pass the keys that we need to pass to reset,
# because passing the entire buffer may have unwanted consequences
data, reset_kwargs = data
if data is not None:
data._fast_apply(
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
)
cur_td = env.reset(
tensordict=data,
**reset_kwargs,
Expand All @@ -2130,7 +2140,6 @@ def _run_worker_pipe_direct(
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
# No need to copy here since we don't write in-place
next_td = env._step(data)
if event is not None:
event.record()
Expand All @@ -2143,15 +2152,10 @@ def _run_worker_pipe_direct(
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
# We must copy the root shared td here, or at least get rid of done:
# if we don't `td is root_shared_tensordict`
# which means that root_shared_tensordict will carry the content of next
# in the next iteration. When using StepCounter, it will look for an
# existing done state, find it and consider the env as done by input (not
# by output) of the step!
# Caveat: for RNN we may need some keys of the "next" TD so we pass the list
# through data
td, root_next_td = env.step_and_maybe_reset(data.clone())
data._fast_apply(
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
)
td, root_next_td = env.step_and_maybe_reset(data)

if event is not None:
event.record()
Expand Down

0 comments on commit 6c7a6d0

Please sign in to comment.