Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 16, 2024
1 parent 94a9b61 commit 358a15f
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 35 deletions.
5 changes: 3 additions & 2 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2514,14 +2514,15 @@ def _make_data(self, datatype, datadim):

@pytest.mark.parametrize("datatype,rbtype", datatype_rb_pairs)
@pytest.mark.parametrize("datadim", [1, 2])
def test_rb_multidim(self, datatype, datadim, rbtype):
@pytest.mark.parametrize("storage_cls", [LazyMemmapStorage, LazyTensorStorage])
def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls):
data = self._make_data(datatype, datadim)
if rbtype not in (PrioritizedReplayBuffer, TensorDictPrioritizedReplayBuffer):
rbtype = functools.partial(rbtype, sampler=RandomSampler())
else:
rbtype = functools.partial(rbtype, alpha=0.9, beta=1.1)

rb = rbtype(storage=LazyMemmapStorage(100, ndim=datadim), batch_size=4)
rb = rbtype(storage=storage_cls(100, ndim=datadim), batch_size=4)
rb.extend(data)
assert len(rb) == 12
s = rb.sample()
Expand Down
35 changes: 15 additions & 20 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,35 +253,30 @@ def __init__(
self._batch_size = batch_size
if dim_extend is not None and dim_extend < 0:
raise ValueError("dim_extend must be a positive value.")
self._dim_extend = dim_extend
if self.dim_extend > 0:
from torchrl.envs.transforms.transforms import _TransposeTransform

if self._storage.ndim <= self.dim_extend:
raise ValueError(
"The storage `ndim` attribute must be greater "
"than the `dim_extend` attribute of the buffer."
)
self.append_transform(_TransposeTransform(self.dim_extend))
self.dim_extend = dim_extend

@property
def dim_extend(self):
dim_extend = self._dim_extend
if dim_extend is None:
if self._storage is not None:
ndim = self._storage.ndim
dim_extend = ndim - 1
else:
dim_extend = 1
self.dim_extend = dim_extend
return dim_extend
return self._dim_extend

@dim_extend.setter
def dim_extend(self, value):
if self._dim_extend is not None and self._dim_extend != value:
if (
hasattr(self, "_dim_extend")
and self._dim_extend is not None
and self._dim_extend != value
):
raise RuntimeError(
"dim_extend cannot be reset. Please create a new replay buffer."
)

if value is None:
if self._storage is not None:
ndim = self._storage.ndim
value = ndim - 1
else:
value = 1

self._dim_extend = value
if value is not None and value > 0:
from torchrl.envs.transforms.transforms import _TransposeTransform
Expand Down
3 changes: 2 additions & 1 deletion torchrl/data/replay_buffers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,8 @@ def _sample_slices(
if (lengths < seq_length).any():
if self.strict_length:
raise RuntimeError(
"Some stored trajectories have a length shorter than the slice that was asked for. "
"Some stored trajectories have a length shorter than the slice that was asked for ("
f"min length={lengths.min()}). "
"Create the sampler with `strict_length=False` to allow shorter trajectories to appear "
"in you batch."
)
Expand Down
41 changes: 35 additions & 6 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,14 @@ def shape(self):
f"Please report this exception as well as the use case (incl. buffer construction) on github."
)

def _max_size_along_dim0(self, *, single_data, batched_data):
if self.ndim == 1:
return self.max_size
raise RuntimeError(
f"storage._max_size_along_dim0 is not supported for storages of type {type(self)} when ndim > 1."
f"Please report this exception as well as the use case (incl. buffer construction) on github."
)

def flatten(self):
if self.ndim == 1:
return self
Expand Down Expand Up @@ -475,7 +483,7 @@ def _len(self, value):
def _total_shape(self):
# Total shape, irrespective of how full the storage is
_total_shape = self.__dict__.get("_total_shape_value", None)
if _total_shape is None:
if _total_shape is None and self.initialized:
if is_tensor_collection(self._storage):
_total_shape = self._storage.shape[: self.ndim]
else:
Expand All @@ -497,12 +505,29 @@ def _len_along_dim0(self):
len_along_dim = len_along_dim // self._total_shape[1:].numel()
return len_along_dim

@property
def _max_size_along_dim0(self):
def _max_size_along_dim0(self, *, single_data=None, batched_data=None):
# returns the max_size of the buffer along dim0
max_size = self.max_size
if self.ndim:
max_size = max_size // self._total_shape[1:].numel()
shape = self.shape
if shape is None:
if single_data is not None:
data = single_data
elif batched_data is not None:
data = batched_data
else:
raise ValueError("single_data or batched_data must be passed.")
if is_tensor_collection(data):
datashape = data.shape[: self.ndim]
else:
for leaf in torch.utils._pytree.tree_leaves(data):
datashape = leaf.shape[: self.ndim]
break
if batched_data is not None:
datashape = datashape[1:]
max_size = max_size // datashape.numel()
else:
max_size = max_size // self._total_shape[1:].numel()
return max_size

@property
Expand All @@ -511,7 +536,8 @@ def shape(self):
if self._is_full:
return self._total_shape
_total_shape = self._total_shape
return torch.Size([self._len_along_dim0] + list(_total_shape[1:]))
if _total_shape is not None:
return torch.Size([self._len_along_dim0] + list(_total_shape[1:]))

def _rand_given_ndim(self, batch_size):
if self.ndim == 1:
Expand Down Expand Up @@ -815,7 +841,10 @@ def _init(

def max_size_along_dim0(data_shape):
if self.ndim > 1:
return (-(self.max_size // data_shape[: self.ndim - 1]), *data_shape)
return (
-(self.max_size // -data_shape[: self.ndim - 1].numel()),
*data_shape,
)
return (self.max_size, *data_shape)

if is_tensor_collection(data):
Expand Down
15 changes: 9 additions & 6 deletions torchrl/data/replay_buffers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def add(self, data: Any) -> int | torch.Tensor:
index = self._cursor
_cursor = self._cursor
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (self._cursor + 1) % self._storage.max_size
self._cursor = (self._cursor + 1) % self._storage._max_size_along_dim0(
single_data=data
)
self._storage[_cursor] = data
return self._replicate_index(index)

Expand All @@ -149,14 +151,15 @@ def extend(self, data: Sequence) -> torch.Tensor:
if batch_size == 0:
raise RuntimeError("Expected at least one element in extend.")
device = data.device if hasattr(data, "device") else None
max_size_along0 = self._storage._max_size_along_dim0(batched_data=data)
index = (
torch.arange(
cur_size, batch_size + cur_size, dtype=torch.long, device=device
)
% self._storage.max_size
% max_size_along0
)
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % self._storage.max_size
self._cursor = (batch_size + cur_size) % max_size_along0
self._storage[index] = data
return self._replicate_index(index)

Expand Down Expand Up @@ -205,7 +208,7 @@ class TensorDictRoundRobinWriter(RoundRobinWriter):
def add(self, data: Any) -> int | torch.Tensor:
index = self._cursor
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (index + 1) % self._storage.max_size
self._cursor = (index + 1) % self._storage.shape[0]
if not is_tensorclass(data):
data.set(
"index",
Expand All @@ -224,10 +227,10 @@ def extend(self, data: Sequence) -> torch.Tensor:
torch.arange(
cur_size, batch_size + cur_size, dtype=torch.long, device=device
)
% self._storage.max_size
% self._storage.shape[0]
)
# we need to update the cursor first to avoid race conditions between workers
self._cursor = (batch_size + cur_size) % self._storage.max_size
self._cursor = (batch_size + cur_size) % self._storage.shape[0]
# storage must convert the data to the appropriate format if needed
if not is_tensorclass(data):
data.set(
Expand Down

0 comments on commit 358a15f

Please sign in to comment.