Skip to content

Commit

Permalink
auto update dataset size (Farama-Foundation#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
younik authored Oct 14, 2024
1 parent 4e6d9f7 commit 2d77f62
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 8 deletions.
5 changes: 1 addition & 4 deletions minari/data_collector/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,7 @@ def create_dataset(
)

self._save_to_disk(dataset_path, metadata)

dataset = MinariDataset(dataset_path)
dataset.storage.update_metadata({"dataset_size": dataset.storage.get_size()})
return dataset
return MinariDataset(dataset_path)

def _flush_to_storage(self):
if self._buffer is not None and len(self._buffer) > 0:
Expand Down
8 changes: 7 additions & 1 deletion minari/dataset/_storages/arrow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def update_episode_metadata(
with open(metadata_path, "w") as file:
json.dump(metadata, file, cls=NumpyEncoder)

self.update_metadata({"dataset_size": self.get_size()})

def get_episode_metadata(self, episode_indices: Iterable[int]) -> Iterable[Dict]:
for episode_id in episode_indices:
metadata_path = self.data_path.joinpath(str(episode_id), "metadata.json")
Expand Down Expand Up @@ -148,7 +150,11 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]):
self.update_episode_metadata([episode_metadata], [episode_id])

self.update_metadata(
{"total_steps": total_steps, "total_episodes": total_episodes}
{
"total_steps": total_steps,
"total_episodes": total_episodes,
"dataset_size": self.get_size(),
}
)


Expand Down
8 changes: 7 additions & 1 deletion minari/dataset/_storages/hdf5_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def update_episode_metadata(
ep_group = file[f"episode_{episode_id}"]
ep_group.attrs.update(metadata)

self.update_metadata({"dataset_size": self.get_size()})

def get_episode_metadata(self, episode_indices: Iterable[int]) -> Iterable[Dict]:
with h5py.File(self._file_path, "r") as file:
for ep_idx in episode_indices:
Expand Down Expand Up @@ -174,7 +176,11 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]):

total_steps = self.total_steps + additional_steps
self.update_metadata(
{"total_steps": total_steps, "total_episodes": total_episodes}
{
"total_steps": total_steps,
"total_episodes": total_episodes,
"dataset_size": self.get_size(),
}
)


Expand Down
1 change: 1 addition & 0 deletions minari/dataset/minari_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def update_from_storage(self, storage: MinariStorage):
{
"author": author1.union(author2),
"author_email": email1.union(email2),
"dataset_size": self.get_size(),
}
)

Expand Down
2 changes: 0 additions & 2 deletions minari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,6 @@ def create_dataset_from_buffers(

storage.update_metadata(metadata)
storage.update_episodes(buffer)
storage.update_metadata({"dataset_size": storage.get_size()})

return MinariDataset(storage)


Expand Down

0 comments on commit 2d77f62

Please sign in to comment.