From c0669fc3a8829dec4a7a1fbee198a6be4f668ea1 Mon Sep 17 00:00:00 2001 From: John Balis Date: Thu, 10 Aug 2023 15:58:43 -0400 Subject: [PATCH] Improves MinariDataset load speed to speed up sampling, fixes total_steps bug and adds test coverage (#129) * attempted to fix list_remote_datasets slowdown * tester.py draft * patch to speed up sampling from a minari dataset, MinariStorage total_steps tests coverage and bugfix * some polish, removed tester file * removed print statements * reverted hosting.py, change pending in seperate PR --- minari/dataset/minari_dataset.py | 15 ++++++++------- minari/dataset/minari_storage.py | 15 ++++++++------- minari/utils.py | 6 +++--- tests/common.py | 4 +++- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 2cd443c2..8705a71b 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -141,18 +141,19 @@ def __init__( self._additional_data_id = 0 if episode_indices is None: episode_indices = np.arange(self._data.total_episodes) + total_steps = self._data.total_steps + else: + total_steps = sum( + self._data.apply( + lambda episode: episode["total_timesteps"], + episode_indices=episode_indices, + ) + ) self._episode_indices = episode_indices assert self._episode_indices is not None - total_steps = sum( - self._data.apply( - lambda episode: episode["total_timesteps"], - episode_indices=self._episode_indices, - ) - ) - self.spec = MinariDatasetSpec( env_spec=self._data.env_spec, total_episodes=self._episode_indices.size, diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index c8fc2bb3..3d3cae39 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -205,13 +205,13 @@ def update_from_collector_env( "id", last_episode_id + id ) + self._total_steps = file.attrs["total_steps"] + new_data_total_steps + # Update metadata of minari dataset file.attrs.modify( "total_episodes", last_episode_id + new_data_total_episodes ) - file.attrs.modify( - "total_steps", file.attrs["total_steps"] + new_data_total_steps - ) + file.attrs.modify("total_steps", self._total_steps) self._total_episodes = int(file.attrs["total_episodes"].item()) def update_from_buffer(self, buffer: List[dict], data_path: str): @@ -247,10 +247,11 @@ def update_from_buffer(self, buffer: List[dict], data_path: str): # TODO: save EpisodeMetadataCallback callback in MinariDataset and update new episode group metadata - file.attrs.modify("total_episodes", last_episode_id + len(buffer)) - file.attrs.modify( - "total_steps", file.attrs["total_steps"] + additional_steps - ) + self._total_steps = file.attrs["total_steps"] + additional_steps + self._total_episodes = last_episode_id + len(buffer) + + file.attrs.modify("total_episodes", self._total_episodes) + file.attrs.modify("total_steps", self._total_steps) self._total_episodes = int(file.attrs["total_episodes"].item()) diff --git a/minari/utils.py b/minari/utils.py index 025335a8..0945c705 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -467,9 +467,9 @@ def create_dataset_from_buffers( ) eps_group.attrs["id"] = i - total_steps = len(eps_buff["actions"]) - eps_group.attrs["total_steps"] = total_steps - total_steps += total_steps + episode_total_steps = len(eps_buff["actions"]) + eps_group.attrs["total_steps"] = episode_total_steps + total_steps += episode_total_steps if seed is None: eps_group.attrs["seed"] = str(None) diff --git a/tests/common.py b/tests/common.py index e4cdc1dc..bdf73727 100644 --- a/tests/common.py +++ b/tests/common.py @@ -458,11 +458,12 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): episode_indices (Iterable[int]): the list of episode indices expected """ episodes = data.get_episodes(episode_indices) - print([episode["id"] for episode in episodes]) # verify we have the right number of episodes, available at the right indices assert data.total_episodes == len(episodes) + total_steps = 0 # verify the actions and observations are in the appropriate action space and observation space, and that the episode lengths are correct for episode in episodes: + total_steps += episode["total_timesteps"] _check_space_elem( episode["observations"], data.observation_space, @@ -484,6 +485,7 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): assert episode["total_timesteps"] == len(episode["rewards"]) assert episode["total_timesteps"] == len(episode["terminations"]) assert episode["total_timesteps"] == len(episode["truncations"]) + assert total_steps == data.total_steps def _reconstuct_obs_or_action_at_index_recursive(