From d39e65f441543af2c1cf3a76f1a6ac7189db3e0a Mon Sep 17 00:00:00 2001 From: R107333 Date: Tue, 15 Oct 2024 21:05:26 +0200 Subject: [PATCH] Restore test for incompatible info step data for default behavior and add test for infos_format="list" option. --- .../callbacks/test_step_data_callback.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index 9dd06f35..f50efc24 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -111,8 +111,9 @@ def test_data_collector_step_data_callback(data_format, register_dummy_envs): @pytest.mark.parametrize("data_format", get_storage_keys()) +@pytest.mark.parametrize("infos_format", ["dict", "list"]) def test_data_collector_step_data_callback_info_correction( - data_format, register_dummy_envs + data_format, register_dummy_envs, infos_format ): """Test DataCollector wrapper and Minari dataset creation.""" dataset_id = "dummy-inconsistent-info-v0" @@ -123,6 +124,7 @@ def test_data_collector_step_data_callback_info_correction( record_infos=True, step_data_callback=CustomSubsetInfoPadStepDataCallback, data_format=data_format, + infos_format=infos_format, ) num_episodes = 10 @@ -164,7 +166,8 @@ def test_data_collector_step_data_callback_info_correction( record_infos=True, ) # here we are checking to make sure that if we have an environment changing its info - # structure across steps, IT IS OK! + # structure across steps, it results in an error when infos_format is the default (dict) + # behaviour. with pytest.raises(ValueError): num_episodes = 10 env.reset(seed=42) @@ -176,5 +179,8 @@ def test_data_collector_step_data_callback_info_correction( _, _, terminated, truncated, _ = env.step(action) env.reset() - raise ValueError("This should BE reached") + if infos_format == "list": + raise ValueError( + "This should be raised to indicate the test succeeded for infos_format='list'" + ) env.close()