From 68dec1952ee688cc35b28e22d0ec104bfa05e0d3 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Mon, 2 Sep 2024 18:50:26 +0200 Subject: [PATCH] Moved subbuffer-related functionality from Collector to Buffer Added several tests for the new functionality Also better names, documentation and input-validation in ReplayBuffer.add --- docs/spelling_wordlist.txt | 2 + pyproject.toml | 1 + test/base/test_buffer.py | 43 +++++++++ tianshou/data/buffer/base.py | 170 +++++++++++++++++++++++++++++++---- tianshou/data/collector.py | 123 +++---------------------- 5 files changed, 211 insertions(+), 128 deletions(-) diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b49094d03..40aa69970 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -290,3 +290,5 @@ subclass subclassing dist dists +subbuffer +subbuffers diff --git a/pyproject.toml b/pyproject.toml index 3d336d9cf..5640f948b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -181,6 +181,7 @@ ignore = [ "PLW2901", # overwrite vars in loop "B027", # empty and non-abstract method in abstract class "D404", # It's fine to start with "This" in docstrings + "D407", "D408", "D409", # Ruff rules for underlines under 'Example:' and so clash with Sphinx ] unfixable = [ "F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index fe4301bea..48a1da90c 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -1503,3 +1503,46 @@ def test_buffer_dropnull() -> None: buf.dropnull() assert len(buf[:3]) == 3 assert not buf.hasnull() + + +@pytest.fixture +def dummy_rollout_batch() -> RolloutBatchProtocol: + return cast( + RolloutBatchProtocol, + Batch( + obs=np.arange(2), + obs_next=np.arange(2), + act=np.arange(5), + rew=1, + terminated=False, + truncated=False, + done=False, + info={}, + ), + ) + + +def test_get_replay_buffer_indices(dummy_rollout_batch: RolloutBatchProtocol) -> None: + buffer = ReplayBuffer(5) + for _ in range(5): + buffer.add(dummy_rollout_batch) + assert np.array_equal(buffer.get_buffer_indices(0, 3), [0, 1, 2]) + assert np.array_equal(buffer.get_buffer_indices(3, 2), [3, 4, 0, 1]) + + +def test_get_vector_replay_buffer_indices(dummy_rollout_batch: RolloutBatchProtocol) -> None: + stacked_batch = Batch.stack([dummy_rollout_batch, dummy_rollout_batch]) + buffer = VectorReplayBuffer(10, 2) + for _ in range(5): + buffer.add(stacked_batch) + + assert np.array_equal(buffer.get_buffer_indices(0, 3), [0, 1, 2]) + assert np.array_equal(buffer.get_buffer_indices(3, 2), [3, 4, 0, 1]) + + assert np.array_equal(buffer.get_buffer_indices(6, 9), [6, 7, 8]) + assert np.array_equal(buffer.get_buffer_indices(8, 7), [8, 9, 5, 6]) + + with pytest.raises(ValueError): + buffer.get_buffer_indices(3, 6) + with pytest.raises(ValueError): + buffer.get_buffer_indices(6, 3) diff --git a/tianshou/data/buffer/base.py b/tianshou/data/buffer/base.py index 1ccf956e0..f7b60054d 100644 --- a/tianshou/data/buffer/base.py +++ b/tianshou/data/buffer/base.py @@ -1,4 +1,5 @@ -from typing import Any, Self, TypeVar, cast +from collections.abc import Sequence +from typing import Any, ClassVar, Self, TypeVar, cast import h5py import numpy as np @@ -60,6 +61,14 @@ class ReplayBuffer: "info", "policy", ) + _required_keys_for_add: ClassVar[set[str]] = { + "obs", + "act", + "rew", + "terminated", + "truncated", + "done", + } def __init__( self, @@ -103,6 +112,111 @@ def subbuffer_edges(self) -> np.ndarray: """ return np.array([0, self.maxsize], dtype=int) + def _get_start_stop_tuples_for_edge_crossing_interval( + self, + start: int, + stop: int, + ) -> tuple[tuple[int, int], tuple[int, int]]: + """Assumes that stop < start and retrieves tuples corresponding to the two + slices that determine the interval within the buffer. + + Example: + ------- + >>> list(self.subbuffer_edges) == [0, 5, 10] + >>> start = 4 + >>> stop = 2 + >>> self._get_start_stop_tuples_for_edge_crossing_interval(start, stop) + ((4, 5), (0, 2)) + + The buffer sliced from 4 to 5 and then from 0 to 2 will contain the transitions + corresponding to the provided start and stop values. + """ + if stop >= start: + raise ValueError( + f"Expected stop < start, but got {start=}, {stop=}. " + f"For stop larger than start this method should never be called, " + f"and stop=start should never occur. This can occur either due to an implementation error, " + f"or due a bad configuration of the buffer that resulted in a single episode being so long that " + f"it completely filled a subbuffer (of size len(buffer)/degree_of_vectorization). " + f"Consider either shortening the episode, increasing the size of the buffer, or decreasing the " + f"degree of vectorization.", + ) + subbuffer_edges = cast(Sequence[int], self.subbuffer_edges) + + edge_after_start_idx = int(np.searchsorted(subbuffer_edges, start, side="left")) + """This is the crossed edge""" + + if edge_after_start_idx == 0: + raise ValueError( + f"The start value should be larger than the first edge, but got {start=}, {subbuffer_edges[1]=}.", + ) + edge_after_start = subbuffer_edges[edge_after_start_idx] + edge_before_stop = subbuffer_edges[edge_after_start_idx - 1] + """It's the edge before the crossed edge""" + + if edge_before_stop >= stop: + raise ValueError( + f"The edge before the crossed edge should be smaller than the stop, but got {edge_before_stop=}, {stop=}.", + ) + return (start, edge_after_start), (edge_before_stop, stop) + + def get_buffer_indices(self, start: int, stop: int) -> np.ndarray: + """Get the indices of the transitions in the buffer between start and stop. + + The special thing about this is that stop may actually be smaller than start, + since one often is interested in a sequence of transitions that goes over a subbuffer edge. + + The main use case for this method is to retrieve an episode from the buffer, in which case + start is the index of the first transition in the episode and stop is the index where `done` is True + 1. + This can be done with the following code: + + .. code-block:: python + + episode_indices = buffer.get_buffer_indices(episode_start_index, episode_done_index + 1) + episode = buffer[episode_indices] + + Even when `start` is smaller than `stop`, it will be validated that they are in the same subbuffer. + + Example: + -------- + >>> list(buffer.subbuffer_edges) == [0, 5, 10] + >>> buffer.get_buffer_indices(start=2, stop=4) + [2, 3] + >>> buffer.get_buffer_indices(start=4, stop=2) + [4, 0, 1] + >>> buffer.get_buffer_indices(start=8, stop=7) + [8, 9, 5, 6] + >>> buffer.get_buffer_indices(start=1, stop=6) + ValueError: Start and stop indices must be within the same subbuffer. + >>> buffer.get_buffer_indices(start=8, stop=1) + ValueError: Start and stop indices must be within the same subbuffer. + + :param start: The start index of the interval. + :param stop: The stop index of the interval. + :return: The indices of the transitions in the buffer between start and stop. + """ + start_left_edge = np.searchsorted(self.subbuffer_edges, start, side="right") - 1 + stop_left_edge = np.searchsorted(self.subbuffer_edges, stop, side="right") - 1 + if start_left_edge != stop_left_edge: + raise ValueError( + f"Start and stop indices must be within the same subbuffer. " + f"Got {start=} in subbuffer edge {start_left_edge} and {stop=} in subbuffer edge {stop_left_edge}.", + ) + if stop > start: + return np.arange(start, stop, dtype=int) + else: + (start, upper_edge), ( + lower_edge, + stop, + ) = self._get_start_stop_tuples_for_edge_crossing_interval( + start, + stop, + ) + log.debug(f"{start=}, {upper_edge=}, {lower_edge=}, {stop=}") + return np.concatenate( + (np.arange(start, upper_edge, dtype=int), np.arange(lower_edge, stop, dtype=int)), + ) + def __len__(self) -> int: return self._size @@ -297,43 +411,69 @@ def add( :param batch: the input data batch. "obs", "act", "rew", "terminated", "truncated" are required keys. - :param buffer_ids: to make consistent with other buffer's add function; if it - is not None, we assume the input batch's first dimension is always 1. + :param buffer_ids: id's of subbuffers, allowed here to be consistent with classes similar to + :class:`~tianshou.data.buffer.vecbuf.VectorReplayBuffer`. Since the `ReplayBuffer` + has a single subbuffer, if this is not None, it must be a single element with value 0. + In that case, the batch is expected to have the shape (1, len(data)). + Failure to adhere to this will result in a `ValueError`. - Return (current_index, episode_return, episode_length, episode_start_index). If + Return `(current_index, episode_return, episode_length, episode_start_index)`. If the episode is not finished, the return value of episode_length and episode_reward is 0. """ - # preprocess batch + # preprocess and copy batch into a new Batch object to avoid mutating the input + # TODO: can't we just copy? Why do we need to rely on setting inside __dict__? new_batch = Batch() for key in batch.get_keys(): new_batch.__dict__[key] = batch[key] batch = new_batch batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) - assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset( + + # has to be done after preprocess batch + if not self._required_keys_for_add.issubset( batch.get_keys(), - ) # important to do after preprocess batch - stacked_batch = buffer_ids is not None - if stacked_batch: - assert len(batch) == 1 + ): + raise ValueError( + f"Input batch must have the following keys: {self._required_keys_for_add}", + ) + + batch_is_stacked = False + """True when instead of passing a batch of shape (len(data)), a batch of shape (1, len(data)) is passed.""" + + if buffer_ids is not None: + if len(buffer_ids) != 1 and buffer_ids[0] != 0: + raise ValueError( + "If `buffer_ids` is not None, it must be a single element with value 0 for the non-vectorized `ReplayBuffer`. " + f"Got {buffer_ids=}.", + ) + if len(batch) != 1: + raise ValueError( + f"If `buffer_ids` is not None, the batch must have the shape (1, len(data)) but got {len(batch)=}.", + ) + batch_is_stacked = True + + # block dealing with exotic options that are currently only used for atari, see various TODOs about that + # These options have interactions with the case when buffer_ids is not None if self._save_only_last_obs: - batch.obs = batch.obs[:, -1] if stacked_batch else batch.obs[-1] + batch.obs = batch.obs[:, -1] if batch_is_stacked else batch.obs[-1] if not self._save_obs_next: batch.pop("obs_next", None) elif self._save_only_last_obs: - batch.obs_next = batch.obs_next[:, -1] if stacked_batch else batch.obs_next[-1] - # get ptr - if stacked_batch: + batch.obs_next = batch.obs_next[:, -1] if batch_is_stacked else batch.obs_next[-1] + + if batch_is_stacked: rew, done = batch.rew[0], batch.done[0] else: rew, done = batch.rew, batch.done insertion_idx, ep_return, ep_len, ep_start_idx = ( np.array([x]) for x in self._update_state_pre_add(rew, done) ) + + # TODO: improve this, don'r rely on try-except, instead process the batch if needed try: self._meta[insertion_idx] = batch except ValueError: - stack = not stacked_batch + stack = not batch_is_stacked batch.rew = batch.rew.astype(float) batch.done = batch.done.astype(bool) batch.terminated = batch.terminated.astype(bool) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 125fe0e90..5ab7f67fb 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -2,7 +2,6 @@ import time import warnings from abc import ABC, abstractmethod -from collections.abc import Sequence from copy import copy from dataclasses import dataclass, field from typing import Any, Generic, Optional, Protocol, Self, TypedDict, TypeVar, cast @@ -339,58 +338,6 @@ def __init__( self._validate_buffer() self.collect_stats_class = collect_stats_class - @property - def _subbuffer_edges(self) -> np.ndarray: - return self.buffer.subbuffer_edges - - def _get_start_stop_tuples_for_edge_crossing_interval( - self, - start: int, - stop: int, - ) -> tuple[tuple[int, int], tuple[int, int]]: - """Assumes that stop < start and retrieves tuples corresponding to the two - slices that determine the interval within the buffer. - - Example: - ------- - >>> list(self._subbuffer_edges) == [0, 5, 10] - >>> start = 4 - >>> stop = 2 - >>> self._get_start_stop_tuples_for_edge_crossing_interval(start, stop) - ((4, 5), (0, 2)) - - The buffer sliced from 4 to 5 and then from 0 to 2 will contain the transitions - corresponding to the provided start and stop values. - """ - if stop >= start: - raise ValueError( - f"Expected stop < start, but got {start=}, {stop=}. " - f"For stop larger than start this method should never be called, " - f"and stop=start should never occur. This can occur either due to an implementation error, " - f"or due a bad configuration of the buffer that resulted in a single episode being so long that " - f"it completely filled a subbuffer (of size len(buffer)/degree_of_vectorization). " - f"Consider either shortening the episode, increasing the size of the buffer, or decreasing the " - f"degree of vectorization.", - ) - subbuffer_edges = cast(Sequence[int], self._subbuffer_edges) - - edge_after_start_idx = int(np.searchsorted(subbuffer_edges, start, side="left")) - """This is the crossed edge""" - - if edge_after_start_idx == 0: - raise ValueError( - f"The start value should be larger than the first edge, but got {start=}, {subbuffer_edges[1]=}.", - ) - edge_after_start = subbuffer_edges[edge_after_start_idx] - edge_before_stop = subbuffer_edges[edge_after_start_idx - 1] - """It's the edge before the crossed edge""" - - if edge_before_stop >= stop: - raise ValueError( - f"The edge before the crossed edge should be smaller than the stop, but got {edge_before_stop=}, {stop=}.", - ) - return (start, edge_after_start), (edge_before_stop, stop) - def _validate_buffer(self) -> None: buf = self.buffer # TODO: a bit weird but true - all VectorReplayBuffers inherit from ReplayBufferManager. @@ -999,21 +946,17 @@ def _collect( # noqa: C901 episode_returns_D, strict=True, ): - cur_ep_index_slice = slice( - ep_start_idx_R[local_done_idx], - insertion_idx_R[local_done_idx] + 1, + # retrieve the episode batch from the buffer using the episode start and stop indices + ep_start_idx, ep_stop_idx = ( + int(ep_start_idx_R[local_done_idx]), + int(insertion_idx_R[local_done_idx] + 1), ) - ( - cur_ep_index_array, - cur_ep_batch, - ) = self._get_buffer_index_and_entries_for_episode_from_slice( - cur_ep_index_slice, - ) - cur_ep_batch = cast(EpisodeBatchProtocol, cur_ep_batch) + ep_index_array = self.buffer.get_buffer_indices(ep_start_idx, ep_stop_idx) + ep_batch = cast(EpisodeBatchProtocol, self.buffer[ep_index_array]) # Step 10 - episode_hook_additions = self.run_on_episode_done(cur_ep_batch) + episode_hook_additions = self.run_on_episode_done(ep_batch) if episode_hook_additions is not None: if n_episode is None: raise ValueError( @@ -1026,18 +969,18 @@ def _collect( # noqa: C901 self.buffer.set_array_at_key( episode_addition, key, - index=cur_ep_index_array, + index=ep_index_array, ) # executing the same logic in the episode-batch since stats computation # may depend on the presence of additional fields - cur_ep_batch.set_array_at_key( + ep_batch.set_array_at_key( episode_addition, key, ) # Step 11 # Finally, update the stats collect_stats.update_at_episode_done( - episode_batch=cur_ep_batch, + episode_batch=ep_batch, episode_return=cur_ep_return, ) @@ -1115,52 +1058,6 @@ def _collect( # noqa: C901 self.reset_env(gym_reset_kwargs) # todo still necessary? return collect_stats - # TODO: move to buffer - def _get_buffer_index_and_entries_for_episode_from_slice( - self, - entries_slice: slice, - ) -> tuple[np.ndarray, RolloutBatchProtocol]: - """ - :param entries_slice: a slice object that selects the entries from the buffer. - `stop` can be smaller than `start`, meaning that a sub-buffer edge is to be crossed - :return: The indices of the entries in the buffer and the corresponding batch of entries. - """ - start, stop = entries_slice.start, entries_slice.stop - - # if isinstance(self.buffer, CachedReplayBuffer): - # # Accounting for the very special behavior of the CachedReplayBuffer, where once an episode is - # # finished, it is moved to the main buffer and the corresponding subbuffer is reset. - # # This means, that retrieving a slice corresponding to a finished episode should always happen - # # from the main buffer, whereas slices for unfinished episodes should always be retrieved from - # # the corresponding subbuffer - # # TODO: fix this behavior in CachedReplayBuffer, remove the special sauce here - # start = start % self.buffer.main_buffer.maxsize - # stop = stop % self.buffer.main_buffer.maxsize - - if stop > start: - cur_ep_index_array = np.arange( - entries_slice.start, - entries_slice.stop, - dtype=int, - ) - else: - # stop < start means that to retrieve the slice we have to cross an edge of the buffer - # We have to split the slice into two parts and concatenate the results - log.debug(f"Received an edge-crossing slice with {stop=} < {start=}") - (start, upper_edge), ( - lower_edge, - stop, - ) = self._get_start_stop_tuples_for_edge_crossing_interval( - start, - stop, - ) - cur_ep_index_array = np.concatenate( - (np.arange(start, upper_edge, dtype=int), np.arange(lower_edge, stop, dtype=int)), - ) - log.debug(f"{start=}, {upper_edge=}, {lower_edge=}, {stop=}") - ep_rollout_batch = self.buffer[cur_ep_index_array] - return cur_ep_index_array, ep_rollout_batch - @staticmethod def _reset_hidden_state_based_on_type( env_ind_local_D: np.ndarray,