Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making Streaming Dataset framework agnostic: Removing PyTorch dependency #551

Open
Abhijit-2592 opened this issue Dec 26, 2023 · 5 comments
Labels
enhancement New feature or request

Comments

@Abhijit-2592
Copy link

🚀 Feature Request

Hey MosaicML team! Thank you so much for this awesome project! I was wondering if there are any plans to make this framework agnostic: Remove the dependency from PyTorch.

Motivation

The general idea of StreamingDataset is very useful and I believe the ML community in general will be more thrilled if we decouple this from PyTorch.

Implementation

Here are my thoughts on how we can go about this:

  • The torch.utils.data.Dataset is a simple class with no dependencies with PyTorch (This is also true for the IterableDataset) which can be very easily re-implemented here.
  • However this gets a bit challenging when porting the distributed.py file. However this is where the CuPy project comes to rescue. We can have seamless interoperability between CuPy, Jax, Tensorflow and PyTorch Tensors via the dl_pack API with no copies. And most of the functions in the distributed.py file have similar implementations in CuPy's distributed API.
  • As for the StreamingDataLoader we can have this as an optional install if installing with PyTorch backend.
  • So my suggestion is if we use CuPy instead of PyTorch we can keep this framework neutral and also have 0 copy interoperability between Jax, TF and Torch.

Additional context

If made framework agnostic:

  • This can be used with tf.data pipelines which works well with Jax and Tensorflow.
  • Fits perfectly into keras.utils.Sequence this way we can also use it with Keras-3 which is compatible with TF/Jax/PyTorch backends.

Also I will be happy to extend my support on the same if you guys think this is a potential future direction!

@Abhijit-2592 Abhijit-2592 added the enhancement New feature or request label Dec 26, 2023
@knighton
Copy link
Contributor

Decoupling from PyTorch would be a hell of a project! We enthusiastically welcome your contributions. Let me list some objections that come to mind offhand -- what do you make of them?

  • StreamingDataset is designed exactly to how PyTorch DataLoader operates, with each rank iterating round-robin over a bunch of worker replicas which are typically fork/spawned upon iter, identical samples per DL requirement, etc. What's the cupy answer to get_worker_info()?

  • Our killer feature, the elastically deterministic mid-epoch checkpointing and resumption, currently depends on either our custom StreamingDataLoader subclass of DataLoader, or tracking time yourself like Composer (yes we built it two different ways), it's a core thing too.

  • We use numpy for some things, but no framework-specific Tensors, and no GPU. We assume GPU and interconnect are precious resources and hands-off. There is a tiny usage of torch dist for barriers in some critical places, which we set up as gloo and tear down if not already used IIRC. Theoretically you could swap them out with streaming/base/shared/barrier.py which is an inter-process barrier backed by FileLock (fcntl) and SharedMemory, which we currently use for worker sync as workers can't necessarily dist. It would be nice to remove that last bit of reliance on torch dist, generally speaking.

@Abhijit-2592
Copy link
Author

@knighton thanks for your comment and support.

  1. I have a private port of the PyTorch dataloader which I kinda hacked for fun to remove the torch dependency and made it into a standalone package (It kinda works but I have not tested it fully). During that I remember porting the get_worker_info(). I’ll see if it works.
  2. I’m not sure what to do for the StreamingDataLoader(). Maybe as you suggested we can track time ourselves.
  3. I’ll try to use the barrier.py you suggested and try porting the distributed.py to see if we can remove the reliance on torch dist.

I’ll keep you updated on the same! Thanks!

@knighton
Copy link
Contributor

Appreciate the updates.

I would recommend just reading our StreamingDataLoader for (2), as what it's doing/needs to do is very simple.

@knighton
Copy link
Contributor

Experimental PR to remove dependency on torch dist:

#552

@Abhijit-2592
Copy link
Author

@knighton Wow! That was fast!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants