Skip to content

Commit

Permalink
[BugFix] Fix done/terminated computation in slice samplers (#2213)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 7, 2024
1 parent d934153 commit 726e959
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 21 deletions.
21 changes: 14 additions & 7 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
)
from torch import multiprocessing as mp
from torch.utils._pytree import tree_flatten, tree_map

from torchrl._utils import _replace_last
from torchrl.collectors import RandomPolicy, SyncDataCollector
from torchrl.collectors.utils import split_trajectories
from torchrl.data import (
Expand Down Expand Up @@ -1974,7 +1976,7 @@ def test_slice_sampler(
"count": torch.arange(100),
"other": torch.randn((20, 50)).expand(100, 20, 50),
done_key: done,
"terminated": done,
_replace_last(done_key, "terminated"): done,
},
[100],
device=device,
Expand All @@ -1997,6 +1999,7 @@ def test_slice_sampler(
end_key=done_key,
slice_len=slice_len,
strict_length=strict_length,
truncated_key=_replace_last(done_key, "truncated"),
)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
Expand All @@ -2007,6 +2010,7 @@ def test_slice_sampler(
end_key=done_key,
slice_len=slice_len,
strict_length=strict_length,
truncated_key=_replace_last(done_key, "truncated"),
)
if slice_len is not None:
num_slices = batch_size // slice_len
Expand Down Expand Up @@ -2037,11 +2041,14 @@ def test_slice_sampler(
)
count_unique = count_unique.union(samples.get("count").view(-1).tolist())

truncated = info[("next", "truncated")]
terminated = info[("next", "terminated")]
truncated = info[_replace_last(done_key, "truncated")]
terminated = info[_replace_last(done_key, "terminated")]
assert (truncated | terminated).view(num_slices, -1)[:, -1].all()
assert (terminated == samples["terminated"].view_as(terminated)).all()
done = info[("next", "done")]
assert (
terminated
== samples[_replace_last(done_key, "terminated")].view_as(terminated)
).all()
done = info[done_key]
assert done.view(num_slices, -1)[:, -1].all()

if len(count_unique) == 100:
Expand Down Expand Up @@ -2197,10 +2204,10 @@ def test_slice_sampler_without_replacement(
trajs_unique_id = trajs_unique_id.union(
cur_episodes,
)
done = info[("next", "done")]
assert done.view(num_slices, -1)[:, -1].all()
done_recon = info[("next", "truncated")] | info[("next", "terminated")]
assert done_recon.view(num_slices, -1)[:, -1].all()
done = info[("next", "done")]
assert done.view(num_slices, -1)[:, -1].all()

def test_slice_sampler_left_right(self):
torch.manual_seed(0)
Expand Down
43 changes: 29 additions & 14 deletions torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,7 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
seq_length,
num_slices,
storage_length=storage_length,
storage=storage,
)

def _sample_slices(
Expand All @@ -1065,6 +1066,8 @@ def _sample_slices(
num_slices: int,
storage_length: int,
traj_idx: torch.Tensor | None = None,
*,
storage,
) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]:
# start_idx and stop_idx are 2d tensors organized like a non-zero

Expand Down Expand Up @@ -1130,6 +1133,7 @@ def get_traj_idx(maxval):
seq_length=seq_length,
storage_length=storage_length,
traj_idx=traj_idx,
storage=storage,
)

def _get_index(
Expand All @@ -1141,6 +1145,8 @@ def _get_index(
num_slices: int,
storage_length: int,
traj_idx: torch.Tensor | None = None,
*,
storage,
) -> Tuple[torch.Tensor, dict]:
# end_point is the last possible index for start
last_indexable_start = lengths[traj_idx] - seq_length + 1
Expand Down Expand Up @@ -1209,13 +1215,17 @@ def _get_index(
truncated.view(num_slices, -1)[:, -1] = 1
else:
truncated[seq_length.cumsum(0) - 1] = 1
terminated = (
(index[:, 0].unsqueeze(0) == stop_idx[:, 0].unsqueeze(1))
.any(0)
.unsqueeze(1)
)
done = terminated | truncated
return index.to(torch.long).unbind(-1), {
index = index.to(torch.long).unbind(-1)
st_index = storage[index]
try:
done = st_index[done_key] | truncated
except KeyError:
done = truncated.clone()
try:
terminated = st_index[terminated_key]
except KeyError:
terminated = torch.zeros_like(truncated)
return index, {
truncated_key: truncated,
done_key: done,
terminated_key: terminated,
Expand Down Expand Up @@ -1454,6 +1464,7 @@ def tuple_to_tensor(traj_idx, lengths=lengths):
num_slices,
storage_length,
traj_idx=tuple_to_tensor(indices),
storage=storage,
)
return idx, info

Expand Down Expand Up @@ -1763,20 +1774,24 @@ def sample(self, storage: Storage, batch_size: int) -> Tuple[torch.Tensor, dict]
truncated.view(num_slices, -1)[:, -1] = 1
else:
truncated[seq_length.cumsum(0) - 1] = 1
terminated = (
(index[:, 0].unsqueeze(0) == stop_idx[:, 0].unsqueeze(1))
.any(0)
.unsqueeze(1)
)
done = terminated | truncated
index = index.to(torch.long).unbind(-1)
st_index = storage[index]
try:
done = st_index[done_key] | truncated
except KeyError:
done = truncated.clone()
try:
terminated = st_index[terminated_key]
except KeyError:
terminated = torch.zeros_like(truncated)
info.update(
{
truncated_key: truncated,
done_key: done,
terminated_key: terminated,
}
)
return index.to(torch.long).unbind(-1), info
return index, info
return index.to(torch.long).unbind(-1), info

def _empty(self):
Expand Down

0 comments on commit 726e959

Please sign in to comment.