-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-GPU Question #18
Comments
The multi-gpu there refers to correctly calling functions on a particular GPU, but unfortunately not on multi-GPU shared arrays (the paradigm from JAX). I have to learn more about sharding in torch to think how to support a sharded array function in torch. I'm currently using the NVIDIA C++ functionality for detecting which GPU the data is on, so as long as Tangentially, this weekend, I finished porting |
Great job on getting FFI interface working. I just tried installing from that branch and doing a fresh recompile on one of my use cases and all seems to work seamlessly. |
Awesome! I'll try to switch permanently this weekend |
Just wondering what the state of this is now. I haven't done much JAX based multi-gpu stuff, but would something like this work if a torch2jax function was called say in the loss function? |
Good question! I believe it should work on multiple devices when torch2jax is called on each shard (so the sharding approach to parallelism).
|
I got a MWE up and running and it appeared performance became worse than just running it on a single GPU. |
I've had some time this weekend to implement experimental (performant) multi-GPU support, can you take a look here: https://github.com/rdyro/torch2jax?tab=readme-ov-file#new-performant-multi-gpu-experimental ? The biggest takeaway is that Generally for performance investigation, tensorboard traces (like here https://jax.readthedocs.io/en/latest/profiling.html ) are quite useful, if you had those, I'd be happy to take a look too. |
I notice you have a stable branch for multi-gpu testing. I was just wondering if torch2jax does actually work out of the box when using what I believe is now the standard JAX multi-gpu paradigm of sharding i.e.
https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
https://jax.readthedocs.io/en/latest/notebooks/shard_map.html
The text was updated successfully, but these errors were encountered: