Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encapsulate Arrays inside Variables #2

Open
cgarciae opened this issue Oct 5, 2024 · 2 comments
Open

Encapsulate Arrays inside Variables #2

cgarciae opened this issue Oct 5, 2024 · 2 comments

Comments

@cgarciae
Copy link

cgarciae commented Oct 5, 2024

Hey @maxencefaldor! This is a very cool project, I love it.

Just wanted to give you a heads up that in the next release of Flax, all state is required to be inside Variables. This solves some issues regarding attributes changing from being state or not due to the type. As an easy fix you can encapsulate Array inside nnx.Param or create a custom nnx.Variable type e.g.

self.kernel_fft = nnx.Param(jnp.fft.fft2(jnp.fft.fftshift(kernel_normalized, axes=(0, 1)), axes=(0, 1)))

Most of the code should work as Variables behave as proxies over their inner value and can be passed to JAX functions. The main difference is that to update them you must use .value:

self.kernel_fft.value = new_value

Best,
Cristian

@maxencefaldor
Copy link
Owner

maxencefaldor commented Oct 20, 2024

Hey @cgarciae,

I am glad that you appreciate the library and thank you for opening this issue!

I tried to wrap all state inside nnx.Param but I get this error:

>       state_fft_k = jnp.dot(state_fft, self.reshape_c_k)  # (y, x, k,)
TypeError: Argument 'Param(
E      value=Traced<ShapedArray(float32[3,3])>with<DynamicJaxprTrace(level=3/0)>
E    )' of type <class 'flax.nnx.variablelib.Param'> is not a valid JAX type.

I would like to avoid manually doing self.reshape_c_k.value everywhere. Do you have a solution?

@cgarciae
Copy link
Author

cgarciae commented Oct 22, 2024

@maxencefaldor can you please post this issue on the JAX repo and explain that you would like to treat the Param as a jax.Array?

The issue is that NNX is using the __jax_array__ protocol to avoid users having to type .value but JAX doesn't have full coverage for __jax_array__ in their APIs so its very important that users show their interest in JAX improving their support for __jax_array__.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants