Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JAX / XLA adding the importing external dlpack-aware Python arrays. #6

Open
adam-hartshorne opened this issue Aug 23, 2023 · 1 comment

Comments

@adam-hartshorne
Copy link

Not a bug, but I thought I would make you aware of this pull request from the main JAX devs,

[JAX] Implement importing external dlpack-aware Python arrays, which allows for creating jax.Arrays from external GPU arrays
asynchronously.
jax-ml/jax#17238

@rdyro
Copy link
Owner

rdyro commented Aug 24, 2023

That's very interesting. The main improvement that could be done to torch2jax is being able to somehow combine torch and jax lazy cuda execution. Right now, the cuda tensor wrapping torch->jax is already zero-overhead, but we need to instruct pytorch to synchronize its cuda stream both before and after executing its wrapped-function. Perhaps this would be a step in getting rid of that.

I'll look into it, thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants