From 4ed2ba8c3e6ef32ab6b65bbe23bddffbe258e239 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 28 Oct 2024 21:21:29 +0000 Subject: [PATCH] Upgrade Flax NNX Migrating from Haiku doc --- docs_nnx/guides/haiku_to_flax.rst | 65 ++++++++++++++++++------------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/docs_nnx/guides/haiku_to_flax.rst b/docs_nnx/guides/haiku_to_flax.rst index 4349b53d0b..0e4684714b 100644 --- a/docs_nnx/guides/haiku_to_flax.rst +++ b/docs_nnx/guides/haiku_to_flax.rst @@ -3,8 +3,9 @@ Migrating from Haiku to Flax This guide demonstrates the differences between Haiku and Flax NNX models, providing side-by-side example code to help you migrate to the Flax NNX API from Haiku. -To get the most out of this guide, it is highly recommended to go through `Flax NNX basics `__ document, which covers the :class:`nnx.Module` system, `Flax transformations `__, and the `Functional API `__ with examples. +If you are new to Flax NNX, make sure you become familiarized with `Flax NNX basics `__, which covers the :class:`nnx.Module` system, `Flax transformations `__, and the `Functional API `__ with examples. +Let’s start with some imports. .. testsetup:: Haiku, Flax NNX @@ -14,16 +15,25 @@ To get the most out of this guide, it is highly recommended to go through `Flax from typing import Any -Basic Module Definition +Basic Module definition ======================= -Both Haiku and Flax use the ``Module`` class as the default unit to express a neural network library layer. In the example below, you first create a ``Block`` (by subclassing ``Module``) composed of one linear layer with dropout and a ReLU activation function; then you use it as a sub-``Module`` when creating a ``Model`` (also by subclassing ``Module``), which is made up of ``Block`` and a linear layer. +Both Haiku and Flax use the ``Module`` class as the default unit to express a neural network library layer. For example, to create a one-layer network with dropout and a ReLU activation function, you: + +* First, create a ``Block`` (by subclassing ``Module``) composed of one linear layer with dropout and a ReLU activation function. +* Then, use ``Block`` as a sub-``Module`` when creating a ``Model`` (also by subclassing ``Module``), which is made up of ``Block`` and a linear layer. There are two fundamental differences between Haiku and Flax ``Module`` objects: -* **Stateless vs. stateful**: A ``hk.Module`` instance is stateless - the variables are returned from a purely functional ``Module.init()`` call and managed separately. A :class:`flax.nnx.Module`, however, owns its variables as attributes of this Python object. +* **Stateless vs. stateful**: + + * A ``haiku.Module`` instance is stateless. This means, the variables are returned from a purely functional ``Module.init()`` call and managed separately. + * A :class:`flax.nnx.Module`, however, owns its variables as attributes of this Python object. + +* **Lazy vs. eager**: -* **Lazy vs. eager**: A ``hk.Module`` only allocates space to create variables when they actually see the input when the user calls the model (lazy). A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager). + * A ``haiku.Module`` only allocates space to create variables when they actually see the input when the user calls the model (lazy). + * A ``flax.nnx.Module`` instance creates variables the moment they are instantiated, before seeing a sample input (eager). .. codediff:: @@ -82,13 +92,13 @@ There are two fundamental differences between Haiku and Flax ``Module`` objects: Variable creation -====================== +================= -Next, let's discuss instantiating the model and initializing its parameters +This section is about instantiating a model and initializing its parameters. -* To generate model parameters for a Haiku model, you need to put it inside a forward function and use ``hk.transform`` to make it purely functional. This results in a nested dictionary of `JAX Arrays `__ (``jax.Array`` data types) to be carried around and maintained separately. +* To generate model parameters for a Haiku model, you need to put it inside a forward function and use ``haiku.transform`` to make it purely functional. This results in a nested dictionary of `JAX Arrays `__ (``jax.Array`` data types) to be carried around and maintained separately. -* In Flax, the model parameters are automatically initialized when you instantiate the model, and the variables (:class:`nnx.Variable` objects) are stored inside the :class:`nnx.Module` (or its sub-Module) as attributes. You still need to provide it with a `pseudorandom number generator (PRNG) `__ key, but that key will be wrapped inside an :class:`nnx.Rngs` class and stored inside, generating more PRNG keys when needed. +* In Flax NNX, the model parameters are automatically initialized when you instantiate the model, and the variables (:class:`nnx.Variable` objects) are stored inside the :class:`nnx.Module` (or its sub-Module) as attributes. You still need to provide it with a `pseudorandom number generator (PRNG) `__ key, but that key will be wrapped inside an :class:`nnx.Rngs` class and stored inside, generating more PRNG keys when needed. If you want to access Flax model parameters in the stateless, dictionary-like fashion for checkpoint saving or model surgery, check out the `Flax NNX split/merge API `__ (:func:`nnx.split` / :func:`nnx.merge`). @@ -122,35 +132,34 @@ If you want to access Flax model parameters in the stateless, dictionary-like fa assert model.block.linear.kernel.value.shape == (784, 256) Training step and compilation -====================== +============================= +This section covers writing a training step and compiling it using the `JAX just-in-time compilation `__. -Now, let's proceed to writing a training step and compiling it using `JAX just-in-time compilation `__. Below are certain differences between Haiku and Flax NNX approaches. +When compiling the training step: -Compiling the training step: +* Haiku uses ``@jax.jit`` - a `JAX transformation `__ - to compile a purely functional training step. +* Flax NNX uses :meth:`@nnx.jit` - a `Flax NNX transformation `__ (one of several transform APIs that behave similarly to JAX transforms, but also `work well with Flax objects `__). While ``jax.jit`` only accepts functions with pure stateless arguments, ``flax.nnx.jit`` allows the arguments to be stateful Modules. This greatly reduces the number of lines needed for a train step. -* Haiku uses ``@jax.jit`` - a `JAX transform `__ - to compile a purely functional training step. -* Flax NNX uses :meth:`@nnx.jit` - a `Flax NNX transform `__ (one of several transform APIs that behave similarly to JAX transforms, but also `work well with Flax objects `__). So, while ``jax.jit`` only accepts functions with pure stateless arguments, ``nnx.jit`` allows the arguments to be stateful Modules. This greatly reduced the number of lines needed for a train step. +When taking gradients: -Taking gradients: +* Similarly, Haiku uses ``jax.grad`` (a JAX transformation for `automatic differentiation `__) to return a raw dictionary of gradients. +* Meanwhile, Flax NNX uses :meth:`flax.nnx.grad` (a Flax NNX transformation) to return the gradients of Flax NNX Modules as :class:`flax.nnx.State` dictionaries. If you want to use regular ``jax.grad`` with Flax NNX, you need to use the `split/merge API `__. -* Similarly, Haiku uses ``jax.grad`` (a JAX transform for `automatic differentiation `__) to return a raw dictionary of gradients. -* Flax NNX uses :meth:`nnx.grad` (a Flax NNX transform) to return the gradients of NNX Modules as :class:`nnx.State` dictionaries. If you want to use regular ``jax.grad`` with Flax NNX you need to use the `split/merge API `__. +For optimizers: -Optimizers: - -* If you are already using `Optax `__ optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here) with Haiku, check out the :class:`nnx.Optimizer` example in the `Flax basics `__ guide for a much more concise way of training and updating your model. +* If you are already using `Optax `__ optimizers like ``optax.adamw`` (instead of the raw ``jax.tree.map`` computation shown here) with Haiku, check out the :class:`flax.nnx.Optimizer` example in the `Flax basics `__ guide for a much more concise way of training and updating your model. Model updates during each training step: -* The Haiku training step needs to return a `pytree `__ of parameters as the input of the next step. -* The Flax training step doesn't need to return anything, because the ``model`` was already updated in-place within :meth:`nnx.jit`. -* In addition, :class:`nnx.Module` objects are stateful, and ``Module`` automatically tracks several things within it, such as PRNG keys and ``BatchNorm`` stats. That is why you don't need to explicitly pass an PRNG key in on every step. Also note that you can use :meth:`nnx.reseed` to reset its underlying PRNG state. +* The Haiku training step needs to return a `JAX pytree `__ of parameters as the input of the next step. +* The Flax NNX training step does not need to return anything, because the ``model`` was already updated in-place within :meth:`nnx.jit`. +* In addition, :class:`nnx.Module` objects are stateful, and ``Module`` automatically tracks several things within it, such as PRNG keys and ``flax.nnx.BatchNorm`` stats. That is why you don't need to explicitly pass a PRNG key in at every step. Also note that you can use :meth:`flax.nnx.reseed` to reset its underlying PRNG state. -Dropout behavior: +The dropout behavior: -* In Haiku, you need to explicitly define and pass in the ``training`` argument to toggle ``hk.dropout`` and make sure random dropout only happens if ``training=True``. -* In Flax, you can call ``model.train()`` (:meth:`flax.nnx.Module.train`) to automatically switch :class:`nnx.Dropout` to the training mode. Conversely, you can call ``model.eval()`` (:meth:`flax.nnx.Module.eval`) to turn off the training mode. You can learn more about what ``nnx.Module.train`` does in its `API reference `__. +* In Haiku, you need to explicitly define and pass in the ``training`` argument to toggle ``haiku.dropout`` and make sure that random dropout only happens if ``training=True``. +* In Flax NNX, you can call ``model.train()`` (:meth:`flax.nnx.Module.train`) to automatically switch :class:`flax.nnx.Dropout` to the training mode. Conversely, you can call ``model.eval()`` (:meth:`flax.nnx.Module.eval`) to turn off the training mode. You can learn more about what ``flax.nnx.Module.train`` does in its `API reference `__. .. codediff:: :title: Haiku, Flax NNX @@ -208,7 +217,7 @@ Dropout behavior: Handling non-parameter states -====================== +============================= Haiku makes a distinction between trainable parameters and all other data ("states") that the model tracks. For example, the batch stats used in batch norm is considered a state. Models with states needs to be transformed with ``hk.transform_with_state`` so that their ``.init()`` returns both params and states. @@ -693,3 +702,5 @@ be set and accessed as normal using regular Python class semantics. _, params, counter = nnx.split(model, nnx.Param, Counter) + +