From 781a5b29dad88cabd2ed761392eef141e8a96d7f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Jan 2024 08:50:56 +0000 Subject: [PATCH] [BugFix] Fix sampling of last item in SliceSampler (#1774) --- test/test_rb.py | 12 +++++++++++- torchrl/data/replay_buffers/samplers.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 8a5e191e5e4..206a54c5da8 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -1595,6 +1595,7 @@ def test_slice_sampler( "obs": torch.randn((3, 4, 5)).expand(100, 3, 4, 5), "act": torch.randn((20,)).expand(100, 20), "steps": steps, + "count": torch.arange(100), "other": torch.randn((20, 50)).expand(100, 20, 50), done_key: done, }, @@ -1621,7 +1622,8 @@ def test_slice_sampler( num_slices = batch_size // slice_len trajs_unique_id = set() too_short = False - for _ in range(5): + count_unique = set() + for _ in range(10): index, info = sampler.sample(storage, batch_size=batch_size) if _data_prefix: samples = storage._storage["_data"][index] @@ -1640,6 +1642,14 @@ def test_slice_sampler( trajs_unique_id = trajs_unique_id.union( samples["another_episode"].view(-1).tolist() ) + count_unique = count_unique.union(samples.get("count").view(-1).tolist()) + if len(count_unique) == 100: + # all items have been sampled + break + else: + raise AssertionError( + f"Not all items can be sampled: {set(range(100))-count_unique} are missing" + ) if strict_length: assert not too_short else: diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index 83bc6a632c6..cfa7900b819 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -760,7 +760,7 @@ def _sample_slices( relative_starts = ( ( torch.rand(num_slices, device=lengths.device) - * (lengths[traj_idx] - seq_length) + * (lengths[traj_idx] - seq_length + 1) ) .floor() .to(start_idx.dtype)