-
Notifications
You must be signed in to change notification settings - Fork 644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] add Pytree #4154
base: main
Are you sure you want to change the base?
[nnx] add Pytree #4154
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@@ -252,16 +254,86 @@ print(f'{y.shape = }') | |||
nnx.display(model) | |||
``` | |||
|
|||
How do NNX transforms achieve this? To understand how NNX objects interact with | |||
JAX transforms lets take a look at the Functional API. | |||
## Using Modules as Pytrees |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realized pytree is a JAX concept, not a wider-spread python concept.
Let's add this link when we talk about pytree: https://jax.readthedocs.io/en/latest/pytrees.html
|
||
print(ys) | ||
``` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a line to introduce this piece of code, something like: Let's see another example mixing NNX and Equinox modules, as Equinox modules are pure pytrees.
@property | ||
def state(self): | ||
return self._nnx_state | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a util to convert it back to nnx.Module
, like:
def to_module() -> nnx.Module:
return nnx.merge(self.graphdef, self.state)
@@ -425,61 +425,6 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): | |||
assert m2 is m | |||
assert m2.ref is m2 | |||
|
|||
def test_call_jit_update(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we delete these tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would delete nnx.call
if we add Pytree
, these test were moved to pytree_test.py
.
What does this PR do?
The JAX ecosystem natively supports pytrees, while NNX Modules can be transformed back and forth to pytrees via
split
/merge
the process add some noise and Module in general cannot be used as a drop-in replacement when callable pytrees are expected. Also, because NNX support pytrees it can used Module from other libraries like Equinox and Penzai, however the reverse is not true.To mitigate the above, this PR adds
nnx.Pytree
which is a proxy object that wraps Module and implements the pytree protocol.Pytree
respect referential transparency, meaning that it wont share state with Module(s) it wraps, internally the input issplit
and theGraphDef
andState
are stored, only when accessing attributes the underlying object is materialized usingmerge
and cached for later use. When methods are called theGraphDef
andState
are updated in place.Pytree
can be used as a context manager to get access to the underlying object when needed for manual modifications, upon termination the pytree is updated with the new state.