Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose memory space in general to JAX users #23152

Open
yliu120 opened this issue Feb 26, 2025 · 1 comment
Open

Expose memory space in general to JAX users #23152

yliu120 opened this issue Feb 26, 2025 · 1 comment

Comments

@yliu120
Copy link
Contributor

yliu120 commented Feb 26, 2025

Prototype Reference: https://github.com/openxla/xla/pull/23149/files

This is the internal annotation custom calls we might use to actually annotate a buffer to a specific memory space. However, we actually think the memory space annotations are too heavy and not necessary.

To get great performance and allow maximal flexibility, the users should be able to pin specific buffers onto some specific memory spaces. Let's consider the following case, a JAX user want to get the buffer as a persistent temp buffer or a pinned host buffer so that they can use them inside their FFI custom ops.

We should in general do the following,

  1. Fixes SPMD partitioning to recognize heterogenous memory spaces.
  2. Don't drop user-set memory space in layout assignment
  3. Fixes layout normalization
  4. Adds a bit memory space propagation.
@nouiz
Copy link
Contributor

nouiz commented Feb 28, 2025

Can you confirm that by "persistent temp buffer" you want an activation to always have the same buffer?
Why that is useful?

JAX started to support supporting activation offloading to the host memory via the jax.remat API. Instead of recomputing, it offload. Is that what you want to do here?

Here is some recent documentation about memory space for input and outputs of a JAX function:
https://docs.jax.dev/en/latest/sharded-computation.html#sharding-transformation-between-memory-types

Here is some doc to do activation offloading instead of recomputation: https://docs.jax.dev/en/latest/gradient-checkpointing.html#custom-policies-for-offload

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants