diff --git a/test/test_rb.py b/test/test_rb.py index 5a849224117..35fc37b400c 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -2514,14 +2514,15 @@ def _make_data(self, datatype, datadim): @pytest.mark.parametrize("datatype,rbtype", datatype_rb_pairs) @pytest.mark.parametrize("datadim", [1, 2]) - def test_rb_multidim(self, datatype, datadim, rbtype): + @pytest.mark.parametrize("storage_cls", [LazyMemmapStorage, LazyTensorStorage]) + def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls): data = self._make_data(datatype, datadim) if rbtype not in (PrioritizedReplayBuffer, TensorDictPrioritizedReplayBuffer): rbtype = functools.partial(rbtype, sampler=RandomSampler()) else: rbtype = functools.partial(rbtype, alpha=0.9, beta=1.1) - rb = rbtype(storage=LazyMemmapStorage(100, ndim=datadim), batch_size=4) + rb = rbtype(storage=storage_cls(100, ndim=datadim), batch_size=4) rb.extend(data) assert len(rb) == 12 s = rb.sample() diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index d507e02b24b..af43771ac82 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -253,35 +253,30 @@ def __init__( self._batch_size = batch_size if dim_extend is not None and dim_extend < 0: raise ValueError("dim_extend must be a positive value.") - self._dim_extend = dim_extend - if self.dim_extend > 0: - from torchrl.envs.transforms.transforms import _TransposeTransform - - if self._storage.ndim <= self.dim_extend: - raise ValueError( - "The storage `ndim` attribute must be greater " - "than the `dim_extend` attribute of the buffer." - ) - self.append_transform(_TransposeTransform(self.dim_extend)) + self.dim_extend = dim_extend @property def dim_extend(self): - dim_extend = self._dim_extend - if dim_extend is None: - if self._storage is not None: - ndim = self._storage.ndim - dim_extend = ndim - 1 - else: - dim_extend = 1 - self.dim_extend = dim_extend - return dim_extend + return self._dim_extend @dim_extend.setter def dim_extend(self, value): - if self._dim_extend is not None and self._dim_extend != value: + if ( + hasattr(self, "_dim_extend") + and self._dim_extend is not None + and self._dim_extend != value + ): raise RuntimeError( "dim_extend cannot be reset. Please create a new replay buffer." ) + + if value is None: + if self._storage is not None: + ndim = self._storage.ndim + value = ndim - 1 + else: + value = 1 + self._dim_extend = value if value is not None and value > 0: from torchrl.envs.transforms.transforms import _TransposeTransform diff --git a/torchrl/data/replay_buffers/samplers.py b/torchrl/data/replay_buffers/samplers.py index c40921e3c99..690a96e5bf7 100644 --- a/torchrl/data/replay_buffers/samplers.py +++ b/torchrl/data/replay_buffers/samplers.py @@ -910,7 +910,8 @@ def _sample_slices( if (lengths < seq_length).any(): if self.strict_length: raise RuntimeError( - "Some stored trajectories have a length shorter than the slice that was asked for. " + "Some stored trajectories have a length shorter than the slice that was asked for (" + f"min length={lengths.min()}). " "Create the sampler with `strict_length=False` to allow shorter trajectories to appear " "in you batch." ) diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index b1dda223f17..186810315c3 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -147,6 +147,14 @@ def shape(self): f"Please report this exception as well as the use case (incl. buffer construction) on github." ) + def _max_size_along_dim0(self, *, single_data, batched_data): + if self.ndim == 1: + return self.max_size + raise RuntimeError( + f"storage._max_size_along_dim0 is not supported for storages of type {type(self)} when ndim > 1." + f"Please report this exception as well as the use case (incl. buffer construction) on github." + ) + def flatten(self): if self.ndim == 1: return self @@ -475,7 +483,7 @@ def _len(self, value): def _total_shape(self): # Total shape, irrespective of how full the storage is _total_shape = self.__dict__.get("_total_shape_value", None) - if _total_shape is None: + if _total_shape is None and self.initialized: if is_tensor_collection(self._storage): _total_shape = self._storage.shape[: self.ndim] else: @@ -497,12 +505,29 @@ def _len_along_dim0(self): len_along_dim = len_along_dim // self._total_shape[1:].numel() return len_along_dim - @property - def _max_size_along_dim0(self): + def _max_size_along_dim0(self, *, single_data=None, batched_data=None): # returns the max_size of the buffer along dim0 max_size = self.max_size if self.ndim: - max_size = max_size // self._total_shape[1:].numel() + shape = self.shape + if shape is None: + if single_data is not None: + data = single_data + elif batched_data is not None: + data = batched_data + else: + raise ValueError("single_data or batched_data must be passed.") + if is_tensor_collection(data): + datashape = data.shape[: self.ndim] + else: + for leaf in torch.utils._pytree.tree_leaves(data): + datashape = leaf.shape[: self.ndim] + break + if batched_data is not None: + datashape = datashape[1:] + max_size = max_size // datashape.numel() + else: + max_size = max_size // self._total_shape[1:].numel() return max_size @property @@ -511,7 +536,8 @@ def shape(self): if self._is_full: return self._total_shape _total_shape = self._total_shape - return torch.Size([self._len_along_dim0] + list(_total_shape[1:])) + if _total_shape is not None: + return torch.Size([self._len_along_dim0] + list(_total_shape[1:])) def _rand_given_ndim(self, batch_size): if self.ndim == 1: @@ -815,7 +841,10 @@ def _init( def max_size_along_dim0(data_shape): if self.ndim > 1: - return (-(self.max_size // data_shape[: self.ndim - 1]), *data_shape) + return ( + -(self.max_size // -data_shape[: self.ndim - 1].numel()), + *data_shape, + ) return (self.max_size, *data_shape) if is_tensor_collection(data): diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py index 9c7ba6f4fb7..bdaa5d573fb 100644 --- a/torchrl/data/replay_buffers/writers.py +++ b/torchrl/data/replay_buffers/writers.py @@ -134,7 +134,9 @@ def add(self, data: Any) -> int | torch.Tensor: index = self._cursor _cursor = self._cursor # we need to update the cursor first to avoid race conditions between workers - self._cursor = (self._cursor + 1) % self._storage.max_size + self._cursor = (self._cursor + 1) % self._storage._max_size_along_dim0( + single_data=data + ) self._storage[_cursor] = data return self._replicate_index(index) @@ -149,14 +151,15 @@ def extend(self, data: Sequence) -> torch.Tensor: if batch_size == 0: raise RuntimeError("Expected at least one element in extend.") device = data.device if hasattr(data, "device") else None + max_size_along0 = self._storage._max_size_along_dim0(batched_data=data) index = ( torch.arange( cur_size, batch_size + cur_size, dtype=torch.long, device=device ) - % self._storage.max_size + % max_size_along0 ) # we need to update the cursor first to avoid race conditions between workers - self._cursor = (batch_size + cur_size) % self._storage.max_size + self._cursor = (batch_size + cur_size) % max_size_along0 self._storage[index] = data return self._replicate_index(index) @@ -205,7 +208,7 @@ class TensorDictRoundRobinWriter(RoundRobinWriter): def add(self, data: Any) -> int | torch.Tensor: index = self._cursor # we need to update the cursor first to avoid race conditions between workers - self._cursor = (index + 1) % self._storage.max_size + self._cursor = (index + 1) % self._storage.shape[0] if not is_tensorclass(data): data.set( "index", @@ -224,10 +227,10 @@ def extend(self, data: Sequence) -> torch.Tensor: torch.arange( cur_size, batch_size + cur_size, dtype=torch.long, device=device ) - % self._storage.max_size + % self._storage.shape[0] ) # we need to update the cursor first to avoid race conditions between workers - self._cursor = (batch_size + cur_size) % self._storage.max_size + self._cursor = (batch_size + cur_size) % self._storage.shape[0] # storage must convert the data to the appropriate format if needed if not is_tensorclass(data): data.set(