diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index f380b915..a3f20135 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -289,10 +289,7 @@ def create_dataset( ) self._save_to_disk(dataset_path, metadata) - - dataset = MinariDataset(dataset_path) - dataset.storage.update_metadata({"dataset_size": dataset.storage.get_size()}) - return dataset + return MinariDataset(dataset_path) def _flush_to_storage(self): if self._buffer is not None and len(self._buffer) > 0: diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py index 0cc0530c..b3582c11 100644 --- a/minari/dataset/_storages/arrow_storage.py +++ b/minari/dataset/_storages/arrow_storage.py @@ -65,6 +65,8 @@ def update_episode_metadata( with open(metadata_path, "w") as file: json.dump(metadata, file, cls=NumpyEncoder) + self.update_metadata({"dataset_size": self.get_size()}) + def get_episode_metadata(self, episode_indices: Iterable[int]) -> Iterable[Dict]: for episode_id in episode_indices: metadata_path = self.data_path.joinpath(str(episode_id), "metadata.json") @@ -148,7 +150,11 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]): self.update_episode_metadata([episode_metadata], [episode_id]) self.update_metadata( - {"total_steps": total_steps, "total_episodes": total_episodes} + { + "total_steps": total_steps, + "total_episodes": total_episodes, + "dataset_size": self.get_size(), + } ) diff --git a/minari/dataset/_storages/hdf5_storage.py b/minari/dataset/_storages/hdf5_storage.py index 03559c98..bca359b7 100644 --- a/minari/dataset/_storages/hdf5_storage.py +++ b/minari/dataset/_storages/hdf5_storage.py @@ -68,6 +68,8 @@ def update_episode_metadata( ep_group = file[f"episode_{episode_id}"] ep_group.attrs.update(metadata) + self.update_metadata({"dataset_size": self.get_size()}) + def get_episode_metadata(self, episode_indices: Iterable[int]) -> Iterable[Dict]: with h5py.File(self._file_path, "r") as file: for ep_idx in episode_indices: @@ -174,7 +176,11 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]): total_steps = self.total_steps + additional_steps self.update_metadata( - {"total_steps": total_steps, "total_episodes": total_episodes} + { + "total_steps": total_steps, + "total_episodes": total_episodes, + "dataset_size": self.get_size(), + } ) diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index e6743116..deba6929 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -339,6 +339,7 @@ def update_from_storage(self, storage: MinariStorage): { "author": author1.union(author2), "author_email": email1.union(email2), + "dataset_size": self.get_size(), } ) diff --git a/minari/utils.py b/minari/utils.py index 3e24b8da..0183804a 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -429,8 +429,6 @@ def create_dataset_from_buffers( storage.update_metadata(metadata) storage.update_episodes(buffer) - storage.update_metadata({"dataset_size": storage.get_size()}) - return MinariDataset(storage)