From 726e95955009c73dc0242424182222e59a9056d7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 7 Jun 2024 11:29:31 +0100 Subject: [PATCH] [BugFix] Fix done/terminated computation in slice samplers (#2213) --- test/test_rb.py | 21 ++++++++---- torchrl/data/replay_buffers/samplers.py | 43 +++++++++++++++++-------- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index af732df1568..9f1e198fcdf 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -32,6 +32,8 @@ ) from torch import multiprocessing as mp from torch.utils._pytree import tree_flatten, tree_map + +from torchrl._utils import _replace_last from torchrl.collectors import RandomPolicy, SyncDataCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import ( @@ -1974,7 +1976,7 @@ def test_slice_sampler( "count": torch.arange(100), "other": torch.randn((20, 50)).expand(100, 20, 50), done_key: done, - "terminated": done, + _replace_last(done_key, "terminated"): done, }, [100], device=device, @@ -1997,6 +1999,7 @@ def test_slice_sampler( end_key=done_key, slice_len=slice_len, strict_length=strict_length, + truncated_key=_replace_last(done_key, "truncated"), ) index = torch.arange(0, num_steps, 1) sampler.extend(index) @@ -2007,6 +2010,7 @@ def test_slice_sampler( end_key=done_key, slice_len=slice_len, strict_length=strict_length, + truncated_key=_replace_last(done_key, "truncated"), ) if slice_len is not None: num_slices = batch_size // slice_len @@ -2037,11 +2041,14 @@ def test_slice_sampler( ) count_unique = count_unique.union(samples.get("count").view(-1).tolist()) - truncated = info[("next", "truncated")] - terminated = info[("next", "terminated")] + truncated = info[_replace_last(done_key, "truncated")] + terminated = info[_replace_last(done_key, "terminated")] assert (truncated | terminated).view(num_slices, -1)[:, -1].all() - assert (terminated == samples["terminated"].view_as(terminated)).all() - done = info[("next", "done")] + assert ( + terminated + == samples[_replace_last(done_key, "terminated")].view_as(terminated) + ).all() + done = info[done_key] assert done.view(num_slices, -1)[:, -1].all() if len(count_unique) == 100: @@ -2197,10 +2204,10 @@ def test_slice_sampler_without_replacement( trajs_unique_id = trajs_unique_id.union( cur_episodes, ) - done = info[("next", "done")] - assert done.view(num_slices, -1)[:, -1].all() done_recon = info[("next", "truncated")] | info[("next", "terminated")] assert done_recon.view(num_slices, -1)[:, -1].all() + done = info[("next", "done")] + assert done.view(num_slices, -1)[:, -1].all() def test_slice_sampler_left_right(self): torch.manual_seed(0) diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 0b83c71ddb6..8d24c9c1409 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -1054,6 +1054,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] seq_length, num_slices, storage_length=storage_length, + storage=storage, ) def _sample_slices( @@ -1065,6 +1066,8 @@ def _sample_slices( num_slices: int, storage_length: int, traj_idx: torch.Tensor | None = None, + *, + storage, ) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]: # start_idx and stop_idx are 2d tensors organized like a non-zero @@ -1130,6 +1133,7 @@ def get_traj_idx(maxval): seq_length=seq_length, storage_length=storage_length, traj_idx=traj_idx, + storage=storage, ) def _get_index( @@ -1141,6 +1145,8 @@ def _get_index( num_slices: int, storage_length: int, traj_idx: torch.Tensor | None = None, + *, + storage, ) -> Tuple[torch.Tensor, dict]: # end_point is the last possible index for start last_indexable_start = lengths[traj_idx] - seq_length + 1 @@ -1209,13 +1215,17 @@ def _get_index( truncated.view(num_slices, -1)[:, -1] = 1 else: truncated[seq_length.cumsum(0) - 1] = 1 - terminated = ( - (index[:, 0].unsqueeze(0) == stop_idx[:, 0].unsqueeze(1)) - .any(0) - .unsqueeze(1) - ) - done = terminated | truncated - return index.to(torch.long).unbind(-1), { + index = index.to(torch.long).unbind(-1) + st_index = storage[index] + try: + done = st_index[done_key] | truncated + except KeyError: + done = truncated.clone() + try: + terminated = st_index[terminated_key] + except KeyError: + terminated = torch.zeros_like(truncated) + return index, { truncated_key: truncated, done_key: done, terminated_key: terminated, @@ -1454,6 +1464,7 @@ def tuple_to_tensor(traj_idx, lengths=lengths): num_slices, storage_length, traj_idx=tuple_to_tensor(indices), + storage=storage, ) return idx, info @@ -1763,12 +1774,16 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] truncated.view(num_slices, -1)[:, -1] = 1 else: truncated[seq_length.cumsum(0) - 1] = 1 - terminated = ( - (index[:, 0].unsqueeze(0) == stop_idx[:, 0].unsqueeze(1)) - .any(0) - .unsqueeze(1) - ) - done = terminated | truncated + index = index.to(torch.long).unbind(-1) + st_index = storage[index] + try: + done = st_index[done_key] | truncated + except KeyError: + done = truncated.clone() + try: + terminated = st_index[terminated_key] + except KeyError: + terminated = torch.zeros_like(truncated) info.update( { truncated_key: truncated, @@ -1776,7 +1791,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict] terminated_key: terminated, } ) - return index.to(torch.long).unbind(-1), info + return index, info return index.to(torch.long).unbind(-1), info def _empty(self):