Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-GPU Question #18

Open
adam-hartshorne opened this issue Oct 27, 2024 · 7 comments
Open

Multi-GPU Question #18

adam-hartshorne opened this issue Oct 27, 2024 · 7 comments

Comments

@adam-hartshorne
Copy link

I notice you have a stable branch for multi-gpu testing. I was just wondering if torch2jax does actually work out of the box when using what I believe is now the standard JAX multi-gpu paradigm of sharding i.e.

https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
https://jax.readthedocs.io/en/latest/notebooks/shard_map.html

@rdyro
Copy link
Owner

rdyro commented Oct 28, 2024

The multi-gpu there refers to correctly calling functions on a particular GPU, but unfortunately not on multi-GPU shared arrays (the paradigm from JAX). I have to learn more about sharding in torch to think how to support a sharded array function in torch.

I'm currently using the NVIDIA C++ functionality for detecting which GPU the data is on, so as long as torch2jax is called from shard_map (exactly!) it should hopefully work correctly. I'm planning on testing this in the coming days. (I'll leave the issue open until I can test it)

Tangentially, this weekend, I finished porting torch2jax (in the new-ffi branch) to the new FFI interface, so long-term support should be assured now.

@adam-hartshorne
Copy link
Author

Great job on getting FFI interface working. I just tried installing from that branch and doing a fresh recompile on one of my use cases and all seems to work seamlessly.

@rdyro
Copy link
Owner

rdyro commented Nov 1, 2024

Awesome! I'll try to switch permanently this weekend

@adam-hartshorne
Copy link
Author

Just wondering what the state of this is now. I haven't done much JAX based multi-gpu stuff, but would something like this work if a torch2jax function was called say in the loss function?

https://docs.kidger.site/equinox/examples/parallelism/

@rdyro
Copy link
Owner

rdyro commented Dec 30, 2024

Good question! I believe it should work on multiple devices when torch2jax is called on each shard (so the sharding approach to parallelism).

pmap is the old way of doing device-level parallelism, but it's (roughly) successor shard_map should work out of the box. One issue might be performance in that each shard would run sequentially instead of in parallel, I'm investigating this and I'll provide a code example soon.

@adam-hartshorne
Copy link
Author

I got a MWE up and running and it appeared performance became worse than just running it on a single GPU.

@rdyro
Copy link
Owner

rdyro commented Jan 21, 2025

I've had some time this weekend to implement experimental (performant) multi-GPU support, can you take a look here: https://github.com/rdyro/torch2jax?tab=readme-ov-file#new-performant-multi-gpu-experimental ?

The biggest takeaway is that pmap doesn't work, but shard_map should work great. Let me know if you want to take a look at the new implementation (main branch) and have any feedback!

Generally for performance investigation, tensorboard traces (like here https://jax.readthedocs.io/en/latest/profiling.html ) are quite useful, if you had those, I'd be happy to take a look too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants