You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hey @maxencefaldor! This is a very cool project, I love it.
Just wanted to give you a heads up that in the next release of Flax, all state is required to be inside Variables. This solves some issues regarding attributes changing from being state or not due to the type. As an easy fix you can encapsulate Array inside nnx.Param or create a custom nnx.Variable type e.g.
Most of the code should work as Variables behave as proxies over their inner value and can be passed to JAX functions. The main difference is that to update them you must use .value:
self.kernel_fft.value = new_value
Best,
Cristian
The text was updated successfully, but these errors were encountered:
I am glad that you appreciate the library and thank you for opening this issue!
I tried to wrap all state inside nnx.Param but I get this error:
> state_fft_k = jnp.dot(state_fft, self.reshape_c_k) # (y, x, k,)
TypeError: Argument 'Param(
E value=Traced<ShapedArray(float32[3,3])>with<DynamicJaxprTrace(level=3/0)>
E )' of type <class 'flax.nnx.variablelib.Param'> is not a valid JAX type.
I would like to avoid manually doing self.reshape_c_k.value everywhere. Do you have a solution?
@maxencefaldor can you please post this issue on the JAX repo and explain that you would like to treat the Param as a jax.Array?
The issue is that NNX is using the __jax_array__ protocol to avoid users having to type .value but JAX doesn't have full coverage for __jax_array__ in their APIs so its very important that users show their interest in JAX improving their support for __jax_array__.
Hey @maxencefaldor! This is a very cool project, I love it.
Just wanted to give you a heads up that in the next release of Flax, all state is required to be inside Variables. This solves some issues regarding attributes changing from being state or not due to the type. As an easy fix you can encapsulate Array inside
nnx.Param
or create a customnnx.Variable
type e.g.Most of the code should work as Variables behave as proxies over their inner value and can be passed to JAX functions. The main difference is that to update them you must use
.value
:Best,
Cristian
The text was updated successfully, but these errors were encountered: