From edc00e2e92041687adfe94726127d823b2a770be Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Mon, 21 Aug 2023 11:27:29 -0400 Subject: [PATCH 01/29] initial commit for info's fix --- minari/data_collector/data_collector.py | 3 ++ minari/dataset/episode_data.py | 4 +- tests/common.py | 11 ++-- tests/utils/test_dataset_creation.py | 70 ++++++++++++++++++++++++- 4 files changed, 83 insertions(+), 5 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 80f893f4..a4712533 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -154,6 +154,9 @@ def _add_to_episode_buffer( if (not self._record_infos and key == "infos") or (value is None): continue + if key == "infos": + print("infos time") + if key not in episode_buffer: if isinstance(value, dict): episode_buffer[key] = self._add_to_episode_buffer({}, value) diff --git a/minari/dataset/episode_data.py b/minari/dataset/episode_data.py index 53786144..797eed87 100644 --- a/minari/dataset/episode_data.py +++ b/minari/dataset/episode_data.py @@ -19,6 +19,7 @@ class EpisodeData: rewards: np.ndarray terminations: np.ndarray truncations: np.ndarray + infos: dict def __repr__(self) -> str: return ( @@ -30,7 +31,8 @@ def __repr__(self) -> str: f"actions={EpisodeData._repr_space_values(self.actions)}, " f"rewards=ndarray of {len(self.rewards)} floats, " f"terminations=ndarray of {len(self.terminations)} bools, " - f"truncations=ndarray of {len(self.truncations)} bools" + f"truncations=ndarray of {len(self.truncations)} bools, " + f"infos=dict with keys of :{list(self.infos.keys())}" ")" ) diff --git a/tests/common.py b/tests/common.py index 65f013de..9e8c81fe 100644 --- a/tests/common.py +++ b/tests/common.py @@ -32,12 +32,12 @@ def step(self, action): terminated = self.timestep > 5 self.timestep += 1 - return self.observation_space.sample(), 0, terminated, False, {} + return self.observation_space.sample(), 0, terminated, False, {"timestep" : self.timestep} def reset(self, seed=None, options=None): self.timestep = 0 self.observation_space.seed(seed) - return self.observation_space.sample(), {} + return self.observation_space.sample(), {"timestep" : self.timestep} class DummyMultiDimensionalBoxEnv(gym.Env): @@ -604,7 +604,7 @@ def create_dummy_dataset_with_collecter_env_helper( def check_episode_data_integrity( - episode_data_list: List[EpisodeData], + episode_data_list: Union[List[EpisodeData],MinariDataset], observation_space: gym.spaces.Space, action_space: gym.spaces.Space, ): @@ -627,7 +627,12 @@ def check_episode_data_integrity( for i in range(episode.total_timesteps + 1): obs = _reconstuct_obs_or_action_at_index_recursive(episode.observations, i) + print(episode) + print(episode.observations) + print(obs) + print(observation_space) assert observation_space.contains(obs) + for i in range(episode.total_timesteps): action = _reconstuct_obs_or_action_at_index_recursive(episode.actions, i) assert action_space.contains(action) diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index b1f878c0..31b6a76b 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -14,6 +14,7 @@ check_load_and_delete_dataset, get_sample_buffer_for_dataset_from_env, register_dummy_envs, + check_episode_data_integrity ) @@ -82,7 +83,8 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): assert dataset.spec.total_episodes == num_episodes assert len(dataset.episode_indices) == num_episodes - check_data_integrity(dataset.storage, dataset.episode_indices) + check_data_integrity(dataset._data, dataset.episode_indices) + check_episode_data_integrity(dataset, dataset.spec.observation_space, dataset.spec.action_space) # check that the environment can be recovered from the dataset check_env_recovery(env.env, dataset, eval_env) @@ -92,6 +94,70 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): # check load and delete local dataset check_load_and_delete_dataset(dataset_id) +@pytest.mark.parametrize( + "dataset_id,env_id", + [ + ("dummy-box-test-v0", "DummyBoxEnv-v0"), + ], +) +def test_generate_dataset_with_collector_env_infos(dataset_id, env_id): + """Test DataCollectorV0 wrapper and Minari dataset creation.""" + # dataset_id = "cartpole-test-v0" + # delete the test dataset if it already exists + local_datasets = minari.list_local_datasets() + if dataset_id in local_datasets: + minari.delete_dataset(dataset_id) + + env = gym.make(env_id) + + env = DataCollectorV0(env, record_infos = True) + num_episodes = 10 + + # Step the environment, DataCollectorV0 wrapper will do the data collection job + env.reset(seed=42) + + for episode in range(num_episodes): + terminated = False + truncated = False + while not terminated and not truncated: + action = env.action_space.sample() # User-defined policy function + _, _, terminated, truncated, info = env.step(action) + print(info) + if terminated or truncated: + assert not env._buffer[-1] + else: + assert env._buffer[-1] + + env.reset() + + # Create Minari dataset and store locally + dataset = minari.create_dataset_from_collector_env( + dataset_id=dataset_id, + collector_env=env, + algorithm_name="random_policy", + code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", + author="WillDudley", + author_email="wdudley@farama.org", + ) + + assert isinstance(dataset, MinariDataset) + assert dataset.total_episodes == num_episodes + assert dataset.spec.total_episodes == num_episodes + assert len(dataset.episode_indices) == num_episodes + + check_data_integrity(dataset._data, dataset.episode_indices) + check_episode_data_integrity(dataset, dataset.spec.observation_space, dataset.spec.action_space) + + + # check that the environment can be recovered from the dataset + check_env_recovery(env.env, dataset) + print("episodedata") + print(dataset[0]) + + env.close() + # check load and delete local dataset + check_load_and_delete_dataset(dataset_id) + @pytest.mark.parametrize( "dataset_id,env_id", @@ -186,6 +252,7 @@ def test_generate_dataset_with_external_buffer(dataset_id, env_id): assert len(dataset.episode_indices) == num_episodes check_data_integrity(dataset.storage, dataset.episode_indices) + check_episode_data_integrity(dataset, dataset.spec.observation_space, dataset.spec.action_space) check_env_recovery(env, dataset, eval_env) check_load_and_delete_dataset(dataset_id) @@ -254,6 +321,7 @@ def test_generate_dataset_with_space_subset_external_buffer(is_env_needed): assert len(dataset.episode_indices) == num_episodes check_data_integrity(dataset.storage, dataset.episode_indices) + check_episode_data_integrity(dataset, dataset.spec.observation_space, dataset.spec.action_space) if is_env_needed: check_env_recovery_with_subset_spaces( env, dataset, action_space_subset, observation_space_subset From 2d892c27de26145a5b100c4038429b163435e3c1 Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Fri, 25 Aug 2023 00:03:59 -0500 Subject: [PATCH 02/29] tentative draft of info support for EpisodeData --- minari/dataset/minari_storage.py | 16 ++++ tests/common.py | 96 ++++++++++++++++++--- tests/data_collector/test_data_collector.py | 18 ++++ tests/dataset/test_minari_dataset.py | 2 + tests/utils/test_dataset_creation.py | 28 +++--- 5 files changed, 139 insertions(+), 21 deletions(-) diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 860c2035..ec32bb61 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -161,6 +161,19 @@ def apply( ep_dicts = self.get_episodes(episode_indices) return map(function, ep_dicts) + def _decode_infos(self, infos: h5py.Group): + result = {} + for key in infos.keys(): + if isinstance(infos[key], h5py.Group): + result[key] = self._decode_infos(infos[key]) + elif isinstance(infos[key], h5py.Dataset): + result[key] = infos[key][()] + else: + raise ValueError( + "Infos are in an unsupported format; see Minari documentation for supported formats." + ) + return result + def _decode_space( self, hdf_ref: Union[h5py.Group, h5py.Dataset, h5py.Datatype], @@ -219,6 +232,9 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: "actions": self._decode_space( ep_group["actions"], self.action_space ), + "infos": self._decode_infos(ep_group["infos"]) + if "infos" in ep_group + else {}, } for key in {"rewards", "terminations", "truncations"}: group_value = ep_group[key] diff --git a/tests/common.py b/tests/common.py index 9e8c81fe..977bb586 100644 --- a/tests/common.py +++ b/tests/common.py @@ -32,7 +32,13 @@ def step(self, action): terminated = self.timestep > 5 self.timestep += 1 - return self.observation_space.sample(), 0, terminated, False, {"timestep" : self.timestep} + return ( + self.observation_space.sample(), + 0, + terminated, + False, + {"timestep": self.timestep}, + ) def reset(self, seed=None, options=None): self.timestep = 0 @@ -117,12 +123,28 @@ def step(self, action): terminated = self.timestep > 5 self.timestep += 1 - return self.observation_space.sample(), 0, terminated, False, {} + return ( + self.observation_space.sample(), + 0, + terminated, + False, + { + "timestep": self.timestep, + "component_1": {"next_timestep": self.timestep + 1}, + }, + ) def reset(self, seed=None, options=None): self.timestep = 0 +<<<<<<< HEAD self.observation_space.seed(seed) return self.observation_space.sample(), {} +======= + return self.observation_space.sample(), { + "timestep": self.timestep, + "component_1": {"next_timestep": self.timestep + 1}, + } +>>>>>>> 4367f79 (tentative draft of info support for EpisodeData) class DummyTupleEnv(gym.Env): @@ -150,12 +172,26 @@ def step(self, action): terminated = self.timestep > 5 self.timestep += 1 - return self.observation_space.sample(), 0, terminated, False, {} + info = { + "info_1": np.ones((2, 2)), + "component_1": {"component_1_info_1": np.ones((2,))}, + } + + return self.observation_space.sample(), 0, terminated, False, info def reset(self, seed=None, options=None): self.timestep = 0 +<<<<<<< HEAD self.observation_space.seed(seed) return self.observation_space.sample(), {} +======= + + info = { + "info_1": np.ones((2, 2)), + "component_1": {"component_1_info_1": np.ones((2,))}, + } + return self.observation_space.sample(), info +>>>>>>> 4367f79 (tentative draft of info support for EpisodeData) class DummyTextEnv(gym.Env): @@ -512,6 +548,45 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): assert total_steps == data.total_steps +def assert_infos_same_shape(info_1, info_2): + for key in info_1.keys(): + if isinstance(info_1[key], dict): + if not assert_infos_same_shape(info_1[key], info_2[key]): + return False + elif isinstance(info_1[key], np.ndarray): + if not (info_1[key].shape == info_2[key].shape) and ( + info_1[key].dtype == info_2[key].dtype + ): + return False + elif np.issubdtype(type(info_1[key]), np.integer) and np.issubdtype( + type(info_2[key]), np.integer + ): + pass + elif np.issubdtype(type(info_1[key]), np.float) and np.issubdtype( + type(info_2[key]), np.float + ): + pass + else: + raise ValueError( + "Infos are in an unsupported format; see Minari documentation for supported formats." + ) + return True + + +def _get_step_from_infos_dict(infos, step_index): + result = {} + for key in infos.keys(): + if isinstance(infos[key], dict): + result[key] = _get_step_from_infos_dict(infos[key], step_index) + elif isinstance(infos[key], np.ndarray): + result[key] = infos[key][step_index] + else: + raise ValueError( + "Infos are in an unsupported format; see Minari documentation for supported formats." + ) + return result + + def _reconstuct_obs_or_action_at_index_recursive( data: Union[dict, tuple, np.ndarray], index: int ) -> Union[np.ndarray, dict, tuple]: @@ -604,9 +679,10 @@ def create_dummy_dataset_with_collecter_env_helper( def check_episode_data_integrity( - episode_data_list: Union[List[EpisodeData],MinariDataset], + episode_data_list: Union[List[EpisodeData], MinariDataset], observation_space: gym.spaces.Space, action_space: gym.spaces.Space, + info_sample: Optional[dict] = None, ): """Checks to see if a list of EpisodeData instances has consistent data and that the observations and actions are in the appropriate spaces. @@ -614,7 +690,7 @@ def check_episode_data_integrity( episode_data_list (List[EpisodeData]): A list of EpisodeData instances representing episodes. observation_space(gym.spaces.Space): The environment's observation space. action_space(gym.spaces.Space): The environment's action space. - + info_sample(dict): An info returned by the environment used to build the dataset. """ # verify the actions and observations are in the appropriate action space and observation space, and that the episode lengths are correct for episode in episode_data_list: @@ -627,12 +703,12 @@ def check_episode_data_integrity( for i in range(episode.total_timesteps + 1): obs = _reconstuct_obs_or_action_at_index_recursive(episode.observations, i) - print(episode) - print(episode.observations) - print(obs) - print(observation_space) + + assert assert_infos_same_shape( + _get_step_from_infos_dict(episode.infos, i), info_sample + ) assert observation_space.contains(obs) - + for i in range(episode.total_timesteps): action = _reconstuct_obs_or_action_at_index_recursive(episode.actions, i) assert action_space.contains(action) diff --git a/tests/data_collector/test_data_collector.py b/tests/data_collector/test_data_collector.py index e183842f..3141a5bb 100644 --- a/tests/data_collector/test_data_collector.py +++ b/tests/data_collector/test_data_collector.py @@ -1,4 +1,5 @@ import gymnasium as gym +import h5py import numpy as np import pytest @@ -29,6 +30,20 @@ def __call__(self, env, **kwargs): return step_data +def _get_step_from_infos(infos, step_index): + result = {} + for key in infos.keys(): + if isinstance(infos[key], h5py.Group): + result[key] = _get_step_from_infos(infos[key]) + elif isinstance(infos[key], h5py.Dataset): + result[key] = infos[key][step_index] + else: + raise ValueError( + "Infos are in an unsupported format; see Minari documentation for supported formats." + ) + return result + + def _get_step_from_dictionary_space(episode_data, step_index): step_data = {} assert isinstance(episode_data, dict) @@ -71,6 +86,8 @@ def get_single_step_from_episode(episode: EpisodeData, index: int) -> EpisodeDat else: action = episode.actions[index] + infos = _get_step_from_infos(episode.infos, index) + step_data = { "id": episode.id, "total_timesteps": 1, @@ -80,6 +97,7 @@ def get_single_step_from_episode(episode: EpisodeData, index: int) -> EpisodeDat "rewards": episode.rewards[index], "terminations": episode.terminations[index], "truncations": episode.truncations[index], + "infos": infos, } return EpisodeData(**step_data) diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index 39449678..be9363c0 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -44,6 +44,7 @@ def test_episode_data(space: gym.Space): rewards=rewards, terminations=terminations, truncations=truncations, + infos={}, ) pattern = r"EpisodeData\(" @@ -55,6 +56,7 @@ def test_episode_data(space: gym.Space): pattern += r"rewards=.+, " pattern += r"terminations=.+, " pattern += r"truncations=.+" + pattern += r"infos=.+" pattern += r"\)" assert re.fullmatch(pattern, repr(episode_data)) diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index 31b6a76b..9d930d1d 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -11,10 +11,10 @@ check_data_integrity, check_env_recovery, check_env_recovery_with_subset_spaces, + check_episode_data_integrity, check_load_and_delete_dataset, get_sample_buffer_for_dataset_from_env, register_dummy_envs, - check_episode_data_integrity ) @@ -84,7 +84,9 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): assert len(dataset.episode_indices) == num_episodes check_data_integrity(dataset._data, dataset.episode_indices) - check_episode_data_integrity(dataset, dataset.spec.observation_space, dataset.spec.action_space) + check_episode_data_integrity( + dataset, dataset.spec.observation_space, dataset.spec.action_space + ) # check that the environment can be recovered from the dataset check_env_recovery(env.env, dataset, eval_env) @@ -94,10 +96,13 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): # check load and delete local dataset check_load_and_delete_dataset(dataset_id) + @pytest.mark.parametrize( "dataset_id,env_id", [ + ("dummy-dict-test-v0", "DummyDictEnv-v0"), ("dummy-box-test-v0", "DummyBoxEnv-v0"), + ("dummy-tuple-test-v0", "DummyTupleEnv-v0"), ], ) def test_generate_dataset_with_collector_env_infos(dataset_id, env_id): @@ -110,19 +115,18 @@ def test_generate_dataset_with_collector_env_infos(dataset_id, env_id): env = gym.make(env_id) - env = DataCollectorV0(env, record_infos = True) + env = DataCollectorV0(env, record_infos=True) num_episodes = 10 # Step the environment, DataCollectorV0 wrapper will do the data collection job - env.reset(seed=42) + _, info_sample = env.reset(seed=42) for episode in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: action = env.action_space.sample() # User-defined policy function - _, _, terminated, truncated, info = env.step(action) - print(info) + _, _, terminated, truncated, _ = env.step(action) if terminated or truncated: assert not env._buffer[-1] else: @@ -146,16 +150,18 @@ def test_generate_dataset_with_collector_env_infos(dataset_id, env_id): assert len(dataset.episode_indices) == num_episodes check_data_integrity(dataset._data, dataset.episode_indices) - check_episode_data_integrity(dataset, dataset.spec.observation_space, dataset.spec.action_space) - + check_episode_data_integrity( + dataset, + dataset.spec.observation_space, + dataset.spec.action_space, + info_sample=info_sample, + ) # check that the environment can be recovered from the dataset check_env_recovery(env.env, dataset) - print("episodedata") - print(dataset[0]) env.close() - # check load and delete local dataset + check_load_and_delete_dataset(dataset_id) From 08d708aac7cf2fde5d6d64e0561447e9b156da17 Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Fri, 25 Aug 2023 00:09:03 -0500 Subject: [PATCH 03/29] typing fix --- tests/data_collector/test_data_collector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data_collector/test_data_collector.py b/tests/data_collector/test_data_collector.py index 3141a5bb..a0089e14 100644 --- a/tests/data_collector/test_data_collector.py +++ b/tests/data_collector/test_data_collector.py @@ -30,7 +30,7 @@ def __call__(self, env, **kwargs): return step_data -def _get_step_from_infos(infos, step_index): +def _get_step_from_infos(infos, step_index: int): result = {} for key in infos.keys(): if isinstance(infos[key], h5py.Group): From 5dce7f4b92a2126e915f2f591aaddd589b3101ae Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Fri, 25 Aug 2023 00:15:15 -0500 Subject: [PATCH 04/29] typing fixed, removed print --- minari/data_collector/data_collector.py | 3 --- tests/data_collector/test_data_collector.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index a4712533..80f893f4 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -154,9 +154,6 @@ def _add_to_episode_buffer( if (not self._record_infos and key == "infos") or (value is None): continue - if key == "infos": - print("infos time") - if key not in episode_buffer: if isinstance(value, dict): episode_buffer[key] = self._add_to_episode_buffer({}, value) diff --git a/tests/data_collector/test_data_collector.py b/tests/data_collector/test_data_collector.py index a0089e14..e0c94dee 100644 --- a/tests/data_collector/test_data_collector.py +++ b/tests/data_collector/test_data_collector.py @@ -34,7 +34,7 @@ def _get_step_from_infos(infos, step_index: int): result = {} for key in infos.keys(): if isinstance(infos[key], h5py.Group): - result[key] = _get_step_from_infos(infos[key]) + result[key] = _get_step_from_infos(infos[key], step_index) elif isinstance(infos[key], h5py.Dataset): result[key] = infos[key][step_index] else: From e93cfad53111db953ce18b0bf394594d0b1b52ac Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Mon, 28 Aug 2023 02:00:23 -0500 Subject: [PATCH 05/29] added some information to dataset stabdards started work on test for varying infos --- docs/content/dataset_standards.md | 10 ++++ tests/common.py | 14 ++--- .../callbacks/test_step_data_callback.py | 56 +++++++++++++++++++ 3 files changed, 70 insertions(+), 10 deletions(-) diff --git a/docs/content/dataset_standards.md b/docs/content/dataset_standards.md index 280f5d80..e04cf594 100644 --- a/docs/content/dataset_standards.md +++ b/docs/content/dataset_standards.md @@ -554,5 +554,15 @@ The `sampled_episodes` variable will be a list of 10 `EpisodeData` elements, eac | `rewards` | `np.ndarray` | Rewards for each timestep. | | `terminations` | `np.ndarray` | Terminations for each timestep. | | `truncations` | `np.ndarray` | Truncations for each timestep. | +| `infos` | `dict` | A dictionary containing additional information. | As mentioned in the `Supported Spaces` section, many different observation and action spaces are supported so the data type for these fields are dependent on the environment being used. + +## Additional Information Formatting + + + + +When creating a dataset with `DataCollectorV0` the additional information stored in the `infos` group of the hdf5 file must be provided to Minari as a dict, which can only contain other dictionaries or `np.ndarray` as values. An info dict must be provided with every observation(including the one from the initial reset), and the shape of each `np.ndarray` must stay the same across timesteps. + +Since it is not guaranteed that all Gymnasium environments provide infos at every timestep, we provide a wrapper to allow for creating infos for environments which do not comply with the info format by default. diff --git a/tests/common.py b/tests/common.py index 977bb586..5667b44a 100644 --- a/tests/common.py +++ b/tests/common.py @@ -86,12 +86,13 @@ def step(self, action): terminated = self.timestep > 5 self.timestep += 1 - return self.observation_space.sample(), 0, terminated, False, {} + return self.observation_space.sample(), 0, terminated, False, {"timestep": self.timestep} if self.timestep %2 == 0 else {} + def reset(self, seed=None, options=None): self.timestep = 0 self.observation_space.seed(seed) - return self.observation_space.sample(), {} + return self.observation_space.sample(), {"timestep": self.timestep} if self.timestep %2 == 0 else {} class DummyDictEnv(gym.Env): @@ -136,15 +137,12 @@ def step(self, action): def reset(self, seed=None, options=None): self.timestep = 0 -<<<<<<< HEAD self.observation_space.seed(seed) - return self.observation_space.sample(), {} -======= + return self.observation_space.sample(), { "timestep": self.timestep, "component_1": {"next_timestep": self.timestep + 1}, } ->>>>>>> 4367f79 (tentative draft of info support for EpisodeData) class DummyTupleEnv(gym.Env): @@ -181,17 +179,13 @@ def step(self, action): def reset(self, seed=None, options=None): self.timestep = 0 -<<<<<<< HEAD self.observation_space.seed(seed) - return self.observation_space.sample(), {} -======= info = { "info_1": np.ones((2, 2)), "component_1": {"component_1_info_1": np.ones((2,))}, } return self.observation_space.sample(), info ->>>>>>> 4367f79 (tentative draft of info support for EpisodeData) class DummyTextEnv(gym.Env): diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index c16df008..d16b0ecf 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -10,6 +10,7 @@ check_env_recovery_with_subset_spaces, check_load_and_delete_dataset, register_dummy_envs, + check_env_recovery, ) @@ -110,3 +111,58 @@ def test_data_collector_step_data_callback(): env.close() # check load and delete local dataset check_load_and_delete_dataset(dataset_id) + + +def test_data_collector_step_data_callback_info_correction(): + """Test DataCollectorV0 wrapper and Minari dataset creation.""" + dataset_id = "dummy-tuple-discrete-box-v0" + # delete the test dataset if it already exists + local_datasets = minari.list_local_datasets() + if dataset_id in local_datasets: + minari.delete_dataset(dataset_id) + + env = gym.make("DummyTupleDiscreteBoxEnv-v0") + + env = DataCollectorV0( + env, + record_infos = True, + ) + num_episodes = 10 + + # Step the environment, DataCollectorV0 wrapper will do the data collection job + env.reset(seed=42) + + for episode in range(num_episodes): + terminated = False + truncated = False + while not terminated and not truncated: + action = env.action_space.sample() # User-defined policy function + _, _, terminated, truncated, _ = env.step(action) + + env.reset() + + # Create Minari dataset and store locally + dataset = minari.create_dataset_from_collector_env( + dataset_id=dataset_id, + collector_env=env, + algorithm_name="random_policy", + code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", + author="WillDudley", + author_email="wdudley@farama.org", + ) + + assert isinstance(dataset, MinariDataset) + assert dataset.total_episodes == num_episodes + assert dataset.spec.total_episodes == num_episodes + assert len(dataset.episode_indices) == num_episodes + + check_data_integrity(dataset._data, dataset.episode_indices) + + # check that the environment can be recovered from the dataset + check_env_recovery( + env.env, dataset + ) + + env.close() + # check load and delete local dataset + #check_load_and_delete_dataset(dataset_id) From 3e05c49db8c5b90d18d3662f83e528f02b828ee2 Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Wed, 30 Aug 2023 04:32:41 -0500 Subject: [PATCH 06/29] added tests for error in response to infos with timestep variant structure, and test of using StepDataCallback to fix it --- docs/content/dataset_standards.md | 2 +- minari/data_collector/data_collector.py | 50 +++++++++++++++++++ tests/common.py | 24 ++++++--- .../callbacks/test_step_data_callback.py | 45 ++++++++++++++--- 4 files changed, 107 insertions(+), 14 deletions(-) diff --git a/docs/content/dataset_standards.md b/docs/content/dataset_standards.md index e04cf594..b382a747 100644 --- a/docs/content/dataset_standards.md +++ b/docs/content/dataset_standards.md @@ -565,4 +565,4 @@ As mentioned in the `Supported Spaces` section, many different observation and a When creating a dataset with `DataCollectorV0` the additional information stored in the `infos` group of the hdf5 file must be provided to Minari as a dict, which can only contain other dictionaries or `np.ndarray` as values. An info dict must be provided with every observation(including the one from the initial reset), and the shape of each `np.ndarray` must stay the same across timesteps. -Since it is not guaranteed that all Gymnasium environments provide infos at every timestep, we provide a wrapper to allow for creating infos for environments which do not comply with the info format by default. +Since it is not guaranteed that all Gymnasium environments provide infos at every timestep, we provide the `StepDataCallback` which can modify the info's from a non-compliant environment so they are the same shape at every timestep. An example of this pattern is available in our test `test_data_collector_step_data_callback_info_correction` in test_step_data_callback.py. diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 80f893f4..d3ce8900 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -128,6 +128,9 @@ def __init__( ) self._record_infos = record_infos + if self._record_infos: + self._reference_info = None # initialized to None, determined + # from the info returned by the first call to `self.env.reset()` self.max_buffer_steps = max_buffer_steps # Initialzie empty buffer @@ -191,6 +194,16 @@ def step( terminated=terminated, truncated=truncated, ) + + if self._record_infos and not self.check_infos_same_shape( + self._reference_info, step_data["infos"] + ): + raise ValueError( + "Info structure inconsistent with info structure returned by original reset." + ) + + # Force step data dictionary to include keys corresponding to Gymnasium step returns: + # actions, observations, rewards, terminations, truncations, and infos assert STEP_DATA_KEYS.issubset( step_data.keys() ), "One or more required keys is missing from 'step-data'." @@ -253,6 +266,17 @@ def reset( step_data = self._step_data_callback(env=self.env, obs=obs, info=info) self._episode_id += 1 + if self._record_infos: + if self._reference_info is None: + self._reference_info = step_data["infos"] + else: + if not self.check_infos_same_shape( + self._reference_info, step_data["infos"] + ): + raise ValueError( + "Info structure inconsistent with info structure returned by original reset." + ) + assert STEP_DATA_KEYS.issubset( step_data.keys() ), "One or more required keys is missing from 'step-data'" @@ -409,6 +433,32 @@ def save_to_disk( env_spec=self.env.spec, ) + def check_infos_same_shape(self, info_1, info_2): + if len(info_1.keys()) != len(info_2.keys()): + return False + for key in info_1.keys(): + if isinstance(info_1[key], dict): + if not self.check_infos_same_shape(info_1[key], info_2[key]): + return False + elif isinstance(info_1[key], np.ndarray): + if not (info_1[key].shape == info_2[key].shape) and ( + info_1[key].dtype == info_2[key].dtype + ): + return False + elif np.issubdtype(type(info_1[key]), np.integer) and np.issubdtype( + type(info_2[key]), np.integer + ): + pass + elif np.issubdtype(type(info_1[key]), np.float) and np.issubdtype( + type(info_2[key]), np.float + ): + pass + else: + raise ValueError( + "Infos are in an unsupported format; see Minari documentation for supported formats." + ) + return True + def close(self): """Close the DataCollector. diff --git a/tests/common.py b/tests/common.py index 5667b44a..e55bfdb3 100644 --- a/tests/common.py +++ b/tests/common.py @@ -86,13 +86,21 @@ def step(self, action): terminated = self.timestep > 5 self.timestep += 1 - return self.observation_space.sample(), 0, terminated, False, {"timestep": self.timestep} if self.timestep %2 == 0 else {} - + return ( + self.observation_space.sample(), + 0, + terminated, + False, + {"timestep": self.timestep} if self.timestep % 2 == 0 else {}, + ) def reset(self, seed=None, options=None): self.timestep = 0 self.observation_space.seed(seed) - return self.observation_space.sample(), {"timestep": self.timestep} if self.timestep %2 == 0 else {} + return ( + self.observation_space.sample(), + {"timestep": self.timestep} if self.timestep % 2 == 0 else {}, + ) class DummyDictEnv(gym.Env): @@ -543,6 +551,8 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): def assert_infos_same_shape(info_1, info_2): + if len(info_1.keys()) != len(info_2.keys()): + return False for key in info_1.keys(): if isinstance(info_1[key], dict): if not assert_infos_same_shape(info_1[key], info_2[key]): @@ -697,10 +707,10 @@ def check_episode_data_integrity( for i in range(episode.total_timesteps + 1): obs = _reconstuct_obs_or_action_at_index_recursive(episode.observations, i) - - assert assert_infos_same_shape( - _get_step_from_infos_dict(episode.infos, i), info_sample - ) + if info_sample is not None: + assert assert_infos_same_shape( + _get_step_from_infos_dict(episode.infos, i), info_sample + ) assert observation_space.contains(obs) for i in range(episode.total_timesteps): diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index d16b0ecf..d955e410 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -1,5 +1,6 @@ import gymnasium as gym import numpy as np +import pytest from gymnasium import spaces import minari @@ -7,10 +8,10 @@ from minari.data_collector.callbacks import StepDataCallback from tests.common import ( check_data_integrity, + check_env_recovery, check_env_recovery_with_subset_spaces, check_load_and_delete_dataset, register_dummy_envs, - check_env_recovery, ) @@ -38,6 +39,14 @@ def __call__(self, env, **kwargs): return step_data +class CustomSubsetInfoPadStepDataCallback(StepDataCallback): + def __call__(self, env, **kwargs): + step_data = super().__call__(env, **kwargs) + if step_data["infos"] == {}: + step_data["infos"] = {"timestep": -1} + return step_data + + def test_data_collector_step_data_callback(): """Test DataCollector wrapper and Minari dataset creation.""" dataset_id = "dummy-dict-test-v0" @@ -125,7 +134,8 @@ def test_data_collector_step_data_callback_info_correction(): env = DataCollectorV0( env, - record_infos = True, + record_infos=True, + step_data_callback=CustomSubsetInfoPadStepDataCallback, ) num_episodes = 10 @@ -159,10 +169,33 @@ def test_data_collector_step_data_callback_info_correction(): check_data_integrity(dataset._data, dataset.episode_indices) # check that the environment can be recovered from the dataset - check_env_recovery( - env.env, dataset - ) + check_env_recovery(env.env, dataset) env.close() # check load and delete local dataset - #check_load_and_delete_dataset(dataset_id) + check_load_and_delete_dataset(dataset_id) + + env = gym.make("DummyTupleDiscreteBoxEnv-v0") + + env = DataCollectorV0( + env, + record_infos=True, + ) + # here we are checking to make sure that if we have an environment changing its info + # structure across timesteps, it is caught by the data_collector + with pytest.raises(ValueError): + + num_episodes = 10 + + # Step the environment, DataCollectorV0 wrapper will do the data collection job + env.reset(seed=42) + + for episode in range(num_episodes): + terminated = False + truncated = False + while not terminated and not truncated: + action = env.action_space.sample() # User-defined policy function + _, _, terminated, truncated, _ = env.step(action) + + env.reset() + env.close() From 214052d54d8dbf770af1a6b4d0c50ed60c39a55a Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Sun, 3 Sep 2023 01:41:30 -0500 Subject: [PATCH 07/29] added explicit np.array dtype support, documentation, and tests --- docs/content/dataset_standards.md | 5 +- minari/data_collector/data_collector.py | 11 +-- tests/common.py | 52 ++++++++---- .../callbacks/test_step_data_callback.py | 2 +- tests/utils/test_dataset_creation.py | 79 +++++++++++++++++-- 5 files changed, 115 insertions(+), 34 deletions(-) diff --git a/docs/content/dataset_standards.md b/docs/content/dataset_standards.md index b382a747..4400bc5a 100644 --- a/docs/content/dataset_standards.md +++ b/docs/content/dataset_standards.md @@ -562,7 +562,8 @@ As mentioned in the `Supported Spaces` section, many different observation and a +When creating a dataset with `DataCollectorV0`, the additional information stored in the `infos` group of the hdf5 file must be provided to Minari as a dict, which can only contain strings as keys and other dictionaries or `np.ndarray` as values. An info dict must be provided for every observation(including the one from the initial reset), and the shape of each `np.ndarray` must stay the same across timesteps, and the keys must remain the same in all `dicts` across timesteps. -When creating a dataset with `DataCollectorV0` the additional information stored in the `infos` group of the hdf5 file must be provided to Minari as a dict, which can only contain other dictionaries or `np.ndarray` as values. An info dict must be provided with every observation(including the one from the initial reset), and the shape of each `np.ndarray` must stay the same across timesteps. +Given that it is not guaranteed that all Gymnasium environments provide infos at every timestep, we provide the `StepDataCallback` which can modify the infos from a non-compliant environment so they have the same structure at every timestep. An example of this pattern is available in our test `test_data_collector_step_data_callback_info_correction` in test_step_data_callback.py. -Since it is not guaranteed that all Gymnasium environments provide infos at every timestep, we provide the `StepDataCallback` which can modify the info's from a non-compliant environment so they are the same shape at every timestep. An example of this pattern is available in our test `test_data_collector_step_data_callback_info_correction` in test_step_data_callback.py. +For `np.ndarray` typed arrays including in an info, we support the list of data data types supported by [h5py](https://docs.h5py.org/en/stable/faq.html#faq). We provide tests to guarantee support for the following numpy data types: `np.int8`,`np.int16`,`np.int32`,`np.int64`, `np.uint8`,`np.uint16`,`np.uint32`,`np.iunt64`,`np.float16`,`np.float32`,`np.float64`. In addition, the info values can contain `int`, or `float` types, and these will be promoted to `np.int64` and `np.float64` respectively. \ No newline at end of file diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index d3ce8900..ca9fafbc 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -433,6 +433,7 @@ def save_to_disk( env_spec=self.env.spec, ) + # This function is designed the same way as `assert_infos_same_shape` in tests/common.py, but is a class function so has a `self` argument. def check_infos_same_shape(self, info_1, info_2): if len(info_1.keys()) != len(info_2.keys()): return False @@ -445,17 +446,9 @@ def check_infos_same_shape(self, info_1, info_2): info_1[key].dtype == info_2[key].dtype ): return False - elif np.issubdtype(type(info_1[key]), np.integer) and np.issubdtype( - type(info_2[key]), np.integer - ): - pass - elif np.issubdtype(type(info_1[key]), np.float) and np.issubdtype( - type(info_2[key]), np.float - ): - pass else: raise ValueError( - "Infos are in an unsupported format; see Minari documentation for supported formats." + "Infos are in an unsupported format; see [Minari documentation](http://minari.farama.org/content/dataset_standards/) for supported formats." ) return True diff --git a/tests/common.py b/tests/common.py index e55bfdb3..1ff7416d 100644 --- a/tests/common.py +++ b/tests/common.py @@ -37,13 +37,33 @@ def step(self, action): 0, terminated, False, - {"timestep": self.timestep}, + {"timestep": np.array([self.timestep])}, ) def reset(self, seed=None, options=None): self.timestep = 0 self.observation_space.seed(seed) - return self.observation_space.sample(), {"timestep" : self.timestep} + return self.observation_space.sample(), {"timestep": np.array([self.timestep])} + + +# this returns whatever is set to `self.info` as the info, making it easy to create parameterized tests with a lot of different info datatypes. +class DummyMutableInfoBoxEnv(gym.Env): + def __init__(self): + self.action_space = spaces.Box(low=-1, high=4, shape=(2,), dtype=np.float32) + self.observation_space = spaces.Box( + low=-1, high=4, shape=(3,), dtype=np.float32 + ) + self.info = {} + + def step(self, action): + terminated = self.timestep > 5 + self.timestep += 1 + + return (self.observation_space.sample(), 0, terminated, False, self.info) + + def reset(self, seed=None, options=None): + self.timestep = 0 + return self.observation_space.sample(), self.info class DummyMultiDimensionalBoxEnv(gym.Env): @@ -91,7 +111,7 @@ def step(self, action): 0, terminated, False, - {"timestep": self.timestep} if self.timestep % 2 == 0 else {}, + {"timestep": np.array([self.timestep])} if self.timestep % 2 == 0 else {}, ) def reset(self, seed=None, options=None): @@ -99,7 +119,7 @@ def reset(self, seed=None, options=None): self.observation_space.seed(seed) return ( self.observation_space.sample(), - {"timestep": self.timestep} if self.timestep % 2 == 0 else {}, + {"timestep": np.array([self.timestep])} if self.timestep % 2 == 0 else {}, ) @@ -138,8 +158,8 @@ def step(self, action): terminated, False, { - "timestep": self.timestep, - "component_1": {"next_timestep": self.timestep + 1}, + "timestep": np.array([self.timestep]), + "component_1": {"next_timestep": np.array([self.timestep + 1])}, }, ) @@ -148,8 +168,8 @@ def reset(self, seed=None, options=None): self.observation_space.seed(seed) return self.observation_space.sample(), { - "timestep": self.timestep, - "component_1": {"next_timestep": self.timestep + 1}, + "timestep": np.array([self.timestep]), + "component_1": {"next_timestep": np.array([self.timestep + 1])}, } @@ -275,6 +295,12 @@ def register_dummy_envs(): max_episode_steps=5, ) + register( + id="DummyMutableInfoBoxEnv-v0", + entry_point="tests.common:DummyMutableInfoBoxEnv", + max_episode_steps=5, + ) + register( id="DummyMultiDimensionalBoxEnv-v0", entry_point="tests.common:DummyMultiDimensionalBoxEnv", @@ -562,17 +588,9 @@ def assert_infos_same_shape(info_1, info_2): info_1[key].dtype == info_2[key].dtype ): return False - elif np.issubdtype(type(info_1[key]), np.integer) and np.issubdtype( - type(info_2[key]), np.integer - ): - pass - elif np.issubdtype(type(info_1[key]), np.float) and np.issubdtype( - type(info_2[key]), np.float - ): - pass else: raise ValueError( - "Infos are in an unsupported format; see Minari documentation for supported formats." + "Infos are in an unsupported format; see [Minari documentation](http://minari.farama.org/content/dataset_standards/) for supported formats." ) return True diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index d955e410..26aa3269 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -43,7 +43,7 @@ class CustomSubsetInfoPadStepDataCallback(StepDataCallback): def __call__(self, env, **kwargs): step_data = super().__call__(env, **kwargs) if step_data["infos"] == {}: - step_data["infos"] = {"timestep": -1} + step_data["infos"] = {"timestep": np.array([-1])} return step_data diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index 9d930d1d..8975d512 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -98,14 +98,79 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): @pytest.mark.parametrize( - "dataset_id,env_id", + "dataset_id,env_id,info_override", [ - ("dummy-dict-test-v0", "DummyDictEnv-v0"), - ("dummy-box-test-v0", "DummyBoxEnv-v0"), - ("dummy-tuple-test-v0", "DummyTupleEnv-v0"), + ("dummy-dict-test-v0", "DummyDictEnv-v0", None), + ("dummy-box-test-v0", "DummyBoxEnv-v0", None), + ( + "dummy-mutable-info-box-test-v0", + "DummyMutableInfoBoxEnv-v0", + {"misc": np.ones((5, 5), np.int64)}, + ), + ( + "dummy-mutable-info-box-test-v0", + "DummyMutableInfoBoxEnv-v0", + {"misc": np.ones((5, 5), np.int32)}, + ), + ( + "dummy-mutable-info-box-test-v0", + "DummyMutableInfoBoxEnv-v0", + {"misc": np.ones((5, 5), np.int16)}, + ), + ( + "dummy-mutable-info-box-test-v0", + "DummyMutableInfoBoxEnv-v0", + {"misc": np.ones((5, 5), np.int8)}, + ), + ( + "dummy-mutable-info-box-test-v0", + "DummyMutableInfoBoxEnv-v0", + {"misc": np.ones((5, 5), np.uint64)}, + ), + ( + "dummy-mutable-info-box-test-v0", + "DummyMutableInfoBoxEnv-v0", + {"misc": np.ones((5, 5), np.uint32)}, + ), + ( + "dummy-mutable-info-box-test-v0", + "DummyMutableInfoBoxEnv-v0", + {"misc": np.ones((5, 5), np.uint16)}, + ), + ( + "dummy-mutable-info-box-test-v0", + "DummyMutableInfoBoxEnv-v0", + {"misc": np.ones((5, 5), np.uint8)}, + ), + ( + "dummy-mutable-info-box-test-v0", + "DummyMutableInfoBoxEnv-v0", + {"misc": np.ones((5, 5), np.float64)}, + ), + ( + "dummy-mutable-info-box-test-v0", + "DummyMutableInfoBoxEnv-v0", + {"misc": np.ones((5, 5), np.float32)}, + ), + ( + "dummy-mutable-info-box-test-v0", + "DummyMutableInfoBoxEnv-v0", + {"misc": np.ones((5, 5), np.float16)}, + ), + ( + "dummy-mutable-info-box-test-v0", + "DummyMutableInfoBoxEnv-v0", + {"misc": np.array([1])}, + ), + ( + "dummy-mutable-info-box-test-v0", + "DummyMutableInfoBoxEnv-v0", + {"misc": np.array([1])}, + ), + ("dummy-tuple-test-v0", "DummyTupleEnv-v0", None), ], ) -def test_generate_dataset_with_collector_env_infos(dataset_id, env_id): +def test_generate_dataset_with_collector_env_infos(dataset_id, env_id, info_override): """Test DataCollectorV0 wrapper and Minari dataset creation.""" # dataset_id = "cartpole-test-v0" # delete the test dataset if it already exists @@ -115,11 +180,15 @@ def test_generate_dataset_with_collector_env_infos(dataset_id, env_id): env = gym.make(env_id) + if env_id == "DummyMutableInfoBoxEnv-v0": + env.unwrapped.info = info_override + env = DataCollectorV0(env, record_infos=True) num_episodes = 10 # Step the environment, DataCollectorV0 wrapper will do the data collection job _, info_sample = env.reset(seed=42) + print(info_sample) for episode in range(num_episodes): terminated = False From 2693fd1e8bf8b4dca817e560c33c5cec45d5da9e Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Sun, 3 Sep 2023 20:29:46 -0500 Subject: [PATCH 08/29] updated doc page --- docs/content/dataset_standards.md | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/docs/content/dataset_standards.md b/docs/content/dataset_standards.md index 4400bc5a..81ea430c 100644 --- a/docs/content/dataset_standards.md +++ b/docs/content/dataset_standards.md @@ -562,8 +562,17 @@ As mentioned in the `Supported Spaces` section, many different observation and a -When creating a dataset with `DataCollectorV0`, the additional information stored in the `infos` group of the hdf5 file must be provided to Minari as a dict, which can only contain strings as keys and other dictionaries or `np.ndarray` as values. An info dict must be provided for every observation(including the one from the initial reset), and the shape of each `np.ndarray` must stay the same across timesteps, and the keys must remain the same in all `dicts` across timesteps. +When creating a dataset with `DataCollectorV0`, if the `DataCollectorV0` is initialized with `record_infos=True`, the additional information stored as step infos must be provided to Minari as a dict, which can only contain strings as keys and other dictionaries or `np.ndarray` as values. An info dict must be provided from every call to the wrapped evironment's `step` and `reset` function, and the shape of each `np.ndarray` must stay the same across timesteps, and the keys must remain the same in all `dicts` across timesteps. + +Here is an example of what a valid `info` might look like: + + +```python +info = {'value_1':np.array([1]), 'value_2': {"sub_value_1":np.asarray([[2.3],[4.5]])}} +``` + +Note that this shows how `infos` can be structured hierarchically, and that the nesting of dicts can go to arbitrary depth. Given that it is not guaranteed that all Gymnasium environments provide infos at every timestep, we provide the `StepDataCallback` which can modify the infos from a non-compliant environment so they have the same structure at every timestep. An example of this pattern is available in our test `test_data_collector_step_data_callback_info_correction` in test_step_data_callback.py. -For `np.ndarray` typed arrays including in an info, we support the list of data data types supported by [h5py](https://docs.h5py.org/en/stable/faq.html#faq). We provide tests to guarantee support for the following numpy data types: `np.int8`,`np.int16`,`np.int32`,`np.int64`, `np.uint8`,`np.uint16`,`np.uint32`,`np.iunt64`,`np.float16`,`np.float32`,`np.float64`. In addition, the info values can contain `int`, or `float` types, and these will be promoted to `np.int64` and `np.float64` respectively. \ No newline at end of file +We provide tests to guarantee support for the following `numpy.ndarray` data types: `np.int8`,`np.int16`,`np.int32`,`np.int64`, `np.uint8`,`np.uint16`,`np.uint32`,`np.iunt64`,`np.float16`,`np.float32`,`np.float64`. \ No newline at end of file From 961c534162a7d00bd8fae5879bbca6eb55eb1183 Mon Sep 17 00:00:00 2001 From: "John U. Balis" Date: Sun, 3 Sep 2023 20:31:11 -0500 Subject: [PATCH 09/29] table syntax change --- docs/content/dataset_standards.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/content/dataset_standards.md b/docs/content/dataset_standards.md index 81ea430c..32e676f9 100644 --- a/docs/content/dataset_standards.md +++ b/docs/content/dataset_standards.md @@ -554,7 +554,7 @@ The `sampled_episodes` variable will be a list of 10 `EpisodeData` elements, eac | `rewards` | `np.ndarray` | Rewards for each timestep. | | `terminations` | `np.ndarray` | Terminations for each timestep. | | `truncations` | `np.ndarray` | Truncations for each timestep. | -| `infos` | `dict` | A dictionary containing additional information. | +| `infos` | `dict` | A dictionary containing additional information. | As mentioned in the `Supported Spaces` section, many different observation and action spaces are supported so the data type for these fields are dependent on the environment being used. From a23d434002e7a62851f9e28be9150a1dbc251c4d Mon Sep 17 00:00:00 2001 From: rodrigodelazcano Date: Tue, 28 Nov 2023 12:49:59 -0500 Subject: [PATCH 10/29] remove print --- tests/utils/test_dataset_creation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index 8975d512..6afcd315 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -188,7 +188,6 @@ def test_generate_dataset_with_collector_env_infos(dataset_id, env_id, info_over # Step the environment, DataCollectorV0 wrapper will do the data collection job _, info_sample = env.reset(seed=42) - print(info_sample) for episode in range(num_episodes): terminated = False From 8dfe4989a0abd028732ba5e120252b013aa16778 Mon Sep 17 00:00:00 2001 From: rodrigodelazcano Date: Tue, 28 Nov 2023 18:46:14 -0500 Subject: [PATCH 11/29] DataCollectorV0 -> DataCollector --- .../callbacks/test_step_data_callback.py | 10 +++++----- tests/utils/test_dataset_creation.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index 26aa3269..0b03fba3 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -123,7 +123,7 @@ def test_data_collector_step_data_callback(): def test_data_collector_step_data_callback_info_correction(): - """Test DataCollectorV0 wrapper and Minari dataset creation.""" + """Test DataCollector wrapper and Minari dataset creation.""" dataset_id = "dummy-tuple-discrete-box-v0" # delete the test dataset if it already exists local_datasets = minari.list_local_datasets() @@ -132,14 +132,14 @@ def test_data_collector_step_data_callback_info_correction(): env = gym.make("DummyTupleDiscreteBoxEnv-v0") - env = DataCollectorV0( + env = DataCollector( env, record_infos=True, step_data_callback=CustomSubsetInfoPadStepDataCallback, ) num_episodes = 10 - # Step the environment, DataCollectorV0 wrapper will do the data collection job + # Step the environment, DataCollector wrapper will do the data collection job env.reset(seed=42) for episode in range(num_episodes): @@ -177,7 +177,7 @@ def test_data_collector_step_data_callback_info_correction(): env = gym.make("DummyTupleDiscreteBoxEnv-v0") - env = DataCollectorV0( + env = DataCollector( env, record_infos=True, ) @@ -187,7 +187,7 @@ def test_data_collector_step_data_callback_info_correction(): num_episodes = 10 - # Step the environment, DataCollectorV0 wrapper will do the data collection job + # Step the environment, DataCollector wrapper will do the data collection job env.reset(seed=42) for episode in range(num_episodes): diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index 6afcd315..a7f689bb 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -171,7 +171,7 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): ], ) def test_generate_dataset_with_collector_env_infos(dataset_id, env_id, info_override): - """Test DataCollectorV0 wrapper and Minari dataset creation.""" + """Test DataCollector wrapper and Minari dataset creation.""" # dataset_id = "cartpole-test-v0" # delete the test dataset if it already exists local_datasets = minari.list_local_datasets() @@ -183,10 +183,10 @@ def test_generate_dataset_with_collector_env_infos(dataset_id, env_id, info_over if env_id == "DummyMutableInfoBoxEnv-v0": env.unwrapped.info = info_override - env = DataCollectorV0(env, record_infos=True) + env = DataCollector(env, record_infos=True) num_episodes = 10 - # Step the environment, DataCollectorV0 wrapper will do the data collection job + # Step the environment, DataCollector wrapper will do the data collection job _, info_sample = env.reset(seed=42) for episode in range(num_episodes): From 02e8dd257d525072460ca067c9a5a58f765b1079 Mon Sep 17 00:00:00 2001 From: rodrigodelazcano Date: Wed, 29 Nov 2023 12:29:54 -0500 Subject: [PATCH 12/29] rename test --- docs/content/dataset_standards.md | 6 +++--- tests/utils/test_dataset_creation.py | 8 ++------ 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/docs/content/dataset_standards.md b/docs/content/dataset_standards.md index 32e676f9..c340efa2 100644 --- a/docs/content/dataset_standards.md +++ b/docs/content/dataset_standards.md @@ -562,7 +562,7 @@ As mentioned in the `Supported Spaces` section, many different observation and a -When creating a dataset with `DataCollectorV0`, if the `DataCollectorV0` is initialized with `record_infos=True`, the additional information stored as step infos must be provided to Minari as a dict, which can only contain strings as keys and other dictionaries or `np.ndarray` as values. An info dict must be provided from every call to the wrapped evironment's `step` and `reset` function, and the shape of each `np.ndarray` must stay the same across timesteps, and the keys must remain the same in all `dicts` across timesteps. +When creating a dataset with `DataCollectorV0`, if the `DataCollectorV0` is initialized with `record_infos=True`, the additional information stored as step infos must be provided to Minari as a dict, which can only contain strings as keys and other dictionaries or `np.ndarray` as values. An info dict must be provided from every call to the wrapped evironment's `step` and `reset` function, and the shape of each `np.ndarray` must stay the same across timesteps, and the keys must remain the same in all `dicts` across timesteps. Here is an example of what a valid `info` might look like: @@ -571,8 +571,8 @@ Here is an example of what a valid `info` might look like: info = {'value_1':np.array([1]), 'value_2': {"sub_value_1":np.asarray([[2.3],[4.5]])}} ``` -Note that this shows how `infos` can be structured hierarchically, and that the nesting of dicts can go to arbitrary depth. +Note that this shows how `infos` can be structured hierarchically, and that the nesting of dicts can go to arbitrary depth. Given that it is not guaranteed that all Gymnasium environments provide infos at every timestep, we provide the `StepDataCallback` which can modify the infos from a non-compliant environment so they have the same structure at every timestep. An example of this pattern is available in our test `test_data_collector_step_data_callback_info_correction` in test_step_data_callback.py. -We provide tests to guarantee support for the following `numpy.ndarray` data types: `np.int8`,`np.int16`,`np.int32`,`np.int64`, `np.uint8`,`np.uint16`,`np.uint32`,`np.iunt64`,`np.float16`,`np.float32`,`np.float64`. \ No newline at end of file +We provide tests to guarantee support for the following `numpy.ndarray` data types: `np.int8`,`np.int16`,`np.int32`,`np.int64`, `np.uint8`,`np.uint16`,`np.uint32`,`np.iunt64`,`np.float16`,`np.float32`,`np.float64`. diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index a7f689bb..70ca35c5 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -170,8 +170,8 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): ("dummy-tuple-test-v0", "DummyTupleEnv-v0", None), ], ) -def test_generate_dataset_with_collector_env_infos(dataset_id, env_id, info_override): - """Test DataCollector wrapper and Minari dataset creation.""" +def test_record_infos_collector_env(dataset_id, env_id, info_override): + """Test DataCollector wrapper and Minari dataset creation including infos.""" # dataset_id = "cartpole-test-v0" # delete the test dataset if it already exists local_datasets = minari.list_local_datasets() @@ -195,10 +195,6 @@ def test_generate_dataset_with_collector_env_infos(dataset_id, env_id, info_over while not terminated and not truncated: action = env.action_space.sample() # User-defined policy function _, _, terminated, truncated, _ = env.step(action) - if terminated or truncated: - assert not env._buffer[-1] - else: - assert env._buffer[-1] env.reset() From e84433bb3bd4dc126bf8e56f606f938f62699994 Mon Sep 17 00:00:00 2001 From: rodrigodelazcano Date: Wed, 29 Nov 2023 21:35:46 -0500 Subject: [PATCH 13/29] move info shape check to add to buffer --- minari/data_collector/data_collector.py | 27 +++++++++---------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index ca9fafbc..0e3e2f92 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -153,6 +153,14 @@ def _add_to_episode_buffer( Returns: Dict: new dictionary episode buffer with added values from step_data """ + + if self._record_infos and not self.check_infos_same_shape( + self._reference_info, step_data["infos"] + ): + raise ValueError( + "Info structure inconsistent with info structure returned by original reset." + ) + for key, value in step_data.items(): if (not self._record_infos and key == "infos") or (value is None): continue @@ -195,13 +203,6 @@ def step( truncated=truncated, ) - if self._record_infos and not self.check_infos_same_shape( - self._reference_info, step_data["infos"] - ): - raise ValueError( - "Info structure inconsistent with info structure returned by original reset." - ) - # Force step data dictionary to include keys corresponding to Gymnasium step returns: # actions, observations, rewards, terminations, truncations, and infos assert STEP_DATA_KEYS.issubset( @@ -266,16 +267,8 @@ def reset( step_data = self._step_data_callback(env=self.env, obs=obs, info=info) self._episode_id += 1 - if self._record_infos: - if self._reference_info is None: - self._reference_info = step_data["infos"] - else: - if not self.check_infos_same_shape( - self._reference_info, step_data["infos"] - ): - raise ValueError( - "Info structure inconsistent with info structure returned by original reset." - ) + if self._record_infos and self._reference_info is None: + self._reference_info = step_data["infos"] assert STEP_DATA_KEYS.issubset( step_data.keys() From fca36c19a034caece7981652a867f78f2f1a1095 Mon Sep 17 00:00:00 2001 From: rodrigodelazcano Date: Wed, 29 Nov 2023 22:08:21 -0500 Subject: [PATCH 14/29] add _get_info in dummy test envs --- tests/common.py | 56 +++++++++++++++++++++++++++---------------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/tests/common.py b/tests/common.py index 1ff7416d..b8413e91 100644 --- a/tests/common.py +++ b/tests/common.py @@ -28,6 +28,9 @@ def __init__(self): low=-1, high=4, shape=(3,), dtype=np.float32 ) + def _get_info(self): + return {"timestep": np.array([self.timestep])} + def step(self, action): terminated = self.timestep > 5 self.timestep += 1 @@ -37,13 +40,13 @@ def step(self, action): 0, terminated, False, - {"timestep": np.array([self.timestep])}, + self._get_info(), ) def reset(self, seed=None, options=None): self.timestep = 0 self.observation_space.seed(seed) - return self.observation_space.sample(), {"timestep": np.array([self.timestep])} + return self.observation_space.sample(), self._get_info() # this returns whatever is set to `self.info` as the info, making it easy to create parameterized tests with a lot of different info datatypes. @@ -55,15 +58,18 @@ def __init__(self): ) self.info = {} + def _get_info(self): + return self.info + def step(self, action): terminated = self.timestep > 5 self.timestep += 1 - return (self.observation_space.sample(), 0, terminated, False, self.info) + return (self.observation_space.sample(), 0, terminated, False, self._get_info()) def reset(self, seed=None, options=None): self.timestep = 0 - return self.observation_space.sample(), self.info + return self.observation_space.sample(), self._get_info() class DummyMultiDimensionalBoxEnv(gym.Env): @@ -102,6 +108,9 @@ def __init__(self): ) ) + def _get_info(self): + return {"timestep": np.array([self.timestep])} if self.timestep % 2 == 0 else {} + def step(self, action): terminated = self.timestep > 5 self.timestep += 1 @@ -111,7 +120,7 @@ def step(self, action): 0, terminated, False, - {"timestep": np.array([self.timestep])} if self.timestep % 2 == 0 else {}, + self._get_info(), ) def reset(self, seed=None, options=None): @@ -119,7 +128,7 @@ def reset(self, seed=None, options=None): self.observation_space.seed(seed) return ( self.observation_space.sample(), - {"timestep": np.array([self.timestep])} if self.timestep % 2 == 0 else {}, + self._get_info(), ) @@ -148,6 +157,12 @@ def __init__(self): } ) + def _get_info(self): + return { + "timestep": np.array([self.timestep]), + "component_1": {"next_timestep": np.array([self.timestep + 1])}, + } + def step(self, action): terminated = self.timestep > 5 self.timestep += 1 @@ -157,20 +172,14 @@ def step(self, action): 0, terminated, False, - { - "timestep": np.array([self.timestep]), - "component_1": {"next_timestep": np.array([self.timestep + 1])}, - }, + self._get_info(), ) def reset(self, seed=None, options=None): self.timestep = 0 self.observation_space.seed(seed) - return self.observation_space.sample(), { - "timestep": np.array([self.timestep]), - "component_1": {"next_timestep": np.array([self.timestep + 1])}, - } + return self.observation_space.sample(), self._get_info() class DummyTupleEnv(gym.Env): @@ -194,26 +203,23 @@ def __init__(self): ) ) - def step(self, action): - terminated = self.timestep > 5 - self.timestep += 1 - - info = { + def _get_info(self): + return { "info_1": np.ones((2, 2)), "component_1": {"component_1_info_1": np.ones((2,))}, } - return self.observation_space.sample(), 0, terminated, False, info + def step(self, action): + terminated = self.timestep > 5 + self.timestep += 1 + + return self.observation_space.sample(), 0, terminated, False, self._get_info() def reset(self, seed=None, options=None): self.timestep = 0 self.observation_space.seed(seed) - info = { - "info_1": np.ones((2, 2)), - "component_1": {"component_1_info_1": np.ones((2,))}, - } - return self.observation_space.sample(), info + return self.observation_space.sample(), self._get_info() class DummyTextEnv(gym.Env): From 701b8e1ade2d55436e043a5cc32d71bf9bd19732 Mon Sep 17 00:00:00 2001 From: rodrigodelazcano Date: Wed, 29 Nov 2023 22:15:09 -0500 Subject: [PATCH 15/29] _get_info_at_step_index --- tests/common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/common.py b/tests/common.py index b8413e91..e5a210a7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -601,11 +601,11 @@ def assert_infos_same_shape(info_1, info_2): return True -def _get_step_from_infos_dict(infos, step_index): +def _get_info_at_step_index(infos, step_index): result = {} for key in infos.keys(): if isinstance(infos[key], dict): - result[key] = _get_step_from_infos_dict(infos[key], step_index) + result[key] = _get_info_at_step_index(infos[key], step_index) elif isinstance(infos[key], np.ndarray): result[key] = infos[key][step_index] else: @@ -733,7 +733,7 @@ def check_episode_data_integrity( obs = _reconstuct_obs_or_action_at_index_recursive(episode.observations, i) if info_sample is not None: assert assert_infos_same_shape( - _get_step_from_infos_dict(episode.infos, i), info_sample + _get_info_at_step_index(episode.infos, i), info_sample ) assert observation_space.contains(obs) From cf35cd7dbf9e964ef41d6bc411d2c673b316e59e Mon Sep 17 00:00:00 2001 From: rodrigodelazcano Date: Thu, 7 Dec 2023 11:10:17 -0500 Subject: [PATCH 16/29] fix pre-commit --- minari/data_collector/data_collector.py | 1 - tests/common.py | 12 +++++------ tests/data_collector/test_data_collector.py | 23 ++++++--------------- 3 files changed, 12 insertions(+), 24 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 0e3e2f92..f354ee41 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -153,7 +153,6 @@ def _add_to_episode_buffer( Returns: Dict: new dictionary episode buffer with added values from step_data """ - if self._record_infos and not self.check_infos_same_shape( self._reference_info, step_data["infos"] ): diff --git a/tests/common.py b/tests/common.py index e5a210a7..ac87946e 100644 --- a/tests/common.py +++ b/tests/common.py @@ -582,12 +582,12 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): assert total_steps == data.total_steps -def assert_infos_same_shape(info_1, info_2): +def assert_infos_same_structure(info_1, info_2): if len(info_1.keys()) != len(info_2.keys()): return False for key in info_1.keys(): if isinstance(info_1[key], dict): - if not assert_infos_same_shape(info_1[key], info_2[key]): + if not assert_infos_same_structure(info_1[key], info_2[key]): return False elif isinstance(info_1[key], np.ndarray): if not (info_1[key].shape == info_2[key].shape) and ( @@ -601,11 +601,11 @@ def assert_infos_same_shape(info_1, info_2): return True -def _get_info_at_step_index(infos, step_index): +def get_info_at_step_index(infos, step_index): result = {} for key in infos.keys(): if isinstance(infos[key], dict): - result[key] = _get_info_at_step_index(infos[key], step_index) + result[key] = get_info_at_step_index(infos[key], step_index) elif isinstance(infos[key], np.ndarray): result[key] = infos[key][step_index] else: @@ -732,8 +732,8 @@ def check_episode_data_integrity( for i in range(episode.total_timesteps + 1): obs = _reconstuct_obs_or_action_at_index_recursive(episode.observations, i) if info_sample is not None: - assert assert_infos_same_shape( - _get_info_at_step_index(episode.infos, i), info_sample + assert assert_infos_same_structure( + get_info_at_step_index(episode.infos, i), info_sample ) assert observation_space.contains(obs) diff --git a/tests/data_collector/test_data_collector.py b/tests/data_collector/test_data_collector.py index e0c94dee..03c67f5b 100644 --- a/tests/data_collector/test_data_collector.py +++ b/tests/data_collector/test_data_collector.py @@ -1,10 +1,13 @@ import gymnasium as gym -import h5py import numpy as np import pytest from minari import DataCollector, EpisodeData, MinariDataset, StepDataCallback -from tests.common import check_load_and_delete_dataset, register_dummy_envs +from tests.common import ( + check_load_and_delete_dataset, + get_info_at_step_index, + register_dummy_envs, +) MAX_UINT64 = np.iinfo(np.uint64).max @@ -30,20 +33,6 @@ def __call__(self, env, **kwargs): return step_data -def _get_step_from_infos(infos, step_index: int): - result = {} - for key in infos.keys(): - if isinstance(infos[key], h5py.Group): - result[key] = _get_step_from_infos(infos[key], step_index) - elif isinstance(infos[key], h5py.Dataset): - result[key] = infos[key][step_index] - else: - raise ValueError( - "Infos are in an unsupported format; see Minari documentation for supported formats." - ) - return result - - def _get_step_from_dictionary_space(episode_data, step_index): step_data = {} assert isinstance(episode_data, dict) @@ -86,7 +75,7 @@ def get_single_step_from_episode(episode: EpisodeData, index: int) -> EpisodeDat else: action = episode.actions[index] - infos = _get_step_from_infos(episode.infos, index) + infos = get_info_at_step_index(episode.infos, index) step_data = { "id": episode.id, From 4cd3d4da1bd8fbbc4aadf6b7e73970d2ac3a3193 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 20 Jan 2024 03:37:30 +0100 Subject: [PATCH 17/29] fix tests --- minari/data_collector/data_collector.py | 65 +++++++++++++------------ 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index f354ee41..8c4d252d 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -7,7 +7,7 @@ import shutil import tempfile import warnings -from typing import Any, Callable, Dict, List, Optional, SupportsFloat, Type, Union +from typing import Any, Callable, Dict, List, Optional, SupportsFloat, Type import gymnasium as gym import numpy as np @@ -128,9 +128,7 @@ def __init__( ) self._record_infos = record_infos - if self._record_infos: - self._reference_info = None # initialized to None, determined - # from the info returned by the first call to `self.env.reset()` + self._reference_info = None self.max_buffer_steps = max_buffer_steps # Initialzie empty buffer @@ -139,11 +137,11 @@ def __init__( self._step_id = -1 self._episode_id = -1 - def _add_to_episode_buffer( + def _add_step_data( self, episode_buffer: EpisodeBuffer, - step_data: Union[StepData, Dict[str, StepData]], - ) -> EpisodeBuffer: + step_data: StepData, + ): """Add step data dictionary to episode buffer. Args: @@ -153,38 +151,41 @@ def _add_to_episode_buffer( Returns: Dict: new dictionary episode buffer with added values from step_data """ - if self._record_infos and not self.check_infos_same_shape( - self._reference_info, step_data["infos"] + dict_data = dict(step_data) + if not self._record_infos: + dict_data = {k: v for k, v in step_data.items() if k != "infos"} + elif not self.check_infos_same_shape( + self._reference_info, dict_data["infos"] ): raise ValueError( "Info structure inconsistent with info structure returned by original reset." ) + self._add_to_episode_buffer(episode_buffer, dict_data) + + def _add_to_episode_buffer( + self, + episode_buffer: EpisodeBuffer, + step_data: Dict[str, Any], + ): for key, value in step_data.items(): - if (not self._record_infos and key == "infos") or (value is None): + if value is None: continue if key not in episode_buffer: - if isinstance(value, dict): - episode_buffer[key] = self._add_to_episode_buffer({}, value) - else: - episode_buffer[key] = [value] + episode_buffer[key] = {} if isinstance(value, dict) else [] + + if isinstance(value, dict): + assert isinstance( + episode_buffer[key], dict + ), f"Element to be inserted is type 'dict', but buffer accepts type {type(episode_buffer[key])}" + + self._add_to_episode_buffer(episode_buffer[key], value) else: - if isinstance(value, dict): - assert isinstance( - episode_buffer[key], dict - ), f"Element to be inserted is type 'dict', but buffer accepts type {type(episode_buffer[key])}" - - episode_buffer[key] = self._add_to_episode_buffer( - episode_buffer[key], value - ) - else: - assert isinstance( - episode_buffer[key], list - ), f"Element to be inserted is type 'list', but buffer accepts type {type(episode_buffer[key])}" - episode_buffer[key].append(value) - - return episode_buffer + assert isinstance( + episode_buffer[key], list + ), f"Element to be inserted is type 'list', but buffer accepts type {type(episode_buffer[key])}" + episode_buffer[key].append(value) def step( self, action: ActType @@ -216,7 +217,7 @@ def step( ), "Actions are not in action space." self._step_id += 1 - self._buffer[-1] = self._add_to_episode_buffer(self._buffer[-1], step_data) + self._add_step_data(self._buffer[-1], step_data) if ( self.max_buffer_steps is not None @@ -232,7 +233,7 @@ def step( "observations": step_data["observations"], "infos": step_data["infos"], } - eps_buff = self._add_to_episode_buffer(eps_buff, previous_data) + self._add_step_data(eps_buff, previous_data) self._buffer.append(eps_buff) return obs, rew, terminated, truncated, info @@ -278,7 +279,7 @@ def reset( "seed": str(None) if seed is None else seed, "id": self._episode_id } - episode_buffer = self._add_to_episode_buffer(episode_buffer, step_data) + self._add_step_data(episode_buffer, step_data) self._buffer.append(episode_buffer) return obs, info From 14c1630484902a9744970c269d39c4b51eebda3b Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 20 Jan 2024 15:34:05 +0100 Subject: [PATCH 18/29] fix docs --- docs/content/basic_usage.md | 2 +- docs/content/dataset_standards.md | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/content/basic_usage.md b/docs/content/basic_usage.md index 6af27742..54dfabd3 100644 --- a/docs/content/basic_usage.md +++ b/docs/content/basic_usage.md @@ -125,7 +125,7 @@ env = gym.make('CartPole-v1') env = DataCollector(env, record_infos=True, max_buffer_steps=100000) total_episodes = 100 -dataset_name = "CartPole-v1-test-v0" +dataset_name = "cartpole-test-v0" dataset = None if dataset_name in minari.list_local_datasets(): dataset = minari.load_dataset(dataset_name) diff --git a/docs/content/dataset_standards.md b/docs/content/dataset_standards.md index c340efa2..5b4dfa3d 100644 --- a/docs/content/dataset_standards.md +++ b/docs/content/dataset_standards.md @@ -562,13 +562,13 @@ As mentioned in the `Supported Spaces` section, many different observation and a -When creating a dataset with `DataCollectorV0`, if the `DataCollectorV0` is initialized with `record_infos=True`, the additional information stored as step infos must be provided to Minari as a dict, which can only contain strings as keys and other dictionaries or `np.ndarray` as values. An info dict must be provided from every call to the wrapped evironment's `step` and `reset` function, and the shape of each `np.ndarray` must stay the same across timesteps, and the keys must remain the same in all `dicts` across timesteps. +When creating a dataset with `DataCollector`, if the `DataCollector` is initialized with `record_infos=True`, the additional information stored as step infos must be provided to Minari as a dict, which can only contain strings as keys and other dictionaries or `np.ndarray` as values. An info dict must be provided from every call to the wrapped evironment's `step` and `reset` function, and the shape of each `np.ndarray` must stay the same across timesteps, and the keys must remain the same in all `dicts` across timesteps. Here is an example of what a valid `info` might look like: - ```python -info = {'value_1':np.array([1]), 'value_2': {"sub_value_1":np.asarray([[2.3],[4.5]])}} +import numpy as np +info = {'value_1': np.array([1]), 'value_2': {"sub_value_1": np.array([[2.3], [4.5]])}} ``` Note that this shows how `infos` can be structured hierarchically, and that the nesting of dicts can go to arbitrary depth. From 4ed86121c21361fbe1b7c4db198ef81cb890454a Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 20 Jan 2024 15:49:54 +0100 Subject: [PATCH 19/29] fix pre-commit --- minari/data_collector/data_collector.py | 6 +++++- tests/data_collector/callbacks/test_step_data_callback.py | 2 +- tests/utils/test_dataset_creation.py | 6 ++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 8c4d252d..a78a174f 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -229,9 +229,13 @@ def step( if step_data["terminations"] or step_data["truncations"]: self._episode_id += 1 eps_buff = {"id": self._episode_id} - previous_data = { + previous_data: StepData = { "observations": step_data["observations"], "infos": step_data["infos"], + "rewards": None, + "actions": None, + "terminations": None, + "truncations": None } self._add_step_data(eps_buff, previous_data) self._buffer.append(eps_buff) diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index 0b03fba3..cf23515c 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -166,7 +166,7 @@ def test_data_collector_step_data_callback_info_correction(): assert dataset.spec.total_episodes == num_episodes assert len(dataset.episode_indices) == num_episodes - check_data_integrity(dataset._data, dataset.episode_indices) + check_data_integrity(dataset.storage, dataset.episode_indices) # check that the environment can be recovered from the dataset check_env_recovery(env.env, dataset) diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index 70ca35c5..ead59122 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -8,6 +8,7 @@ import minari from minari import DataCollector, MinariDataset from tests.common import ( + DummyMutableInfoBoxEnv, check_data_integrity, check_env_recovery, check_env_recovery_with_subset_spaces, @@ -83,7 +84,7 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): assert dataset.spec.total_episodes == num_episodes assert len(dataset.episode_indices) == num_episodes - check_data_integrity(dataset._data, dataset.episode_indices) + check_data_integrity(dataset.storage, dataset.episode_indices) check_episode_data_integrity( dataset, dataset.spec.observation_space, dataset.spec.action_space ) @@ -181,6 +182,7 @@ def test_record_infos_collector_env(dataset_id, env_id, info_override): env = gym.make(env_id) if env_id == "DummyMutableInfoBoxEnv-v0": + assert isinstance(env.unwrapped, DummyMutableInfoBoxEnv) env.unwrapped.info = info_override env = DataCollector(env, record_infos=True) @@ -213,7 +215,7 @@ def test_record_infos_collector_env(dataset_id, env_id, info_override): assert dataset.spec.total_episodes == num_episodes assert len(dataset.episode_indices) == num_episodes - check_data_integrity(dataset._data, dataset.episode_indices) + check_data_integrity(dataset.storage, dataset.episode_indices) check_episode_data_integrity( dataset, dataset.spec.observation_space, From ebbce779186a0f2cc7c3d2218c96e65482745550 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 21 Jan 2024 00:23:09 +0100 Subject: [PATCH 20/29] refactor --- minari/data_collector/data_collector.py | 46 +++++++++++-------------- tests/common.py | 22 ++---------- 2 files changed, 23 insertions(+), 45 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index a78a174f..623faeb0 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -154,12 +154,14 @@ def _add_step_data( dict_data = dict(step_data) if not self._record_infos: dict_data = {k: v for k, v in step_data.items() if k != "infos"} - elif not self.check_infos_same_shape( - self._reference_info, dict_data["infos"] - ): - raise ValueError( - "Info structure inconsistent with info structure returned by original reset." - ) + else: + assert self._reference_info is not None + if not check_infos_same_shape( + self._reference_info, step_data["infos"] + ): + raise ValueError( + "Info structure inconsistent with info structure returned by original reset." + ) self._add_to_episode_buffer(episode_buffer, dict_data) @@ -430,25 +432,6 @@ def save_to_disk( env_spec=self.env.spec, ) - # This function is designed the same way as `assert_infos_same_shape` in tests/common.py, but is a class function so has a `self` argument. - def check_infos_same_shape(self, info_1, info_2): - if len(info_1.keys()) != len(info_2.keys()): - return False - for key in info_1.keys(): - if isinstance(info_1[key], dict): - if not self.check_infos_same_shape(info_1[key], info_2[key]): - return False - elif isinstance(info_1[key], np.ndarray): - if not (info_1[key].shape == info_2[key].shape) and ( - info_1[key].dtype == info_2[key].dtype - ): - return False - else: - raise ValueError( - "Infos are in an unsupported format; see [Minari documentation](http://minari.farama.org/content/dataset_standards/) for supported formats." - ) - return True - def close(self): """Close the DataCollector. @@ -458,3 +441,16 @@ def close(self): self._buffer.clear() shutil.rmtree(self._tmp_dir.name) + + +def check_infos_same_shape(info_1: dict, info_2: dict): + if info_1.keys() != info_2.keys(): + return False + for key in info_1.keys(): + if type(info_1[key]) is not type(info_2[key]): + return False + if isinstance(info_1[key], dict): + return check_infos_same_shape(info_1[key], info_2[key]) + elif isinstance(info_1[key], np.ndarray): + return (info_1[key].shape == info_2[key].shape) and (info_1[key].dtype == info_2[key].dtype) + return True diff --git a/tests/common.py b/tests/common.py index ac87946e..27519592 100644 --- a/tests/common.py +++ b/tests/common.py @@ -12,6 +12,7 @@ import minari from minari import DataCollector, MinariDataset +from minari.data_collector.data_collector import check_infos_same_shape from minari.dataset.minari_dataset import EpisodeData from minari.dataset.minari_storage import MinariStorage @@ -582,25 +583,6 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): assert total_steps == data.total_steps -def assert_infos_same_structure(info_1, info_2): - if len(info_1.keys()) != len(info_2.keys()): - return False - for key in info_1.keys(): - if isinstance(info_1[key], dict): - if not assert_infos_same_structure(info_1[key], info_2[key]): - return False - elif isinstance(info_1[key], np.ndarray): - if not (info_1[key].shape == info_2[key].shape) and ( - info_1[key].dtype == info_2[key].dtype - ): - return False - else: - raise ValueError( - "Infos are in an unsupported format; see [Minari documentation](http://minari.farama.org/content/dataset_standards/) for supported formats." - ) - return True - - def get_info_at_step_index(infos, step_index): result = {} for key in infos.keys(): @@ -732,7 +714,7 @@ def check_episode_data_integrity( for i in range(episode.total_timesteps + 1): obs = _reconstuct_obs_or_action_at_index_recursive(episode.observations, i) if info_sample is not None: - assert assert_infos_same_structure( + assert check_infos_same_shape( get_info_at_step_index(episode.infos, i), info_sample ) assert observation_space.contains(obs) From d39aacae81016c7f7a0690a80c91c55b882b5efa Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 21 Jan 2024 02:11:26 +0100 Subject: [PATCH 21/29] simplify tests --- docs/content/dataset_standards.md | 15 +--- minari/data_collector/data_collector.py | 6 +- minari/dataset/minari_storage.py | 5 +- tests/common.py | 32 +++++--- tests/utils/test_dataset_creation.py | 101 ++++-------------------- 5 files changed, 46 insertions(+), 113 deletions(-) diff --git a/docs/content/dataset_standards.md b/docs/content/dataset_standards.md index 5b4dfa3d..a2af8ee5 100644 --- a/docs/content/dataset_standards.md +++ b/docs/content/dataset_standards.md @@ -560,19 +560,6 @@ As mentioned in the `Supported Spaces` section, many different observation and a ## Additional Information Formatting - - -When creating a dataset with `DataCollector`, if the `DataCollector` is initialized with `record_infos=True`, the additional information stored as step infos must be provided to Minari as a dict, which can only contain strings as keys and other dictionaries or `np.ndarray` as values. An info dict must be provided from every call to the wrapped evironment's `step` and `reset` function, and the shape of each `np.ndarray` must stay the same across timesteps, and the keys must remain the same in all `dicts` across timesteps. - -Here is an example of what a valid `info` might look like: - -```python -import numpy as np -info = {'value_1': np.array([1]), 'value_2': {"sub_value_1": np.array([[2.3], [4.5]])}} -``` - -Note that this shows how `infos` can be structured hierarchically, and that the nesting of dicts can go to arbitrary depth. +When creating a dataset with `DataCollector`, if the `DataCollector` is initialized with `record_infos=True`, an info dict must be provided from every call to the environment's `step` and `reset` function. The structure of the info dictionary must be the same across timesteps. Given that it is not guaranteed that all Gymnasium environments provide infos at every timestep, we provide the `StepDataCallback` which can modify the infos from a non-compliant environment so they have the same structure at every timestep. An example of this pattern is available in our test `test_data_collector_step_data_callback_info_correction` in test_step_data_callback.py. - -We provide tests to guarantee support for the following `numpy.ndarray` data types: `np.int8`,`np.int16`,`np.int32`,`np.int64`, `np.uint8`,`np.uint16`,`np.uint32`,`np.iunt64`,`np.float16`,`np.float32`,`np.float64`. diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 623faeb0..bf3e820f 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -156,7 +156,7 @@ def _add_step_data( dict_data = {k: v for k, v in step_data.items() if k != "infos"} else: assert self._reference_info is not None - if not check_infos_same_shape( + if not _check_infos_same_shape( self._reference_info, step_data["infos"] ): raise ValueError( @@ -443,14 +443,14 @@ def close(self): shutil.rmtree(self._tmp_dir.name) -def check_infos_same_shape(info_1: dict, info_2: dict): +def _check_infos_same_shape(info_1: dict, info_2: dict): if info_1.keys() != info_2.keys(): return False for key in info_1.keys(): if type(info_1[key]) is not type(info_2[key]): return False if isinstance(info_1[key], dict): - return check_infos_same_shape(info_1[key], info_2[key]) + return _check_infos_same_shape(info_1[key], info_2[key]) elif isinstance(info_1[key], np.ndarray): return (info_1[key].shape == info_2[key].shape) and (info_1[key].dtype == info_2[key].dtype) return True diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index ec32bb61..b076d0e4 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -88,7 +88,10 @@ def new( obj._action_space = action_space if env_spec is not None: - metadata["env_spec"] = env_spec.to_json() + try: + metadata["env_spec"] = env_spec.to_json() + except TypeError: + pass with h5py.File(obj._file_path, "a") as file: file.attrs.update(metadata) return obj diff --git a/tests/common.py b/tests/common.py index 27519592..88b82699 100644 --- a/tests/common.py +++ b/tests/common.py @@ -12,7 +12,6 @@ import minari from minari import DataCollector, MinariDataset -from minari.data_collector.data_collector import check_infos_same_shape from minari.dataset.minari_dataset import EpisodeData from minari.dataset.minari_storage import MinariStorage @@ -51,13 +50,13 @@ def reset(self, seed=None, options=None): # this returns whatever is set to `self.info` as the info, making it easy to create parameterized tests with a lot of different info datatypes. -class DummyMutableInfoBoxEnv(gym.Env): - def __init__(self): +class DummyInfoBoxEnv(gym.Env): + def __init__(self, info = None): self.action_space = spaces.Box(low=-1, high=4, shape=(2,), dtype=np.float32) self.observation_space = spaces.Box( low=-1, high=4, shape=(3,), dtype=np.float32 ) - self.info = {} + self.info = info if info is not None else {} def _get_info(self): return self.info @@ -66,7 +65,7 @@ def step(self, action): terminated = self.timestep > 5 self.timestep += 1 - return (self.observation_space.sample(), 0, terminated, False, self._get_info()) + return self.observation_space.sample(), 0, terminated, False, self._get_info() def reset(self, seed=None, options=None): self.timestep = 0 @@ -303,8 +302,8 @@ def register_dummy_envs(): ) register( - id="DummyMutableInfoBoxEnv-v0", - entry_point="tests.common:DummyMutableInfoBoxEnv", + id="DummyInfoBoxEnv-v0", + entry_point="tests.common:DummyInfoBoxEnv", max_episode_steps=5, ) @@ -714,9 +713,11 @@ def check_episode_data_integrity( for i in range(episode.total_timesteps + 1): obs = _reconstuct_obs_or_action_at_index_recursive(episode.observations, i) if info_sample is not None: - assert check_infos_same_shape( - get_info_at_step_index(episode.infos, i), info_sample + assert check_infos_equal( + get_info_at_step_index(episode.infos, i), + info_sample ) + assert observation_space.contains(obs) for i in range(episode.total_timesteps): @@ -728,6 +729,19 @@ def check_episode_data_integrity( assert episode.total_timesteps == len(episode.truncations) +def check_infos_equal(info_1: dict, info_2: dict): + if info_1.keys() != info_2.keys(): + return False + for key in info_1.keys(): + if isinstance(info_1[key], dict): + return check_infos_equal(info_1[key], info_2[key]) + elif isinstance(info_1[key], np.ndarray): + return np.all(info_1[key] == info_2[key]) + else: + return info_1[key] == info_2[key] + return True + + def _space_subset_helper(entry: Dict): return OrderedDict( diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index ead59122..b26f29ef 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -8,7 +8,6 @@ import minari from minari import DataCollector, MinariDataset from tests.common import ( - DummyMutableInfoBoxEnv, check_data_integrity, check_env_recovery, check_env_recovery_with_subset_spaces, @@ -99,91 +98,24 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): @pytest.mark.parametrize( - "dataset_id,env_id,info_override", + "info_override", [ - ("dummy-dict-test-v0", "DummyDictEnv-v0", None), - ("dummy-box-test-v0", "DummyBoxEnv-v0", None), - ( - "dummy-mutable-info-box-test-v0", - "DummyMutableInfoBoxEnv-v0", - {"misc": np.ones((5, 5), np.int64)}, - ), - ( - "dummy-mutable-info-box-test-v0", - "DummyMutableInfoBoxEnv-v0", - {"misc": np.ones((5, 5), np.int32)}, - ), - ( - "dummy-mutable-info-box-test-v0", - "DummyMutableInfoBoxEnv-v0", - {"misc": np.ones((5, 5), np.int16)}, - ), - ( - "dummy-mutable-info-box-test-v0", - "DummyMutableInfoBoxEnv-v0", - {"misc": np.ones((5, 5), np.int8)}, - ), - ( - "dummy-mutable-info-box-test-v0", - "DummyMutableInfoBoxEnv-v0", - {"misc": np.ones((5, 5), np.uint64)}, - ), - ( - "dummy-mutable-info-box-test-v0", - "DummyMutableInfoBoxEnv-v0", - {"misc": np.ones((5, 5), np.uint32)}, - ), - ( - "dummy-mutable-info-box-test-v0", - "DummyMutableInfoBoxEnv-v0", - {"misc": np.ones((5, 5), np.uint16)}, - ), - ( - "dummy-mutable-info-box-test-v0", - "DummyMutableInfoBoxEnv-v0", - {"misc": np.ones((5, 5), np.uint8)}, - ), - ( - "dummy-mutable-info-box-test-v0", - "DummyMutableInfoBoxEnv-v0", - {"misc": np.ones((5, 5), np.float64)}, - ), - ( - "dummy-mutable-info-box-test-v0", - "DummyMutableInfoBoxEnv-v0", - {"misc": np.ones((5, 5), np.float32)}, - ), - ( - "dummy-mutable-info-box-test-v0", - "DummyMutableInfoBoxEnv-v0", - {"misc": np.ones((5, 5), np.float16)}, - ), - ( - "dummy-mutable-info-box-test-v0", - "DummyMutableInfoBoxEnv-v0", - {"misc": np.array([1])}, - ), - ( - "dummy-mutable-info-box-test-v0", - "DummyMutableInfoBoxEnv-v0", - {"misc": np.array([1])}, - ), - ("dummy-tuple-test-v0", "DummyTupleEnv-v0", None), + None, {}, {"foo": np.ones((10, 10), dtype=np.float32)}, + {"int": 1}, {"bool": False}, + { + "value1": True, + "value2": 5, + "value3": { + "nested1": False, + "nested2": np.empty(10) + } + }, ], ) -def test_record_infos_collector_env(dataset_id, env_id, info_override): +def test_record_infos_collector_env(info_override): """Test DataCollector wrapper and Minari dataset creation including infos.""" - # dataset_id = "cartpole-test-v0" - # delete the test dataset if it already exists - local_datasets = minari.list_local_datasets() - if dataset_id in local_datasets: - minari.delete_dataset(dataset_id) - - env = gym.make(env_id) - - if env_id == "DummyMutableInfoBoxEnv-v0": - assert isinstance(env.unwrapped, DummyMutableInfoBoxEnv) - env.unwrapped.info = info_override + dataset_id = "dummy-mutable-info-box-test-v0" + env = gym.make("DummyInfoBoxEnv-v0", info=info_override) env = DataCollector(env, record_infos=True) num_episodes = 10 @@ -205,7 +137,7 @@ def test_record_infos_collector_env(dataset_id, env_id, info_override): dataset_id=dataset_id, collector_env=env, algorithm_name="random_policy", - code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", + code_permalink=CODELINK, author="WillDudley", author_email="wdudley@farama.org", ) @@ -223,9 +155,6 @@ def test_record_infos_collector_env(dataset_id, env_id, info_override): info_sample=info_sample, ) - # check that the environment can be recovered from the dataset - check_env_recovery(env.env, dataset) - env.close() check_load_and_delete_dataset(dataset_id) From 0983ec2cc32337ad6e4a79b87bbad853d5915ba0 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 21 Jan 2024 02:16:54 +0100 Subject: [PATCH 22/29] fix pre-commit --- tests/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/common.py b/tests/common.py index 88b82699..d2ac2436 100644 --- a/tests/common.py +++ b/tests/common.py @@ -51,7 +51,7 @@ def reset(self, seed=None, options=None): # this returns whatever is set to `self.info` as the info, making it easy to create parameterized tests with a lot of different info datatypes. class DummyInfoBoxEnv(gym.Env): - def __init__(self, info = None): + def __init__(self, info=None): self.action_space = spaces.Box(low=-1, high=4, shape=(2,), dtype=np.float32) self.observation_space = spaces.Box( low=-1, high=4, shape=(3,), dtype=np.float32 From dd252e992fe43f32d0cdf69238dbfcde3f6aaf91 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 21 Jan 2024 13:41:25 +0100 Subject: [PATCH 23/29] remove redundant comments --- tests/common.py | 1 - .../callbacks/test_step_data_callback.py | 39 +++++-------------- 2 files changed, 10 insertions(+), 30 deletions(-) diff --git a/tests/common.py b/tests/common.py index d2ac2436..a287ecf2 100644 --- a/tests/common.py +++ b/tests/common.py @@ -49,7 +49,6 @@ def reset(self, seed=None, options=None): return self.observation_space.sample(), self._get_info() -# this returns whatever is set to `self.info` as the info, making it easy to create parameterized tests with a lot of different info datatypes. class DummyInfoBoxEnv(gym.Env): def __init__(self, info=None): self.action_space = spaces.Box(low=-1, high=4, shape=(2,), dtype=np.float32) diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index cf23515c..31831b8d 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -50,10 +50,6 @@ def __call__(self, env, **kwargs): def test_data_collector_step_data_callback(): """Test DataCollector wrapper and Minari dataset creation.""" dataset_id = "dummy-dict-test-v0" - # delete the test dataset if it already exists - local_datasets = minari.list_local_datasets() - if dataset_id in local_datasets: - minari.delete_dataset(dataset_id) env = gym.make("DummyDictEnv-v0") @@ -84,23 +80,20 @@ def test_data_collector_step_data_callback(): ) num_episodes = 10 - # Step the environment, DataCollector wrapper will do the data collection job env.reset(seed=42) - for episode in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() _, _, terminated, truncated, _ = env.step(action) env.reset() - # Create Minari dataset and store locally dataset = env.create_dataset( dataset_id=dataset_id, algorithm_name="random_policy", - code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", + code_permalink=str(__file__), author="WillDudley", author_email="wdudley@farama.org", ) @@ -112,24 +105,17 @@ def test_data_collector_step_data_callback(): check_data_integrity(dataset.storage, dataset.episode_indices) - # check that the environment can be recovered from the dataset check_env_recovery_with_subset_spaces( env.env, dataset, action_space_subset, observation_space_subset ) env.close() - # check load and delete local dataset check_load_and_delete_dataset(dataset_id) def test_data_collector_step_data_callback_info_correction(): """Test DataCollector wrapper and Minari dataset creation.""" dataset_id = "dummy-tuple-discrete-box-v0" - # delete the test dataset if it already exists - local_datasets = minari.list_local_datasets() - if dataset_id in local_datasets: - minari.delete_dataset(dataset_id) - env = gym.make("DummyTupleDiscreteBoxEnv-v0") env = DataCollector( @@ -139,24 +125,21 @@ def test_data_collector_step_data_callback_info_correction(): ) num_episodes = 10 - # Step the environment, DataCollector wrapper will do the data collection job env.reset(seed=42) - for episode in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() _, _, terminated, truncated, _ = env.step(action) env.reset() - # Create Minari dataset and store locally dataset = minari.create_dataset_from_collector_env( dataset_id=dataset_id, collector_env=env, algorithm_name="random_policy", - code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", + code_permalink=str(__file__), author="WillDudley", author_email="wdudley@farama.org", ) @@ -168,11 +151,9 @@ def test_data_collector_step_data_callback_info_correction(): check_data_integrity(dataset.storage, dataset.episode_indices) - # check that the environment can be recovered from the dataset check_env_recovery(env.env, dataset) env.close() - # check load and delete local dataset check_load_and_delete_dataset(dataset_id) env = gym.make("DummyTupleDiscreteBoxEnv-v0") @@ -183,18 +164,18 @@ def test_data_collector_step_data_callback_info_correction(): ) # here we are checking to make sure that if we have an environment changing its info # structure across timesteps, it is caught by the data_collector - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=r"Info structure inconsistent with info structure returned by original reset." + ): num_episodes = 10 - - # Step the environment, DataCollector wrapper will do the data collection job env.reset(seed=42) - - for episode in range(num_episodes): + for _ in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() _, _, terminated, truncated, _ = env.step(action) env.reset() From 7661f20095b019704d0cbb15dcac79fe4262dc0e Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 21 Jan 2024 13:48:18 +0100 Subject: [PATCH 24/29] fix pre-commit --- tests/data_collector/callbacks/test_step_data_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index 31831b8d..ccfe5adc 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -85,7 +85,7 @@ def test_data_collector_step_data_callback(): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() + action = env.action_space.sample() _, _, terminated, truncated, _ = env.step(action) env.reset() From d42879eb9fed1876cf66ccedd6f32ceaeec766aa Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 22 Jan 2024 18:08:43 +0100 Subject: [PATCH 25/29] fixes --- minari/data_collector/data_collector.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index bf3e820f..28310e91 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -7,7 +7,7 @@ import shutil import tempfile import warnings -from typing import Any, Callable, Dict, List, Optional, SupportsFloat, Type +from typing import Any, Callable, Dict, List, Optional, SupportsFloat, Type, Union import gymnasium as gym import numpy as np @@ -140,7 +140,7 @@ def __init__( def _add_step_data( self, episode_buffer: EpisodeBuffer, - step_data: StepData, + step_data: Union[StepData, Dict], ): """Add step data dictionary to episode buffer. @@ -231,13 +231,9 @@ def step( if step_data["terminations"] or step_data["truncations"]: self._episode_id += 1 eps_buff = {"id": self._episode_id} - previous_data: StepData = { + previous_data = { "observations": step_data["observations"], "infos": step_data["infos"], - "rewards": None, - "actions": None, - "terminations": None, - "truncations": None } self._add_step_data(eps_buff, previous_data) self._buffer.append(eps_buff) @@ -265,7 +261,7 @@ def reset( observation (ObsType): Observation of the initial state. info (dictionary): Auxiliary information complementing ``observation``. """ - autoseed_enabled = (not options) or options.get("minari_autoseed", False) + autoseed_enabled = (not options) or options.get("minari_autoseed", True) if seed is None and autoseed_enabled: seed = secrets.randbits(AUTOSEED_BIT_SIZE) From 348513e2a3949580d5d1116f9370e89d24c85a57 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 25 Jan 2024 15:20:40 +0100 Subject: [PATCH 26/29] fix basic_usage --- docs/content/basic_usage.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/content/basic_usage.md b/docs/content/basic_usage.md index 54dfabd3..a6f9f558 100644 --- a/docs/content/basic_usage.md +++ b/docs/content/basic_usage.md @@ -77,7 +77,7 @@ for _ in range(total_episodes): if terminated or truncated: break -dataset = env.create_dataset(dataset_id="CartPole-v1-test-v0", +dataset = env.create_dataset(dataset_id="cartpole-test-v0", algorithm_name="Random-Policy", code_permalink="https://github.com/Farama-Foundation/Minari", author="Farama", @@ -96,7 +96,7 @@ Once the dataset has been created we can check if the Minari dataset id appears >>> import minari >>> local_datasets = minari.list_local_datasets() >>> local_datasets.keys() -dict_keys(['CartPole-v1-test-v0']) +dict_keys(['cartpole-test-v0']) ``` ```{eval-rst} @@ -161,9 +161,9 @@ Minari will only be able to load datasets that are stored in your `local root di ```python >>> import minari ->>> dataset = minari.load_dataset('CartPole-v1-test-v0') +>>> dataset = minari.load_dataset('cartpole-test-v0') >>> dataset.name -'CartPole-v1-test-v0' +'cartpole-test-v0' ``` ### Download Remote Datasets @@ -323,7 +323,7 @@ From a :class:`minari.MinariDataset` object we can also recover the Gymnasium en ```python import minari -dataset = minari.load_dataset('CartPole-v1-test-v0') +dataset = minari.load_dataset('cartpole-test-v0') env = dataset.recover_environment() env.reset() From 5dd2de5cef2d218debcfb229346724322f7c0e9e Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 25 Jan 2024 15:23:16 +0100 Subject: [PATCH 27/29] fix episode_data repr --- minari/dataset/episode_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/minari/dataset/episode_data.py b/minari/dataset/episode_data.py index 797eed87..8d32d754 100644 --- a/minari/dataset/episode_data.py +++ b/minari/dataset/episode_data.py @@ -32,7 +32,7 @@ def __repr__(self) -> str: f"rewards=ndarray of {len(self.rewards)} floats, " f"terminations=ndarray of {len(self.terminations)} bools, " f"truncations=ndarray of {len(self.truncations)} bools, " - f"infos=dict with keys of :{list(self.infos.keys())}" + f"infos=dict with the following keys: {list(self.infos.keys())}" ")" ) From 0da6a84d8774075ea7798b5024dc75a3afa39f35 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 25 Jan 2024 15:32:51 +0100 Subject: [PATCH 28/29] fix common --- tests/common.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/common.py b/tests/common.py index a287ecf2..3e7e26ba 100644 --- a/tests/common.py +++ b/tests/common.py @@ -581,7 +581,7 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): assert total_steps == data.total_steps -def get_info_at_step_index(infos, step_index): +def get_info_at_step_index(infos: Dict, step_index: int) -> Dict: result = {} for key in infos.keys(): if isinstance(infos[key], dict): @@ -696,9 +696,10 @@ def check_episode_data_integrity( Args: episode_data_list (List[EpisodeData]): A list of EpisodeData instances representing episodes. - observation_space(gym.spaces.Space): The environment's observation space. - action_space(gym.spaces.Space): The environment's action space. - info_sample(dict): An info returned by the environment used to build the dataset. + observation_space (gym.spaces.Space): The environment's observation space. + action_space (gym.spaces.Space): The environment's action space. + info_sample (dict): An info returned by the environment used to build the dataset. + """ # verify the actions and observations are in the appropriate action space and observation space, and that the episode lengths are correct for episode in episode_data_list: @@ -728,14 +729,14 @@ def check_episode_data_integrity( assert episode.total_timesteps == len(episode.truncations) -def check_infos_equal(info_1: dict, info_2: dict): +def check_infos_equal(info_1: Dict, info_2: Dict) -> bool: if info_1.keys() != info_2.keys(): return False for key in info_1.keys(): if isinstance(info_1[key], dict): return check_infos_equal(info_1[key], info_2[key]) elif isinstance(info_1[key], np.ndarray): - return np.all(info_1[key] == info_2[key]) + return bool(np.all(info_1[key] == info_2[key])) else: return info_1[key] == info_2[key] return True @@ -752,7 +753,7 @@ def _space_subset_helper(entry: Dict): ) -def get_sample_buffer_for_dataset_from_env(env, num_episodes=10): +def get_sample_buffer_for_dataset_from_env(env: gym.Env, num_episodes: int = 10): buffer = [] observations = [] @@ -771,7 +772,7 @@ def get_sample_buffer_for_dataset_from_env(env, num_episodes=10): truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() observation, reward, terminated, truncated, _ = env.step(action) observations.append(_space_subset_helper(observation)) actions.append(_space_subset_helper(action)) From 9578d84f79775cf5ba278d0430da95c974d5dc46 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 25 Jan 2024 17:02:50 +0100 Subject: [PATCH 29/29] improe tests --- tests/common.py | 31 ++++++++++--------- .../callbacks/test_step_data_callback.py | 6 ++-- tests/data_collector/test_data_collector.py | 26 +++++++++------- tests/utils/test_dataset_creation.py | 8 ++--- 4 files changed, 36 insertions(+), 35 deletions(-) diff --git a/tests/common.py b/tests/common.py index 3e7e26ba..8c4a99c7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -49,26 +49,21 @@ def reset(self, seed=None, options=None): return self.observation_space.sample(), self._get_info() -class DummyInfoBoxEnv(gym.Env): +class DummyInfoEnv(DummyBoxEnv): def __init__(self, info=None): - self.action_space = spaces.Box(low=-1, high=4, shape=(2,), dtype=np.float32) - self.observation_space = spaces.Box( - low=-1, high=4, shape=(3,), dtype=np.float32 - ) + super().__init__() self.info = info if info is not None else {} def _get_info(self): return self.info - def step(self, action): - terminated = self.timestep > 5 - self.timestep += 1 - return self.observation_space.sample(), 0, terminated, False, self._get_info() +class DummyInconsistentInfoEnv(DummyBoxEnv): + def __init__(self): + super().__init__() - def reset(self, seed=None, options=None): - self.timestep = 0 - return self.observation_space.sample(), self._get_info() + def _get_info(self): + return super()._get_info() if self.timestep % 2 == 0 else {} class DummyMultiDimensionalBoxEnv(gym.Env): @@ -108,7 +103,7 @@ def __init__(self): ) def _get_info(self): - return {"timestep": np.array([self.timestep])} if self.timestep % 2 == 0 else {} + return {"timestep": np.array([self.timestep])} def step(self, action): terminated = self.timestep > 5 @@ -301,8 +296,14 @@ def register_dummy_envs(): ) register( - id="DummyInfoBoxEnv-v0", - entry_point="tests.common:DummyInfoBoxEnv", + id="DummyInfoEnv-v0", + entry_point="tests.common:DummyInfoEnv", + max_episode_steps=5, + ) + + register( + id="DummyInconsistentInfoEnv-v0", + entry_point="tests.common:DummyInconsistentInfoEnv", max_episode_steps=5, ) diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index ccfe5adc..d26cda5a 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -115,8 +115,8 @@ def test_data_collector_step_data_callback(): def test_data_collector_step_data_callback_info_correction(): """Test DataCollector wrapper and Minari dataset creation.""" - dataset_id = "dummy-tuple-discrete-box-v0" - env = gym.make("DummyTupleDiscreteBoxEnv-v0") + dataset_id = "dummy-inconsistent-info-v0" + env = gym.make("DummyInconsistentInfoEnv-v0") env = DataCollector( env, @@ -156,7 +156,7 @@ def test_data_collector_step_data_callback_info_correction(): env.close() check_load_and_delete_dataset(dataset_id) - env = gym.make("DummyTupleDiscreteBoxEnv-v0") + env = gym.make("DummyInconsistentInfoEnv-v0") env = DataCollector( env, diff --git a/tests/data_collector/test_data_collector.py b/tests/data_collector/test_data_collector.py index 03c67f5b..ee932cdc 100644 --- a/tests/data_collector/test_data_collector.py +++ b/tests/data_collector/test_data_collector.py @@ -4,6 +4,7 @@ from minari import DataCollector, EpisodeData, MinariDataset, StepDataCallback from tests.common import ( + check_infos_equal, check_load_and_delete_dataset, get_info_at_step_index, register_dummy_envs, @@ -110,10 +111,10 @@ def test_truncation_without_reset(dataset_id, env_id): env = DataCollector( env, step_data_callback=ForceTruncateStepDataCallback, + record_infos=True, ) env.reset() - for _ in range(num_steps): env.step(env.action_space.sample()) @@ -132,19 +133,20 @@ def test_truncation_without_reset(dataset_id, env_id): assert len(dataset.episode_indices) == num_episodes episodes_generator = dataset.iterate_episodes() - last_step = None + last_step = get_single_step_from_episode(next(episodes_generator), -1) for episode in episodes_generator: assert episode.total_timesteps == ForceTruncateStepDataCallback.episode_steps - if last_step is not None: - first_step = get_single_step_from_episode(episode, 0) - # Check that the last observation of the previous episode is carried over to the next episode - # as the reset observation. - if isinstance(first_step.observations, dict) or isinstance( - first_step.observations, tuple - ): - assert first_step.observations == last_step.observations - else: - assert np.array_equal(first_step.observations, last_step.observations) + first_step = get_single_step_from_episode(episode, 0) + # Check that the last observation of the previous episode is carried over to the next episode + # as the reset observation. + if isinstance(first_step.observations, dict) or isinstance( + first_step.observations, tuple + ): + assert first_step.observations == last_step.observations + else: + assert np.array_equal(first_step.observations, last_step.observations) + + check_infos_equal(last_step.infos, first_step.infos) last_step = get_single_step_from_episode(episode, -1) assert bool(last_step.truncations) is True diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index b26f29ef..6dfc7afc 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -38,7 +38,7 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): """Test DataCollector wrapper and Minari dataset creation.""" env = gym.make(env_id) - env = DataCollector(env) + env = DataCollector(env, record_infos=True) num_episodes = 10 # Step the environment, DataCollector wrapper will do the data collection job @@ -115,24 +115,22 @@ def test_generate_dataset_with_collector_env(dataset_id, env_id): def test_record_infos_collector_env(info_override): """Test DataCollector wrapper and Minari dataset creation including infos.""" dataset_id = "dummy-mutable-info-box-test-v0" - env = gym.make("DummyInfoBoxEnv-v0", info=info_override) + env = gym.make("DummyInfoEnv-v0", info=info_override) env = DataCollector(env, record_infos=True) num_episodes = 10 - # Step the environment, DataCollector wrapper will do the data collection job _, info_sample = env.reset(seed=42) for episode in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() _, _, terminated, truncated, _ = env.step(action) env.reset() - # Create Minari dataset and store locally dataset = minari.create_dataset_from_collector_env( dataset_id=dataset_id, collector_env=env,