Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Flax NNX Haiku Linen migration doc #4200

Merged
merged 2 commits into from
Nov 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 38 additions & 27 deletions docs_nnx/guides/haiku_to_flax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ Migrating from Haiku to Flax

This guide demonstrates the differences between Haiku and Flax NNX models, providing side-by-side example code to help you migrate to the Flax NNX API from Haiku.

To get the most out of this guide, it is highly recommended to go through `Flax NNX basics <https://flax.readthedocs.io/en/latest/nnx_basics.html>`__ document, which covers the :class:`nnx.Module<flax.nnx.Module>` system, `Flax transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__, and the `Functional API <https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-functional-api>`__ with examples.
If you are new to Flax NNX, make sure you become familiarized with `Flax NNX basics <https://flax.readthedocs.io/en/latest/nnx_basics.html>`__, which covers the :class:`nnx.Module<flax.nnx.Module>` system, `Flax transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__, and the `Functional API <https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-functional-api>`__ with examples.

Let’s start with some imports.

.. testsetup:: Haiku, Flax NNX

Expand All @@ -14,16 +15,25 @@ To get the most out of this guide, it is highly recommended to go through `Flax
from typing import Any


Basic Module Definition
Basic Module definition
=======================

Both Haiku and Flax use the ``Module`` class as the default unit to express a neural network library layer. In the example below, you first create a ``Block`` (by subclassing ``Module``) composed of one linear layer with dropout and a ReLU activation function; then you use it as a sub-``Module`` when creating a ``Model`` (also by subclassing ``Module``), which is made up of ``Block`` and a linear layer.
Both Haiku and Flax use the ``Module`` class as the default unit to express a neural network library layer. For example, to create a one-layer network with dropout and a ReLU activation function, you:

* First, create a ``Block`` (by subclassing ``Module``) composed of one linear layer with dropout and a ReLU activation function.
* Then, use ``Block`` as a sub-``Module`` when creating a ``Model`` (also by subclassing ``Module``), which is made up of ``Block`` and a linear layer.

There are two fundamental differences between Haiku and Flax ``Module`` objects:

* **Stateless vs. stateful**: A ``hk.Module`` instance is stateless - the variables are returned from a purely functional ``Module.init()`` call and managed separately. A :class:`flax.nnx.Module`, however, owns its variables as attributes of this Python object.
* **Stateless vs. stateful**:

* A ``haiku.Module`` instance is stateless. This means, the variables are returned from a purely functional ``Module.init()`` call and managed separately.
* A :class:`flax.nnx.Module`, however, owns its variables as attributes of this Python object.

* **Lazy vs. eager**:

* **Lazy vs. eager**: A ``hk.Module`` only allocates space to create variables when they actually see the input when the user calls the model (lazy). A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager).
* A ``haiku.Module`` only allocates space to create variables when they actually see the input when the user calls the model (lazy).
* A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager).


.. codediff::
Expand Down Expand Up @@ -82,13 +92,13 @@ There are two fundamental differences between Haiku and Flax ``Module`` objects:


Variable creation
=======================
=================

Next, let's discuss instantiating the model and initializing its parameters
This section is about instantiating a model and initializing its parameters.

* To generate model parameters for a Haiku model, you need to put it inside a forward function and use ``hk.transform`` to make it purely functional. This results in a nested dictionary of `JAX Arrays <https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array>`__ (``jax.Array`` data types) to be carried around and maintained separately.
* To generate model parameters for a Haiku model, you need to put it inside a forward function and use ``haiku.transform`` to make it purely functional. This results in a nested dictionary of `JAX Arrays <https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array>`__ (``jax.Array`` data types) to be carried around and maintained separately.

* In Flax, the model parameters are automatically initialized when you instantiate the model, and the variables (:class:`nnx.Variable<flax.nnx.Variable>` objects) are stored inside the :class:`nnx.Module<flax.nnx.Module>` (or its sub-Module) as attributes. You still need to provide it with a `pseudorandom number generator (PRNG) <https://jax.readthedocs.io/en/latest/random-numbers.html>`__ key, but that key will be wrapped inside an :class:`nnx.Rngs<flax.nnx.Rngs>` class and stored inside, generating more PRNG keys when needed.
* In Flax NNX, the model parameters are automatically initialized when you instantiate the model, and the variables (:class:`nnx.Variable<flax.nnx.Variable>` objects) are stored inside the :class:`nnx.Module<flax.nnx.Module>` (or its sub-Module) as attributes. You still need to provide it with a `pseudorandom number generator (PRNG) <https://jax.readthedocs.io/en/latest/random-numbers.html>`__ key, but that key will be wrapped inside an :class:`nnx.Rngs<flax.nnx.Rngs>` class and stored inside, generating more PRNG keys when needed.

If you want to access Flax model parameters in the stateless, dictionary-like fashion for checkpoint saving or model surgery, check out the `Flax NNX split/merge API <https://flax.readthedocs.io/en/latest/nnx_basics.html#state-and-graphdef>`__ (:func:`nnx.split<flax.nnx.split>` / :func:`nnx.merge<flax.nnx.merge>`).

Expand Down Expand Up @@ -122,35 +132,34 @@ If you want to access Flax model parameters in the stateless, dictionary-like fa
assert model.block.linear.kernel.value.shape == (784, 256)

Training step and compilation
=======================
=============================

This section covers writing a training step and compiling it using the `JAX just-in-time compilation <https://jax.readthedocs.io/en/latest/jit-compilation.html>`__.

Now, let's proceed to writing a training step and compiling it using `JAX just-in-time compilation <https://jax.readthedocs.io/en/latest/jit-compilation.html>`__. Below are certain differences between Haiku and Flax NNX approaches.
When compiling the training step:

Compiling the training step:
* Haiku uses ``@jax.jit`` - a `JAX transformation <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`__ - to compile a purely functional training step.
* Flax NNX uses :meth:`@nnx.jit<flax.nnx.jit>` - a `Flax NNX transformation <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__ (one of several transform APIs that behave similarly to JAX transforms, but also `work well with Flax objects <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__). While ``jax.jit`` only accepts functions with pure stateless arguments, ``flax.nnx.jit`` allows the arguments to be stateful Modules. This greatly reduces the number of lines needed for a train step.

* Haiku uses ``@jax.jit`` - a `JAX transform <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`__ - to compile a purely functional training step.
* Flax NNX uses :meth:`@nnx.jit<flax.nnx.jit>` - a `Flax NNX transform <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__ (one of several transform APIs that behave similarly to JAX transforms, but also `work well with Flax objects <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`__). So, while ``jax.jit`` only accepts functions with pure stateless arguments, ``nnx.jit`` allows the arguments to be stateful Modules. This greatly reduced the number of lines needed for a train step.
When taking gradients:

Taking gradients:
* Similarly, Haiku uses ``jax.grad`` (a JAX transformation for `automatic differentiation <https://jax.readthedocs.io/en/latest/automatic-differentiation.html#taking-gradients-with-jax-grad>`__) to return a raw dictionary of gradients.
* Meanwhile, Flax NNX uses :meth:`flax.nnx.grad<flax.nnx.grad>` (a Flax NNX transformation) to return the gradients of Flax NNX Modules as :class:`flax.nnx.State<flax.nnx.State>` dictionaries. If you want to use regular ``jax.grad`` with Flax NNX, you need to use the `split/merge API <https://flax.readthedocs.io/en/latest/nnx_basics.html#state-and-graphdef>`__.

* Similarly, Haiku uses ``jax.grad`` (a JAX transform for `automatic differentiation <https://jax.readthedocs.io/en/latest/automatic-differentiation.html#taking-gradients-with-jax-grad>`__) to return a raw dictionary of gradients.
* Flax NNX uses :meth:`nnx.grad<flax.nnx.grad>` (a Flax NNX transform) to return the gradients of NNX Modules as :class:`nnx.State<flax.nnx.State>` dictionaries. If you want to use regular ``jax.grad`` with Flax NNX you need to use the `split/merge API <https://flax.readthedocs.io/en/latest/nnx_basics.html#state-and-graphdef>`__.
For optimizers:

Optimizers:

* If you are already using `Optax <https://optax.readthedocs.io/>`__ optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here) with Haiku, check out the :class:`nnx.Optimizer<flax.nnx.Optimizer>` example in the `Flax basics <https://flax.readthedocs.io/en/latest/nnx_basics.html#transforms>`__ guide for a much more concise way of training and updating your model.
* If you are already using `Optax <https://optax.readthedocs.io/>`__ optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here) with Haiku, check out the :class:`flax.nnx.Optimizer<flax.nnx.Optimizer>` example in the `Flax basics <https://flax.readthedocs.io/en/latest/nnx_basics.html#transforms>`__ guide for a much more concise way of training and updating your model.

Model updates during each training step:

* The Haiku training step needs to return a `pytree <https://jax.readthedocs.io/en/latest/working-with-pytrees.html>`__ of parameters as the input of the next step.
* The Flax training step doesn't need to return anything, because the ``model`` was already updated in-place within :meth:`nnx.jit<flax.nnx.jit>`.
* In addition, :class:`nnx.Module<flax.nnx.Module>` objects are stateful, and ``Module`` automatically tracks several things within it, such as PRNG keys and ``BatchNorm`` stats. That is why you don't need to explicitly pass an PRNG key in on every step. Also note that you can use :meth:`nnx.reseed<flax.nnx.reseed>` to reset its underlying PRNG state.
* The Haiku training step needs to return a `JAX pytree <https://jax.readthedocs.io/en/latest/working-with-pytrees.html>`__ of parameters as the input of the next step.
* The Flax NNX training step does not need to return anything, because the ``model`` was already updated in-place within :meth:`nnx.jit<flax.nnx.jit>`.
* In addition, :class:`nnx.Module<flax.nnx.Module>` objects are stateful, and ``Module`` automatically tracks several things within it, such as PRNG keys and ``flax.nnx.BatchNorm`` stats. That is why you don't need to explicitly pass a PRNG key in at every step. Also note that you can use :meth:`flax.nnx.reseed<flax.nnx.reseed>` to reset its underlying PRNG state.

Dropout behavior:
The dropout behavior:

* In Haiku, you need to explicitly define and pass in the ``training`` argument to toggle ``hk.dropout`` and make sure random dropout only happens if ``training=True``.
* In Flax, you can call ``model.train()`` (:meth:`flax.nnx.Module.train`) to automatically switch :class:`nnx.Dropout<flax.nnx.Dropout>` to the training mode. Conversely, you can call ``model.eval()`` (:meth:`flax.nnx.Module.eval`) to turn off the training mode. You can learn more about what ``nnx.Module.train`` does in its `API reference <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module.train>`__.
* In Haiku, you need to explicitly define and pass in the ``training`` argument to toggle ``haiku.dropout`` and make sure that random dropout only happens if ``training=True``.
* In Flax NNX, you can call ``model.train()`` (:meth:`flax.nnx.Module.train`) to automatically switch :class:`flax.nnx.Dropout<flax.nnx.Dropout>` to the training mode. Conversely, you can call ``model.eval()`` (:meth:`flax.nnx.Module.eval`) to turn off the training mode. You can learn more about what ``flax.nnx.Module.train`` does in its `API reference <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module.train>`__.

.. codediff::
:title: Haiku, Flax NNX
Expand Down Expand Up @@ -208,7 +217,7 @@ Dropout behavior:


Handling non-parameter states
======================
=============================

Haiku makes a distinction between trainable parameters and all other data ("states") that the model tracks. For example, the batch stats used in batch norm is considered a state. Models with states needs to be transformed with ``hk.transform_with_state`` so that their ``.init()`` returns both params and states.

Expand Down Expand Up @@ -693,3 +702,5 @@ be set and accessed as normal using regular Python class semantics.
_, params, counter = nnx.split(model, nnx.Param, Counter)




Loading