Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Parallel Sampling #136

Open
daniellawson9999 opened this issue Aug 31, 2023 · 4 comments
Open

[Question] Parallel Sampling #136

daniellawson9999 opened this issue Aug 31, 2023 · 4 comments

Comments

@daniellawson9999
Copy link

daniellawson9999 commented Aug 31, 2023

Parallel episode sampling

I have a use case where we have a dataset consisting of image-based observations, and I notice that sampling speed seems to be slower than with 1D observations. I checked out how sampling is working internally, and noticed that Minari samples episodes serially, instead of sampling in parallel. I thought that parallelizing this call may have been thought about already, so I was curious for any recommendations on the best way to do this. I was also wondering if this was something that will be added in the future.

I have one more layer of complexity on top of this, where instead of 1 dataset, I have say 10 datasets from different envs, each have image-based observations. Think multi-task Atari. I have 10 minari datasets, and then say want 30 episodes from each for each gradient update. Also want to do this in parallel, and will experiment with different parallelization techniques but curious if others had intuition about this.

def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]:
"""Get a list of episodes.
Args:
episode_indices (Iterable[int]): episodes id to return
Returns:
episodes (List[dict]): list of episodes data
"""
out = []
with h5py.File(self._data_path, "r") as file:
for ep_idx in episode_indices:
ep_group = file[f"episode_{ep_idx}"]
out.append(
{
"id": ep_group.attrs.get("id"),
"total_timesteps": ep_group.attrs.get("total_steps"),
"seed": ep_group.attrs.get("seed"),
"observations": self._decode_space(
ep_group["observations"], self.observation_space
),
"actions": self._decode_space(
ep_group["actions"], self.action_space
),
"rewards": ep_group["rewards"][()],
"terminations": ep_group["terminations"][()],
"truncations": ep_group["truncations"][()],
}

@balisujohn
Copy link
Collaborator

First, thanks for using Minari! And these questions are really helpful for us; it's difficult to refine a product without hearing from users.

OK so for the first part. We are working on a optional https://docs.mosaicml.com/projects/streaming/en/stable/ streamingdataset backend. We are open to design suggestions for parallel sampling for both a streaming dataset and h5py backend(I'm not sure what we need to do to get true parallelism in python in the setting of memory shared between physical threads, maybe this is easy). It may be that a parallel sampling implementation could be a strict improvement over our current implementation on any machine with more than once physical CPU core.

For the second one. We don't have any built-in features for sampling from multiple datasets at once. The closest thing that comes to mind is generating the list of indices to sample externally to minari, then sampling from each dataset using iterate_episodes with that list as an argument (you can also use square brackets directly on the MinariDataset object to get a episode by index). That will give you fine-grained enough control to sample without replacement, or sample the same indices from different datasets, etc... We have sub-episode trajectory sampling code in development also.

We are open to feature requests, so feel free to propose any features you think would support your use-case.

@daniellawson9999
Copy link
Author

Thanks for the response! Regarding the development of the streaming dataset backend, is this currently in a public fork? Curious to just to take a look and see if I could patch together something similar in the mean time before this becomes an official feature.

@jamartinh
Copy link
Contributor

jamartinh commented Oct 9, 2024

Have you tried by now with joblib library ?
Also just using torch's Dataloader should work using n processes.

@jamartinh
Copy link
Contributor

from joblib import Parallel, delayed
import minari

class MinariParallelWrapper:
    def __init__(self, dataset_name):
        # Load the dataset using Minari
        self.dataset = minari.load_dataset(dataset_name)
        
    def get_episodes_parallel(self, n, n_jobs=-1):
        """Fetch 'n' episodes in parallel using Joblib."""
        return Parallel(n_jobs=n_jobs)(
            delayed(self.dataset.get_episode)(i) for i in range(n)
        )

    # Add other dataset methods here, with or without parallelism
    def get_metadata(self):
        return self.dataset.metadata

    def episode_statistics(self):
        return self.dataset.episode_statistics()

    # Wrapping other methods (optionally parallelized)
    # For example, getting rewards from episodes:
    def get_rewards_parallel(self, n, n_jobs=-1):
        """Fetch rewards from 'n' episodes in parallel."""
        return Parallel(n_jobs=n_jobs)(
            delayed(lambda ep: ep["rewards"])(self.dataset.get_episode(i)) for i in range(n)
        )

# Example usage
if __name__ == "__main__":
    # Assume dataset already created and named "CartPole-v1-dataset"
    wrapper = MinariParallelWrapper("CartPole-v1-dataset")
    
    # Get 10 episodes in parallel
    episodes = wrapper.get_episodes_parallel(10)
    print(episodes)

    # Example of getting rewards from the first 5 episodes in parallel
    rewards = wrapper.get_rewards_parallel(5)
    print(rewards)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants