Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into openx
Browse files Browse the repository at this point in the history
# Conflicts:
#	test/test_rb.py
  • Loading branch information
vmoens committed Jan 8, 2024
2 parents e39638b + 781a5b2 commit fc2b1ab
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
12 changes: 11 additions & 1 deletion test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,6 +1595,7 @@ def test_slice_sampler(
"obs": torch.randn((3, 4, 5)).expand(100, 3, 4, 5),
"act": torch.randn((20,)).expand(100, 20),
"steps": steps,
"count": torch.arange(100),
"other": torch.randn((20, 50)).expand(100, 20, 50),
done_key: done,
},
Expand All @@ -1621,7 +1622,8 @@ def test_slice_sampler(
num_slices = batch_size // slice_len
trajs_unique_id = set()
too_short = False
for _ in range(20):
count_unique = set()
for _ in range(10):
index, info = sampler.sample(storage, batch_size=batch_size)
if _data_prefix:
samples = storage._storage["_data"][index]
Expand All @@ -1640,6 +1642,14 @@ def test_slice_sampler(
trajs_unique_id = trajs_unique_id.union(
samples["another_episode"].view(-1).tolist()
)
count_unique = count_unique.union(samples.get("count").view(-1).tolist())
if len(count_unique) == 100:
# all items have been sampled
break
else:
raise AssertionError(
f"Not all items can be sampled: {set(range(100))-count_unique} are missing"
)
if strict_length:
assert not too_short
else:
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def _sample_slices(
relative_starts = (
(
torch.rand(num_slices, device=lengths.device)
* (lengths[traj_idx] - seq_length)
* (lengths[traj_idx] - seq_length + 1)
)
.floor()
.to(start_idx.dtype)
Expand Down

0 comments on commit fc2b1ab

Please sign in to comment.