diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index 9dd06f35..f50efc24 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -111,8 +111,9 @@ def test_data_collector_step_data_callback(data_format, register_dummy_envs): @pytest.mark.parametrize("data_format", get_storage_keys()) +@pytest.mark.parametrize("infos_format", ["dict", "list"]) def test_data_collector_step_data_callback_info_correction( - data_format, register_dummy_envs + data_format, register_dummy_envs, infos_format ): """Test DataCollector wrapper and Minari dataset creation.""" dataset_id = "dummy-inconsistent-info-v0" @@ -123,6 +124,7 @@ def test_data_collector_step_data_callback_info_correction( record_infos=True, step_data_callback=CustomSubsetInfoPadStepDataCallback, data_format=data_format, + infos_format=infos_format, ) num_episodes = 10 @@ -164,7 +166,8 @@ def test_data_collector_step_data_callback_info_correction( record_infos=True, ) # here we are checking to make sure that if we have an environment changing its info - # structure across steps, IT IS OK! + # structure across steps, it results in an error when infos_format is the default (dict) + # behaviour. with pytest.raises(ValueError): num_episodes = 10 env.reset(seed=42) @@ -176,5 +179,8 @@ def test_data_collector_step_data_callback_info_correction( _, _, terminated, truncated, _ = env.step(action) env.reset() - raise ValueError("This should BE reached") + if infos_format == "list": + raise ValueError( + "This should be raised to indicate the test succeeded for infos_format='list'" + ) env.close()