Skip to content

Commit

Permalink
Restore test for incompatible info step data for default behavior and…
Browse files Browse the repository at this point in the history
… add test for infos_format="list" option.
  • Loading branch information
jamartinh committed Oct 15, 2024
1 parent da58ce3 commit d39e65f
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tests/data_collector/callbacks/test_step_data_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ def test_data_collector_step_data_callback(data_format, register_dummy_envs):


@pytest.mark.parametrize("data_format", get_storage_keys())
@pytest.mark.parametrize("infos_format", ["dict", "list"])
def test_data_collector_step_data_callback_info_correction(
data_format, register_dummy_envs
data_format, register_dummy_envs, infos_format
):
"""Test DataCollector wrapper and Minari dataset creation."""
dataset_id = "dummy-inconsistent-info-v0"
Expand All @@ -123,6 +124,7 @@ def test_data_collector_step_data_callback_info_correction(
record_infos=True,
step_data_callback=CustomSubsetInfoPadStepDataCallback,
data_format=data_format,
infos_format=infos_format,
)
num_episodes = 10

Expand Down Expand Up @@ -164,7 +166,8 @@ def test_data_collector_step_data_callback_info_correction(
record_infos=True,
)
# here we are checking to make sure that if we have an environment changing its info
# structure across steps, IT IS OK!
# structure across steps, it results in an error when infos_format is the default (dict)
# behaviour.
with pytest.raises(ValueError):
num_episodes = 10
env.reset(seed=42)
Expand All @@ -176,5 +179,8 @@ def test_data_collector_step_data_callback_info_correction(
_, _, terminated, truncated, _ = env.step(action)

env.reset()
raise ValueError("This should BE reached")
if infos_format == "list":
raise ValueError(
"This should be raised to indicate the test succeeded for infos_format='list'"
)
env.close()

0 comments on commit d39e65f

Please sign in to comment.