Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add Pytree #4154

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

[nnx] add Pytree #4154

wants to merge 1 commit into from

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Aug 29, 2024

What does this PR do?

The JAX ecosystem natively supports pytrees, while NNX Modules can be transformed back and forth to pytrees via split / merge the process add some noise and Module in general cannot be used as a drop-in replacement when callable pytrees are expected. Also, because NNX support pytrees it can used Module from other libraries like Equinox and Penzai, however the reverse is not true.

To mitigate the above, this PR adds nnx.Pytree which is a proxy object that wraps Module and implements the pytree protocol.

module = nnx.Pytree(nnx.Linear(2, 3, rngs=nnx.Rngs(0)))
module = jax.tree.map(lambda x: x, module)
assert module.kernel.shape == (2, 3)

y = module(jnp.ones((1, 2)))
assert y.shape == (1, 3)

Pytree respect referential transparency, meaning that it wont share state with Module(s) it wraps, internally the input is split and the GraphDef and State are stored, only when accessing attributes the underlying object is materialized using merge and cached for later use. When methods are called the GraphDef and State are updated in place. Pytree can be used as a context manager to get access to the underlying object when needed for manual modifications, upon termination the pytree is updated with the new state.

@dataclasses.dataclass
class Counter(nnx.Module):
  count: nnx.BatchStat[int]

pt_counter = nnx.Pytree(Counter(nnx.BatchStat(0)))
assert pt_counter.count.value == 0

with pt_counter as counter:
  counter.count += 1

assert pt_counter.count.value == 1

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@@ -252,16 +254,86 @@ print(f'{y.shape = }')
nnx.display(model)
```

How do NNX transforms achieve this? To understand how NNX objects interact with
JAX transforms lets take a look at the Functional API.
## Using Modules as Pytrees
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized pytree is a JAX concept, not a wider-spread python concept.
Let's add this link when we talk about pytree: https://jax.readthedocs.io/en/latest/pytrees.html


print(ys)
```

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a line to introduce this piece of code, something like: Let's see another example mixing NNX and Equinox modules, as Equinox modules are pure pytrees.

@property
def state(self):
return self._nnx_state

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a util to convert it back to nnx.Module, like:

def to_module() -> nnx.Module:
  return nnx.merge(self.graphdef, self.state)

@@ -425,61 +425,6 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
assert m2 is m
assert m2.ref is m2

def test_call_jit_update(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we delete these tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would delete nnx.call if we add Pytree, these test were moved to pytree_test.py.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants