Skip to content

Commit

Permalink
Improves MinariDataset load speed to speed up sampling, fixes total_s…
Browse files Browse the repository at this point in the history
…teps bug and adds test coverage (#129)

* attempted to fix list_remote_datasets slowdown

* tester.py draft

* patch to speed up sampling from a minari dataset, MinariStorage total_steps tests coverage and bugfix

* some polish, removed tester file

* removed print statements

* reverted hosting.py, change pending in seperate PR
  • Loading branch information
balisujohn authored Aug 10, 2023
1 parent e1d5658 commit c0669fc
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 18 deletions.
15 changes: 8 additions & 7 deletions minari/dataset/minari_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,19 @@ def __init__(
self._additional_data_id = 0
if episode_indices is None:
episode_indices = np.arange(self._data.total_episodes)
total_steps = self._data.total_steps
else:
total_steps = sum(
self._data.apply(
lambda episode: episode["total_timesteps"],
episode_indices=episode_indices,
)
)

self._episode_indices = episode_indices

assert self._episode_indices is not None

total_steps = sum(
self._data.apply(
lambda episode: episode["total_timesteps"],
episode_indices=self._episode_indices,
)
)

self.spec = MinariDatasetSpec(
env_spec=self._data.env_spec,
total_episodes=self._episode_indices.size,
Expand Down
15 changes: 8 additions & 7 deletions minari/dataset/minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,13 @@ def update_from_collector_env(
"id", last_episode_id + id
)

self._total_steps = file.attrs["total_steps"] + new_data_total_steps

# Update metadata of minari dataset
file.attrs.modify(
"total_episodes", last_episode_id + new_data_total_episodes
)
file.attrs.modify(
"total_steps", file.attrs["total_steps"] + new_data_total_steps
)
file.attrs.modify("total_steps", self._total_steps)
self._total_episodes = int(file.attrs["total_episodes"].item())

def update_from_buffer(self, buffer: List[dict], data_path: str):
Expand Down Expand Up @@ -247,10 +247,11 @@ def update_from_buffer(self, buffer: List[dict], data_path: str):

# TODO: save EpisodeMetadataCallback callback in MinariDataset and update new episode group metadata

file.attrs.modify("total_episodes", last_episode_id + len(buffer))
file.attrs.modify(
"total_steps", file.attrs["total_steps"] + additional_steps
)
self._total_steps = file.attrs["total_steps"] + additional_steps
self._total_episodes = last_episode_id + len(buffer)

file.attrs.modify("total_episodes", self._total_episodes)
file.attrs.modify("total_steps", self._total_steps)

self._total_episodes = int(file.attrs["total_episodes"].item())

Expand Down
6 changes: 3 additions & 3 deletions minari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,9 +467,9 @@ def create_dataset_from_buffers(
)

eps_group.attrs["id"] = i
total_steps = len(eps_buff["actions"])
eps_group.attrs["total_steps"] = total_steps
total_steps += total_steps
episode_total_steps = len(eps_buff["actions"])
eps_group.attrs["total_steps"] = episode_total_steps
total_steps += episode_total_steps

if seed is None:
eps_group.attrs["seed"] = str(None)
Expand Down
4 changes: 3 additions & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,12 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]):
episode_indices (Iterable[int]): the list of episode indices expected
"""
episodes = data.get_episodes(episode_indices)
print([episode["id"] for episode in episodes])
# verify we have the right number of episodes, available at the right indices
assert data.total_episodes == len(episodes)
total_steps = 0
# verify the actions and observations are in the appropriate action space and observation space, and that the episode lengths are correct
for episode in episodes:
total_steps += episode["total_timesteps"]
_check_space_elem(
episode["observations"],
data.observation_space,
Expand All @@ -484,6 +485,7 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]):
assert episode["total_timesteps"] == len(episode["rewards"])
assert episode["total_timesteps"] == len(episode["terminations"])
assert episode["total_timesteps"] == len(episode["truncations"])
assert total_steps == data.total_steps


def _reconstuct_obs_or_action_at_index_recursive(
Expand Down

0 comments on commit c0669fc

Please sign in to comment.