Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change create dataset from buffer to ensure same behavior with create data from env #146

Closed
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions minari/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def create_dataset_from_buffers(
* `terminations`: np.ndarray of step terminations. shape = (total_episode_steps + 1, 1).
* `truncations`: np.ndarray of step truncations. shape = (total_episode_steps + 1, 1).

If the last trjaectory is neither terminated or truncated, the last step will be marked as truncated.
younik marked this conversation as resolved.
Show resolved Hide resolved
Other additional items can be added as long as the values are np.ndarray's or other nested dictionaries.

Args:
Expand Down Expand Up @@ -455,9 +456,18 @@ def create_dataset_from_buffers(
with h5py.File(data_path, "w", track_order=True) as file:
for i, eps_buff in enumerate(buffer):
# check episode terminated or truncated
assert (
eps_buff["terminations"][-1] or eps_buff["truncations"][-1]
), "Each episode must be terminated or truncated before adding it to a Minari dataset"
if_term_or_trunc = eps_buff["terminations"][-1] or eps_buff["truncations"][-1]
if_last_eps = (i == (len(buffer) - 1))
if if_last_eps and not if_term_or_trunc:
warnings.warn(
f"The last episode {i} is not terminated or truncated. The last step will be marked as truncated.",
UserWarning,
)
eps_buff["truncations"][-1] = True
else:
assert (
eps_buff["terminations"][-1] or eps_buff["truncations"][-1]
), "Each episode must be terminated or truncated before adding it to a Minari dataset"
Comment on lines +469 to +472
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like this branch is not very useful, I would remove it

assert len(eps_buff["actions"]) + 1 == len(
eps_buff["observations"]
), f"Number of observations {len(eps_buff['observations'])} must have an additional element compared to the number of action steps {len(eps_buff['actions'])}. The initial and final observation must be included"
Expand Down
103 changes: 103 additions & 0 deletions tests/utils/test_dataset_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,106 @@ def test_generate_dataset_with_space_subset_external_buffer():
env.close()

check_load_and_delete_dataset(dataset_id)


def test_generate_dataset_with_buffer_incomplete_traj():
"""Test create dataset from external buffers with incomplete trajectories."""
dataset_id = "cartpole-test-v0"
env_id = "CartPole-v1"

# delete the test dataset if it already exists
local_datasets = minari.list_local_datasets()
if dataset_id in local_datasets:
minari.delete_dataset(dataset_id)

env = gym.make(env_id)
env.reset(seed=42)

obs_all, act_all, rew_all, term_all, trunc_all = [], [], [], [], []
buffer = []
num_episodes = 10
# Step the environment, DataCollectorV0 wrapper will do the data collection job
for episode in range(num_episodes):
observations, actions, rewards, terminations, truncations = [], [], [], [], []

observation, _ = env.reset()
observations.append(observation)
_term_i, _trunc_i = False, False

while not _term_i and not _trunc_i:
_act_i = env.action_space.sample() # User-defined policy function
_obs_i, _rwd_i, _term_i, _trunc_i, _ = env.step(_act_i)
observations, actions, rewards, terminations, truncations = map(
lambda x, y: x + [y],
[observations, actions, rewards, terminations, truncations],
[_obs_i, _act_i, _rwd_i, _term_i, _trunc_i],
)

# last episoode manually change the last truncation and termination to False to verify
younik marked this conversation as resolved.
Show resolved Hide resolved
if episode == num_episodes - 1:
terminations[-1] = False
truncations[-1] = False

obs_all, act_all, rew_all, term_all, trunc_all = map(
lambda x, y: x + [np.array(y)],
[obs_all, act_all, rew_all, term_all, trunc_all],
[observations, actions, rewards, terminations, truncations],
)

buffer.append(
{
"observations": observations,
"actions": actions,
"rewards": rewards,
"terminations": terminations,
"truncations": truncations,
}
)

# Create Minari dataset and store locally
dataset = minari.create_dataset_from_buffers(
dataset_id=dataset_id,
env=env,
buffer=buffer,
algorithm_name="random_policy",
code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py",
author="WillDudley",
author_email="[email protected]",
)

assert isinstance(dataset, MinariDataset)
assert dataset.total_episodes == num_episodes
assert dataset.spec.total_episodes == num_episodes
assert len(dataset.episode_indices) == num_episodes

check_data_integrity(dataset._data, dataset.episode_indices)
check_env_recovery(env, dataset)
env.close()

dataset_loaded = minari.load_dataset(dataset_id)
obs_loaded, act_loaded, rew_loaded, term_loaded, trunc_loaded = [], [], [], [], []
for _eps in dataset_loaded:
obs_loaded.append(_eps.observations)
act_loaded.append(_eps.actions)
rew_loaded.append(_eps.rewards)
term_loaded.append(_eps.terminations)
trunc_loaded.append(_eps.truncations)

obs_loaded, act_loaded, rew_loaded, term_loaded, trunc_loaded = map(
lambda x: np.concatenate(x),
[obs_loaded, act_loaded, rew_loaded, term_loaded, trunc_loaded],
)
obs_original, act_original, rew_original, term_original, trunc_original = map(
lambda x: np.concatenate(x), [obs_all, act_all, rew_all, term_all, trunc_all]
)

assert np.all(obs_loaded == obs_original)
assert np.all(act_loaded == act_original)
assert np.all(rew_loaded == rew_original)
assert np.all(term_loaded == term_original)
assert np.all(trunc_loaded[:-1] == trunc_original[-1])
assert trunc_loaded[-1].item() is True
assert trunc_original[-1].item() is False

check_load_and_delete_dataset(dataset_id)
return
Loading