Skip to content

Commit

Permalink
add split primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
n-gao committed Jan 30, 2025
1 parent f5fc768 commit 5150e33
Show file tree
Hide file tree
Showing 5 changed files with 568 additions and 2 deletions.
2 changes: 1 addition & 1 deletion folx/jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def jvp(tangents):
grad_y = tree_take(y_tangent, slice(None, -1), axis=JAC_DIM)
lapl_y = tree_take(y_tangent, -1, axis=JAC_DIM)

assert grad_y.shape == mask.shape
assert jtu.tree_all(jtu.tree_map(lambda a, b: a.shape == b.shape, grad_y, mask))
grad_y = jtu.tree_map(FwdJacobian, grad_y, mask)
return y, grad_y, lapl_y

Expand Down
9 changes: 9 additions & 0 deletions folx/wrapped_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,3 +503,12 @@ def get_laplacian(
register_function(
jax.lax.square_p, wrap_forward_laplacian(jax.lax.square, in_axes=())
)
if hasattr(jax.lax, 'split_p'):
register_function(
jax.lax.split_p,
wrap_forward_laplacian(
jax.lax.split,
flags=FunctionFlags.INDEXING,
index_static_args=(1, 2),
),
)
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ build-backend = "hatchling.build"

[dependency-groups]
dev = [
"flax>=0.10.2",
"jax[cuda12]>=0.4.38",
"parameterized>=0.9.0",
"pre-commit-uv>=4.1.4",
"pytest>=8.3.4",
Expand All @@ -66,3 +68,6 @@ addopts = "-n auto"
JAX_PLATFORMS = "cpu"
JAX_ENABLE_X64 = "True"
XLA_FLAGS = "--xla_force_host_platform_device_count=4"

[tool.uv.sources]
folx = { workspace = true }
11 changes: 11 additions & 0 deletions test/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,14 @@ def f(x, dtype):
with self.subTest(dtype=dtype):
y = jax.jit(forward_laplacian(functools.partial(f, dtype=dtype)))(x)
self.assertIsInstance(y, jax.Array)

def test_split(self):
x = jax.random.normal(jax.random.PRNGKey(0), (16,))

def f(x):
return jnp.split(x, 2)

# Check that the output is still sparse
y_fwd = forward_laplacian(f, sparsity_threshold=1)(x)
assert y_fwd[0].jacobian.data.shape == (1, 8)
assert y_fwd[1].jacobian.data.shape == (1, 8)
Loading

0 comments on commit 5150e33

Please sign in to comment.