Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix strict length in PRB+SliceSampler #2202

Merged
merged 17 commits into from
Jun 7, 2024
163 changes: 158 additions & 5 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,7 @@ def test_slice_sampler(
)
index = torch.arange(0, num_steps, 1)
sampler.extend(index)
sampler.update_priority(index, 1)
else:
sampler = SliceSampler(
num_slices=num_slices,
Expand All @@ -2013,16 +2014,20 @@ def test_slice_sampler(
trajs_unique_id = set()
too_short = False
count_unique = set()
for _ in range(30):
for _ in range(50):
index, info = sampler.sample(storage, batch_size=batch_size)
samples = storage._storage[index]
if strict_length:
# check that trajs are ok
samples = samples.view(num_slices, -1)

assert samples["another_episode"].unique(
dim=1
).squeeze().shape == torch.Size([num_slices])
unique_another_episode = (
samples["another_episode"].unique(dim=1).squeeze()
)
assert unique_another_episode.shape == torch.Size([num_slices]), (
num_slices,
samples,
)
assert (
samples["steps"][..., 1:] - 1 == samples["steps"][..., :-1]
).all()
Expand Down Expand Up @@ -2255,7 +2260,7 @@ def test_slice_sampler_left_right_ndim(self):
curr_eps = curr_eps[curr_eps != 0]
assert curr_eps.unique().numel() == 1

def test_slicesampler_strictlength(self):
def test_slice_sampler_strictlength(self):

torch.manual_seed(0)

Expand Down Expand Up @@ -2299,6 +2304,154 @@ def test_slicesampler_strictlength(self):
else:
assert len(sample["traj"].unique()) == 1

@pytest.mark.parametrize("ndim", [1, 2])
@pytest.mark.parametrize("strict_length", [True, False])
@pytest.mark.parametrize("circ", [False, True])
def test_slice_sampler_prioritized(self, ndim, strict_length, circ):
torch.manual_seed(0)
out = []
for t in range(5):
length = (t + 1) * 5
done = torch.zeros(length, 1, dtype=torch.bool)
done[-1] = 1
priority = 10 if t == 0 else 1
traj = TensorDict(
{
"traj": torch.full((length,), t),
"step_count": torch.arange(length),
"done": done,
"priority": torch.full((length,), priority),
},
batch_size=length,
)
out.append(traj)
data = torch.cat(out)
if ndim == 2:
data = torch.stack([data, data])
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(data.numel(), ndim=ndim),
sampler=PrioritizedSliceSampler(
max_capacity=data.numel(),
alpha=1.0,
beta=1.0,
end_key="done",
slice_len=10,
strict_length=strict_length,
cache_values=True,
),
batch_size=50,
)
if not circ:
# Simplest case: the buffer is full but no overlap
index = rb.extend(data)
else:
# The buffer is 2/3 -> 1/3 overlapping
rb.extend(data[..., : data.shape[-1] // 3])
index = rb.extend(data)
rb.update_priority(index, data["priority"])
samples = []
found_shorter_batch = False
for _ in range(100):
samples.append(rb.sample())
if samples[-1].numel() < 50:
found_shorter_batch = True
samples = torch.cat(samples)
if strict_length:
assert not found_shorter_batch
else:
assert found_shorter_batch
# the first trajectory has a very high priority, but should only appear
# if strict_length=False.
if strict_length:
assert (samples["traj"] != 0).all(), samples["traj"].unique()
else:
assert (samples["traj"] == 0).any()
# Check that all samples of the first traj contain all elements (since it's too short to fullfill 10 elts)
sc = samples[samples["traj"] == 0]["step_count"]
assert (sc == 0).sum() == (sc == 1).sum()
assert (sc == 0).sum() == (sc == 4).sum()
assert rb._sampler._cache
rb.extend(data)
assert not rb._sampler._cache

@pytest.mark.parametrize("ndim", [1, 2])
@pytest.mark.parametrize("strict_length", [True, False])
@pytest.mark.parametrize("circ", [False, True])
@pytest.mark.parametrize(
"span", [False, [False, False], [False, True], 3, [False, 3]]
)
def test_slice_sampler_prioritized_span(self, ndim, strict_length, circ, span):
torch.manual_seed(0)
out = []
# 5 trajs of length 3, 6, 9, 12 and 15
for t in range(5):
length = (t + 1) * 3
done = torch.zeros(length, 1, dtype=torch.bool)
done[-1] = 1
priority = 1
traj = TensorDict(
{
"traj": torch.full((length,), t),
"step_count": torch.arange(length),
"done": done,
"priority": torch.full((length,), priority),
},
batch_size=length,
)
out.append(traj)
data = torch.cat(out)
if ndim == 2:
data = torch.stack([data, data])
rb = TensorDictReplayBuffer(
storage=LazyTensorStorage(data.numel(), ndim=ndim),
sampler=PrioritizedSliceSampler(
max_capacity=data.numel(),
alpha=1.0,
beta=1.0,
end_key="done",
slice_len=5,
strict_length=strict_length,
cache_values=True,
span=span,
),
batch_size=5,
)
if not circ:
# Simplest case: the buffer is full but no overlap
index = rb.extend(data)
else:
# The buffer is 2/3 -> 1/3 overlapping
rb.extend(data[..., : data.shape[-1] // 3])
index = rb.extend(data)
rb.update_priority(index, data["priority"])
found_traj_0 = False
found_traj_4_truncated_left = False
found_traj_4_truncated_right = False
for i, s in enumerate(rb):
t = s["traj"].unique().tolist()
assert len(t) == 1
t = t[0]
if t == 0:
found_traj_0 = True
if t == 4 and s.numel() < 5:
if s["step_count"][0] > 10:
found_traj_4_truncated_right = True
if s["step_count"][0] == 0:
found_traj_4_truncated_left = True
if i == 1000:
break
assert not rb._sampler.span[0]
# if rb._sampler.span[0]:
# assert found_traj_4_truncated_left
if rb._sampler.span[1]:
assert found_traj_4_truncated_right
else:
assert not found_traj_4_truncated_right
if strict_length and not rb._sampler.span[1]:
assert not found_traj_0
else:
assert found_traj_0


def test_prioritized_slice_sampler_doc_example():
sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)
Expand Down
16 changes: 11 additions & 5 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,17 +564,18 @@ def update_priority(
index: Union[int, torch.Tensor],
priority: Union[int, torch.Tensor],
) -> None:
if self.dim_extend > 0 and priority.ndim > 1:
priority = self._transpose(priority).flatten()
# priority = priority.flatten()
with self._replay_lock:
self._sampler.update_priority(index, priority)
self._sampler.update_priority(index, priority, storage=self.storage)

@pin_memory_output
def _sample(self, batch_size: int) -> Tuple[Any, dict]:
with self._replay_lock:
index, info = self._sampler.sample(self._storage, batch_size)
info["index"] = index
data = self._storage.get(index)
# if self.dim_extend > 0:
# data = self._transpose(data)
if not isinstance(index, INT_CLASSES):
data = self._collate_fn(data)
if self._transform is not None and len(self._transform):
Expand Down Expand Up @@ -643,7 +644,7 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An
return ret[0]

def mark_update(self, index: Union[int, torch.Tensor]) -> None:
self._sampler.mark_update(index)
self._sampler.mark_update(index, storage=self._storage)

def append_transform(
self, transform: "Transform", *, invert: bool = False # noqa-F821
Expand Down Expand Up @@ -1105,8 +1106,13 @@ def extend(self, tensordicts: TensorDictBase) -> torch.Tensor:
return torch.zeros((0, self._storage.ndim), dtype=torch.long)

index = super()._extend(tensordicts)

# TODO: to be usable directly, the indices should be flipped but the issue
# is that just doing this results in indices that are not sorted like the original data
# so the actualy indices will have to be used on the _storage directly (not on the buffer)
self._set_index_in_td(tensordicts, index)
self.update_tensordict_priority(tensordicts)
# TODO: in principle this is a good idea but currently it doesn't work + it re-writes a priority that has just been written
# self.update_tensordict_priority(tensordicts)
return index

def _set_index_in_td(self, tensordict, index):
Expand Down
Loading
Loading