You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
That's very interesting. The main improvement that could be done to torch2jax is being able to somehow combine torch and jax lazy cuda execution. Right now, the cuda tensor wrapping torch->jax is already zero-overhead, but we need to instruct pytorch to synchronize its cuda stream both before and after executing its wrapped-function. Perhaps this would be a step in getting rid of that.
Not a bug, but I thought I would make you aware of this pull request from the main JAX devs,
[JAX] Implement importing external dlpack-aware Python arrays, which allows for creating jax.Arrays from external GPU arrays
asynchronously.
jax-ml/jax#17238
The text was updated successfully, but these errors were encountered: