From b1ba25f986c02d50b2f1574f86539acb0c69a3e2 Mon Sep 17 00:00:00 2001
From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com>
Date: Mon, 7 Oct 2024 00:28:04 +0000
Subject: [PATCH] Upgrade Flax NNX Bridge guide
---
docs_nnx/guides/bridge_guide.ipynb | 498 +++++++++++++++++++----------
docs_nnx/guides/bridge_guide.md | 287 +++++++++--------
2 files changed, 475 insertions(+), 310 deletions(-)
diff --git a/docs_nnx/guides/bridge_guide.ipynb b/docs_nnx/guides/bridge_guide.ipynb
index e41836a93e..39b553161b 100644
--- a/docs_nnx/guides/bridge_guide.ipynb
+++ b/docs_nnx/guides/bridge_guide.ipynb
@@ -4,22 +4,29 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Use Flax NNX and Linen together\n",
+ "# Use Flax NNX and Linen together via `nnx.bridge`\n",
"\n",
- "This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API.\n",
+ "This guide is designed to assist existing Flax users who want to mix Flax NNX and Flax Linen `Module`s in their codebase. Bridging NNX and Linen code is made possible with the help of the [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html) API. This document should enable you to move to and try out Flax NNX at your own pace, and leverage \"the best of both worlds\". This can be particularly helpful if you:\n",
"\n",
- "This will be helpful if you:\n",
+ "* Want to migrate your codebase to [Flax NNX](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) from [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) gradually, one `Module` at a time; and/or\n",
+ "* Have an external dependency that has already been moved to Flax NNX, but you have not done so. Alternatively, it may still be in Flax Linen while you've moved your code to Flax NNX.\n",
"\n",
- "* Want to migrate your codebase to NNX gradually, one module at a time;\n",
- "* Have external dependency that already moved to NNX but you haven't, or is still in Linen while you've moved to NNX.\n",
+ "You will also learn how to resolve certain caveats of interoperating both Flax Linen and Flax NNX APIs. The guide will also teach you some aspects of how Flax Linen and NNX APIs are fundamentally different.\n",
"\n",
- "We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the caveats of interoperating the two APIs, on a few aspects that they are fundamentally different.\n",
+ "Table of contents:\n",
"\n",
- "**Note**:\n",
+ "- A sub-`Module` is all you need\n",
+ "- Basics\n",
+ " - Flax Linen to NNX with `nnx.bridge.lazy_init`/`ToNNX`\n",
+ " - Flax NNX to Linen with `nnx.bridge.ToLinen`\n",
+ "- Handling the JAX PRNG keys\n",
+ "- Flax NNX variable types vs Flax Linen collections\n",
+ "- Partition metadata\n",
+ "- Lifted transformations - go ahead and do it\n",
"\n",
- "This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide. \n",
+ "**Note**: Since this guide describes how to glue a [`flax.linen.Module`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) with a [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module), if you need to _migrate_ an existing Linen `Module` (a.k.a. `nn.Module`) to an NNX `Module`, check out the [Migrate from Haiku to Flax (Linen and NNX)](https://flax.readthedocs.io/en/latest/guides/haiku_to_flax.html) guide. In addition, all [built-in Flax Linen layers](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/layers.html) should have [equivalent Flax NNX versions](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html).\n",
"\n",
- "And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html)."
+ "First, let's import some necessary dependencies:"
]
},
{
@@ -44,44 +51,89 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Submodule is all you need\n",
+ "## A sub-`Module` is all you need\n",
"\n",
- "A Flax model is always a tree of modules - either old Linen modules (`flax.linen.Module`, usually written as `nn.Module`) or NNX modules (`nnx.Module`). \n",
+ "A Flax model is a [pytree](https://jax.readthedocs.io/en/latest/glossary.html#term-pytree) of `Module`s - either an old [`flax.linen.Module`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) (usually written as `nn.Module`) or a new [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module).\n",
"\n",
- "An `nnx.bridge` wrapper glues the two types together, in both ways:\n",
+ "The [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html) wrapper API enables you to glue these two types of `Module`s together in two ways using:\n",
"\n",
- "* `nnx.bridge.ToNNX`: Convert a Linen module to NNX, so that it can be a submodule of another NNX module, or stand alone to be trained in NNX-style training loops.\n",
- "* `nnx.bridge.ToLinen`: Vice versa, convert a NNX module to Linen.\n",
+ "* [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX): Converts a `flax.linen.Module` to NNX, so that it can be a sub-`Module` of another `flax.nnx.Module`, or a standalone `Module` to be trained in NNX style training loops.\n",
+ "* [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen): The opposite of `nnx.bridge.ToNNX` - it converts a `flax.nnx.Module` to `flax.linen.Module`.\n",
"\n",
- "This means you can move in either top-down or bottom-up behavior: convert the whole Linen module to NNX, then gradually move down, or convert all the lower level modules to NNX then move up.\n"
+ "Therefore, you can convert the entire `flax.linen.Module` to Flax NNX, and then gradually “move down” (the “top-down” way), or convert all the lower-level `flax.linen.Module`s to Flax NNX and then “move up” (the “bottom-up” way)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "## The Basics\n",
+ "## Basics\n",
"\n",
- "There are two fundamental difference between Linen and NNX modules:\n",
+ "There are two fundamental differences between [`flax.linen.Module`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) and [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module):\n",
"\n",
- "* **Stateless vs. stateful**: Linen module instances are stateless: variables are returned from a purely functional `.init()` call and managed separately. NNX modules, however, owns its variables as instance attributes.\n",
+ "* **Stateless vs stateful**:\n",
+ " - Flax Linen `Module` instances are stateless: Variables are returned from a purely functional `Module.init()` call and managed separately.\n",
+ " - Flax NNX `Module`s, however, own their variables as instance attributes.\n",
"\n",
- "* **Lazy vs. eager**: Linen modules only allocate space to create variables when they actually see their input. Whereas NNX module instances create their variables the moment they are instantiated, without seeing a sample input.\n",
+ "* **Lazy vs eager**:\n",
+ " - Flax Linen `Module`s only allocate space to create variables when they actually see their input.\n",
+ " - In comparison, Flax NNX `Module` instances create their [`nnx.Variable`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) the moment they are instantiated without seeing a sample input.\n",
"\n",
- "With that in mind, let's look at how the `nnx.bridge` wrappers tackle the differences."
+ "With that in mind, let's review how the [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html) wrappers tackle these differences.\n",
+ "\n",
+ "### Basics: Flax Linen to NNX with `nnx.bridge.lazy_init``/`ToNNX`\n",
+ "\n",
+ "Since `flax.linen.Module`s may require an input to create variables, the Flax team semi-formally supports lazy initialization in the `flax.nnx.Module`s converted from Flax Linen. The Flax Linen variables are created when you give it a sample input. For you, it's calling [`nnx.bridge.lazy_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX.lazy_init) (`nnx.bridge.ToNNX.lazy_init`) where you call `module.init()` in the Flax Linen code.\n",
+ "\n",
+ "> **Note:** To inspect all `nnx.Module` variables and state, You can call [`nnx.display`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/visualization.html#flax.nnx.display)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a3db6428",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class LinenDot(nn.Module):\n",
+ " out_dim: int\n",
+ " w_init: Callable[..., Any] = nn.initializers.lecun_normal()\n",
+ " @nn.compact\n",
+ " def __call__(self, x):\n",
+ " # Flax Linen might need the input shape to create the weight!\n",
+ " w = self.param('w', self.w_init, (x.shape[-1], self.out_dim))\n",
+ " return x @ w\n",
+ "\n",
+ "x = jax.random.normal(jax.random.key(42), (4, 32))\n",
+ "model = bridge.ToNNX(LinenDot(64),\n",
+ " rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen.\n",
+ "bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen.\n",
+ "y = model(x) # => `y = model.apply(var, x)` in Linen.\n",
+ "\n",
+ "nnx.display(model)\n",
+ "\n",
+ "# In-place swap your weight array and the model still works!\n",
+ "model.w.value = jax.random.normal(jax.random.key(1), (32, 64))\n",
+ "assert not jnp.allclose(y, model(x))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Linen -> NNX\n",
+ "
\n",
"\n",
- "Since Linen modules may require an input to create variables, we semi-formally supported lazy initialization in the NNX modules converted from Linen. The Linen variables are created when you give it a sample input.\n",
"\n",
- "For you, it's calling `nnx.bridge.lazy_init()` where you call `module.init()` in Linen code.\n",
"\n",
- "(Note: you can call `nnx.display` upon any NNX module to inspect all its variables and state.)"
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "982639de",
+ "metadata": {},
+ "source": [
+ "The [`nnx.bridge.lazy_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX.lazy_init) method also works even if the top-level `Module` is a pure-NNX one, so you can perform \"sub-moduling\" as you wish:"
]
},
{
@@ -115,33 +167,38 @@
}
],
"source": [
- "class LinenDot(nn.Module):\n",
- " out_dim: int\n",
- " w_init: Callable[..., Any] = nn.initializers.lecun_normal()\n",
- " @nn.compact\n",
+ "class NNXOuter(nnx.Module):\n",
+ " def __init__(self, out_dim: int, rngs: nnx.Rngs):\n",
+ " self.dot = nnx.bridge.ToNNX(LinenDot(out_dim), rngs=rngs)\n",
+ " self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, out_dim,)))\n",
+ "\n",
" def __call__(self, x):\n",
- " # Linen might need the input shape to create the weight!\n",
- " w = self.param('w', self.w_init, (x.shape[-1], self.out_dim))\n",
- " return x @ w\n",
+ " return self.dot(x) + self.b\n",
"\n",
"x = jax.random.normal(jax.random.key(42), (4, 32))\n",
- "model = bridge.ToNNX(LinenDot(64), \n",
- " rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen\n",
- "bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen\n",
- "y = model(x) # => `y = model.apply(var, x)` in Linen\n",
+ "model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit them on one line.\n",
+ "nnx.display(model)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
\n",
"\n",
- "nnx.display(model)\n",
"\n",
- "# In-place swap your weight array and the model still works!\n",
- "model.w.value = jax.random.normal(jax.random.key(1), (32, 64))\n",
- "assert not jnp.allclose(y, model(x))"
+ "\n",
+ "
"
]
},
{
"cell_type": "markdown",
+ "id": "a5bc171f",
"metadata": {},
"source": [
- "`nnx.bridge.lazy_init` also works even if the top-level module is a pure-NNX one, so you can do sub-moduling as you wish:"
+ "The Flax Linen weight is already converted to a typical Flax NNX variable ([`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)), which is a thin wrapper of the actual [`jax.Array`](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) value within. Here, `w` is an [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) because it belongs to the `params` collection of `LinenDot` `flax.linen.Module`.\n",
+ "\n",
+ "Different collections and types are covered in more detail in the _Flax NNX variable types vs Flax Linen collections_ section. Right now, you just need to know that they are converted to Flax `nnx.Variable`s like native ones."
]
},
{
@@ -174,34 +231,6 @@
"output_type": "display_data"
}
],
- "source": [
- "class NNXOuter(nnx.Module):\n",
- " def __init__(self, out_dim: int, rngs: nnx.Rngs):\n",
- " self.dot = nnx.bridge.ToNNX(LinenDot(out_dim), rngs=rngs)\n",
- " self.b = nnx.Param(jax.random.uniform(rngs.params(), (1, out_dim,)))\n",
- "\n",
- " def __call__(self, x):\n",
- " return self.dot(x) + self.b\n",
- "\n",
- "x = jax.random.normal(jax.random.key(42), (4, 32))\n",
- "model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit into one line\n",
- "nnx.display(model)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The Linen weight is already converted to a typical NNX variable, which is a thin wrapper of the actual JAX array value within. Here, `w` is an `nnx.Param` because it belongs to the `params` collection of `LinenDot` module.\n",
- "\n",
- "We will talk more about different collections and types in the [NNX Variable <-> Linen Collections](#variable-types-vs-collections) section. Right now, just know that they are converted to NNX variables like native ones."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
"source": [
"assert isinstance(model.dot.w, nnx.Param)\n",
"assert isinstance(model.dot.w.value, jax.Array)"
@@ -211,7 +240,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "If you create this model witout using `nnx.bridge.lazy_init`, the NNX variables defined outside will be initialized as usual, but the Linen part (wrapped inside `ToNNX`) will not."
+ "If you create this model without using [`nnx.bridge.lazy_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX.lazy_init), the Flax [`nnx.Variable`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) defined outside will be initialized as usual, but the Flax Linen part (that is wrapped inside of [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX)) will not."
]
},
{
@@ -249,6 +278,18 @@
"nnx.display(partial_model)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "7a551761",
+ "metadata": {},
+ "source": [
+ "
\n",
+ "\n",
+ "\n",
+ "\n",
+ "
"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 6,
@@ -288,11 +329,25 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### NNX -> Linen\n",
+ "
\n",
+ "\n",
"\n",
- "To convert an NNX module to Linen, you should forward your creation arguments to `bridge.ToLinen` and let it handle the actual creation process.\n",
"\n",
- "This is because NNX module instance initializes all its variables eagerly when it is created, which consumes memory and compute. On the other hand, Linen modules are stateless, and the typical `init` and `apply` process involves multiple creation of them. So `bridge.to_linen` will handle the actual module creation and make sure no memory is allocated twice."
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d6c2dffa",
+ "metadata": {},
+ "source": [
+ "### Basics: Flax NNX to Linen `nnx.bridge.ToLinen`\n",
+ "\n",
+ "To convert a `flax.nnx.Module` to Flax Linen, you should forward your creation arguments to [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) and let it handle the actual creation process.\n",
+ "\n",
+ "This is because:\n",
+ "- The `flax.nnx.Module` instance initializes all its variables eagerly when it is created, which consumes memory and compute.\n",
+ "- On the other hand, `flax.linen.Module`s are stateless, and the typical `init` and `apply` process involves multiple creation of them. Therefore, [`nnx.bridge.to_linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.to_linen) will handle the actual `Module` creation and make sure no memory is allocated twice."
]
},
{
@@ -319,21 +374,31 @@
" return x @ self.w\n",
"\n",
"x = jax.random.normal(jax.random.key(42), (4, 32))\n",
- "# Pass in the arguments, not an actual module\n",
+ "# Pass in the arguments, not an actual `Module`.\n",
"model = bridge.to_linen(NNXDot, 32, out_dim=64)\n",
"variables = model.init(jax.random.key(0), x)\n",
"y = model.apply(variables, x)\n",
"\n",
"print(list(variables.keys()))\n",
"print(variables['params']['w'].shape) # => (32, 64)\n",
- "print(y.shape) # => (4, 64)\n"
+ "print(y.shape) # => (4, 64)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Note that `ToLinen` modules need to track an extra variable collection - `nnx` - for the static metadata of the underlying NNX module."
+ " ['nnx', 'params']\n",
+ " (32, 64)\n",
+ " (4, 64)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "de1b26a5",
+ "metadata": {},
+ "source": [
+ "Note that [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) `Module`s need to track an extra variable collection - `nnx` - for the static metadata of the underlying `nnx.Module`."
]
},
{
@@ -358,7 +423,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "`bridge.to_linen` is actually a convenience wrapper around the Linen module `bridge.ToLinen`. Most likely you won't need to use `ToLinen` directly at all, unless you are using one of the built-in arguments of `ToLinen`. For example, if your NNX module doesn't want to be initialized with RNG handling:"
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c8880236",
+ "metadata": {},
+ "source": [
+ "[`nnx.bridge.to_linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.to_linen) is actually a convenience wrapper around the Flax Linen Module [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen). Most likely you won't need to use `nnx.bridge.ToLinen directly at all, unless you are using one of the built-in arguments of `nnx.bridge.ToLinen`. For example, if your `nnx.Module` doesn't want to be initialized with PRNG handling:"
]
},
{
@@ -373,8 +446,8 @@
" def __call__(self, x):\n",
" return x + self.constant\n",
"\n",
- "# You have to use `skip_rng=True` because this module's `__init__` don't\n",
- "# take `rng` as argument\n",
+ "# You have to use `skip_rng=True` because your module `__init__` don't\n",
+ "# take `rng` as an argument.\n",
"model = bridge.ToLinen(NNXAddConstant, skip_rng=True)\n",
"y, var = model.init_with_output(jax.random.key(0), x)"
]
@@ -383,22 +456,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "Similar to `ToNNX`, you can use `ToLinen` to create a submodule of another Linen module. "
+ "Similar to [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX), you can use [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) to create a sub-`Module` of another `flax.linen.Module`."
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 11,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "(32, 64) (1, 64) (4, 64)\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"class LinenOuter(nn.Module):\n",
" out_dim: int\n",
@@ -419,37 +484,53 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Handling RNG keys\n",
- "\n",
- "All Flax modules, Linen or NNX, automatically handle the RNG keys for variable creation and random layers like dropouts. However, the specific logics of RNG key splitting are different, so you cannot generate the same params between Linen and NNX modules, even if you pass in same keys.\n",
- "\n",
- "Another difference is that NNX modules are stateful, so they can track and update the RNG keys within themselves."
+ " (32, 64) (1, 64) (4, 64)"
]
},
{
"cell_type": "markdown",
+ "id": "3ded9b4d",
"metadata": {},
"source": [
- "### Linen to NNX\n",
+ "## Handling the JAX PRNG keys\n",
"\n",
- "If you convert a Linen module to NNX, you enjoy the stateful benefit and don't need to pass in extra RNG keys on every module call. You can use always `nnx.reseed` to reset the RNG state within."
+ "All Flax `Module`s - in [Linen](https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) or [NNX](https://flax.readthedocs.io/en/latest/guides/randomness.html) - can automatically handle the JAX [pseudorandom number generator (PRNG)](https://jax.readthedocs.io/en/latest/random-numbers.html) keys for variable creation and random layers like dropouts. However, the specific logics of PRNG key splitting are different, so you cannot generate the same params between Linen and NNX `Module`s, even if you pass in the same keys.\n",
+ "\n",
+ "Another difference is that NNX modules are stateful, so they can track and update the RNG keys within themselves.\n",
+ "\n",
+ "> **Note:** To refresh your memory of PRNG key handling, review [JAX PRNG 101](https://jax.readthedocs.io/en/latest/random-numbers.html), [JAX - The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#jax-prng), [Flax NNX Randomness](https://flax.readthedocs.io/en/latest/guides/randomness.html), and [Flax Linen Randomness and PRNGs](https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html).\n",
+ "\n",
+ "### PRNG keys: Flax Linen to NNX - Enjoy the stateful benefits!\n",
+ "\n",
+ "If you convert a Flax Linen `Module` to NNX, you can enjoy the stateful benefits and don't need to pass in extra PRNG keys on every `nnx.Module` call. And you can use always [`nnx.reseed`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.reseed) to reset the PRNG state within."
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 12,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The RNG key in state: Array((), dtype=key) overlaying:\n",
+ "[1428664606 3351135085]\n",
+ "Number of key splits: 0\n",
+ "Number of key splits after y2: 2\n"
+ ]
+ }
+ ],
"source": [
"x = jax.random.normal(jax.random.key(42), (4, 32))\n",
"model = bridge.ToNNX(nn.Dropout(rate=0.5, deterministic=False), rngs=nnx.Rngs(dropout=0))\n",
- "# We don't really need to call lazy_init because no extra params were created here,\n",
+ "# You don't really need to call `lazy_init` because no extra params were created here,\n",
"# but it's a good practice to always add this line.\n",
"bridge.lazy_init(model, x)\n",
"y1, y2 = model(x), model(x)\n",
"assert not jnp.allclose(y1, y2) # Two runs yield different outputs!\n",
"\n",
- "# Reset the dropout RNG seed, so that next model run will be the same as the first.\n",
+ "# Reset the dropout PRNG seed, so that the next model run will be the same as the first.\n",
"nnx.reseed(model, dropout=0)\n",
"assert jnp.allclose(y1, model(x))"
]
@@ -458,45 +539,36 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### NNX to Linen\n",
+ "### PRNG keys: Flax NNX to Linen - Two handling style options\n",
"\n",
- "If you convert an NNX module to Linen, the underlying NNX module's RNG states will still be part of the top-level `variables`. On the other hand, Linen `apply()` call accepts different RNG keys on each call, which resets the internal Linen environment and allow different random data to be generated.\n",
+ "If you convert a Flax NNX `Module` to Linen, the underlying `flax.nnx.Module's PRNG states will still be part of the top-level variables. On the other hand, the `flax.linen.Module.apply()` call accepts different PRNG keys on each call, which _resets the internal Flax Linen environment and allows different random data to be generated_.\n",
"\n",
- "Now, it really depends on whether your underlying NNX module generates new random data from its RNG state, or from the passed-in argument. Fortunately, `nnx.Dropout` supports both - using passed-in keys if there is any, and use its own RNG state if not.\n",
+ "Now, it really depends on whether your underlying Flax NNX `Module` generates new random data from its PRNG state, or from the passed-in argument. Fortunately, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) supports both - using passed-in keys if there is any, and using its own PRNG state if not.\n",
"\n",
- "And this leaves you with two style options of handling the RNG keys: \n",
+ "And this leaves you with two style options of handling the PRNG keys:\n",
"\n",
- "* The NNX style (recommended): Let the underlying NNX state manage the RNG keys, no need to pass in extra keys in `apply()`. This means a few more lines to mutate the `variables` for every apply call, but things will look easier once your whole model no longer needs `ToLinen`.\n",
+ "* The Flax NNX style (_recommended_): Let the underlying NNX state manage the PRNG keys, no need to pass in extra keys in `apply()`. This means a few more lines to mutate the `variables` for every apply call, but things will look easier once your whole model no longer needs [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen).\n",
+ "* The Flax Linen style: Just pass different PRNG keys for every `apply()` call.\n",
"\n",
- "* The Linen style: Just pass different RNG keys for every `apply()` call."
+ "> **Note:** You can make use of the [Flax NNX Randomness](https://flax.readthedocs.io/en/latest/guides/randomness.html), and [Flax Linen Randomness and PRNGs](https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) tutorials that can help you better understand PRNG handling in Flax."
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
+ "id": "d175b29a",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "The RNG key in state: Array((), dtype=key) overlaying:\n",
- "[1428664606 3351135085]\n",
- "Number of key splits: 0\n",
- "Number of key splits after y2: 2\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"x = jax.random.normal(jax.random.key(42), (4, 32))\n",
"model = bridge.to_linen(nnx.Dropout, rate=0.5)\n",
"variables = model.init({'dropout': jax.random.key(0)}, x)\n",
"\n",
- "# The NNX RNG state was stored inside `variables`\n",
+ "# The Flax NNX PRNG state was stored inside `variables`.\n",
"print('The RNG key in state:', variables['RngKey']['rngs']['dropout']['key'].value)\n",
"print('Number of key splits:', variables['RngCount']['rngs']['dropout']['count'].value)\n",
"\n",
- "# NNX style: Must set `RngCount` as mutable and update the variables after every `apply`\n",
+ "# Flax NNX style: Must set `RngCount` as mutable and update the variables after every `apply`.\n",
"y1, updates = model.apply(variables, x, mutable=['RngCount'])\n",
"variables |= updates\n",
"y2, updates = model.apply(variables, x, mutable=['RngCount'])\n",
@@ -504,7 +576,7 @@
"print('Number of key splits after y2:', variables['RngCount']['rngs']['dropout']['count'].value)\n",
"assert not jnp.allclose(y1, y2) # Every call yields different output!\n",
"\n",
- "# Linen style: Just pass different RNG keys for every `apply()` call.\n",
+ "# Flax Linen style: Just pass different PRNG keys for every `apply()` call.\n",
"y3 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})\n",
"y4 = model.apply(variables, x, rngs={'dropout': jax.random.key(2)})\n",
"assert not jnp.allclose(y3, y4) # Every call yields different output!\n",
@@ -516,24 +588,30 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## NNX variable types vs. Linen collections\n",
- "\n",
- "When you want to group some variables as one category, in Linen you use different collections; in NNX, since all variables shall be top-level Python attributes, you use different variable types.\n",
- "\n",
- "Therefore, when mixing Linen and NNX modules, Flax must know the 1-to-1 mapping between Linen collections and NNX variable types, so that `ToNNX` and `ToLinen` can do the conversion automatically. \n",
- "\n",
- "Flax keeps a registry for this, and it already covers all Flax's built-in Linen collections. You can register extra mapping of NNX variable type and Linen collection names using `nnx.register_variable_name_type_pair`."
+ " The RNG key in state: Array((), dtype=key) overlaying:\n",
+ " [1428664606 3351135085]\n",
+ " Number of key splits: 0\n",
+ " Number of key splits after y2: 2"
]
},
{
"cell_type": "markdown",
+ "id": "252b478b",
"metadata": {},
"source": [
- "### Linen to NNX\n",
+ "## Flax NNX variable types vs Flax Linen collections\n",
+ "\n",
+ "When you want to group certain variables in one category, in Flax Linen you use different collections. In Flax NNX, because all variables shall be top-level Python attributes, you use different variable types.\n",
+ "\n",
+ "Therefore, when mixing Flax Linen and NNX `Module`s, Flax must know the 1-to-1 mapping between Flax Linen collections and Flax NNX variable types, so that [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX) and [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) can do the conversion automatically.\n",
+ "\n",
+ "Flax keeps a registry for this, and it already covers all Flax's built-in Linen collections. You can register extra mapping of Flax NNX variable types and Flax Linen collection names using [`flax.nnx.register_variable_name_type_pair`](https://flax.readthedocs.io/en/latest/_modules/flax/nnx/bridge/variables.html).\n",
+ "\n",
+ "### Variables and collections: Flax Linen to NNX\n",
"\n",
- "For any collection of your Linen module, `ToNNX` will convert all its endpoint arrays (aka. leaves) to a subtype of `nnx.Variable`, either from registry or automatically created on-the-fly. \n",
+ "For any collection of your Linen module, [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX) will convert all its endpoint arrays (a.k.a. [pytree](https://jax.readthedocs.io/en/latest/glossary.html#term-pytree) [leaves](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#mistaking-pytree-nodes-for-leaves)) to a subtype of [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), either from registry or automatically created on-the-fly.\n",
"\n",
- "(However, we still keep the whole collection as one class attribute, because Linen modules may have duplicated names over different collections.)"
+ "> **Note:** However, you still keep the entire collection(s) as one class attribute, because `flax.linen.Module`s may have duplicated names over different collections."
]
},
{
@@ -585,7 +663,6 @@
"print(model.b) # Of type `nnx.Param`\n",
"print(model.count) # Of type `counter` - auto-created type from the collection name\n",
"print(type(model.count))\n",
- "\n",
"y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger\n",
"print(model.dot_sum) # Of type `nnx.Intermediates`"
]
@@ -594,9 +671,30 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "You can quickly separate different types of NNX variables apart using `nnx.split`.\n",
- "\n",
- "This can be handy when you only want to set some variables as trainable."
+ " Param(\n",
+ " value=Array([[ 0.35401407, 0.38010964, -0.20674096],\n",
+ " [-0.7356256 , 0.35613298, -0.5099556 ],\n",
+ " [-0.4783049 , 0.4310735 , 0.30137998],\n",
+ " [-0.6102254 , -0.2668519 , -1.053598 ]], dtype=float32)\n",
+ " )\n",
+ " Param(\n",
+ " value=Array([0., 0., 0.], dtype=float32)\n",
+ " )\n",
+ " counter(\n",
+ " value=Array(0, dtype=int32)\n",
+ " )\n",
+ " \n",
+ " (Intermediate(\n",
+ " value=Array(6.932987, dtype=float32)\n",
+ " ),)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "149f3886",
+ "metadata": {},
+ "source": [
+ "You can quickly separate different types of Flax NNX variables apart using [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). This can be handy when you only want to set certain variables as trainable."
]
},
{
@@ -621,7 +719,6 @@
"print('All Params:', list(params.keys()))\n",
"print('All Counters:', list(counter.keys()))\n",
"print('All the rest (intermediates and RNG keys):', list(the_rest.keys()))\n",
- "\n",
"model = nnx.merge(static, params, counter, the_rest) # You can merge them back at any time\n",
"y = model(x, mutable=True) # still works!"
]
@@ -630,9 +727,19 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### NNX to Linen\n",
+ " All Params: ['b', 'w']\n",
+ " All Counters: ['count']\n",
+ " All the rest (intermediates and RNG keys): ['dot_sum', 'rngs']"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a694b265",
+ "metadata": {},
+ "source": [
+ "### Variables and collections: Flax NNX to Linen\n",
"\n",
- "If you define custom NNX variable types, you should register their names with `nnx.register_variable_name_type_pair` so that they go to the desired collections."
+ "If you define custom Flax NNX variable types, you should register their names with [`nnx.register_variable_name_type_pair`](https://flax.readthedocs.io/en/latest/_modules/flax/nnx/bridge/variables.html) so that they go to the desired collections."
]
},
{
@@ -678,26 +785,37 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Partition metadata\n",
- "\n",
- "Flax uses a metadata wrapper box over the raw JAX array to annotate how a variable should be sharded.\n",
- "\n",
- "In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too.\n",
- "\n",
- "The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen and `nnx.with_partitioning` for NNX)."
+ " All Linen collections: ['nnx', 'LoRAParam', 'params', 'counts']\n",
+ " {'w': Array([[ 0.2916921 , 0.22780475, 0.06553137],\n",
+ " [ 0.17487915, -0.34043145, 0.24764155],\n",
+ " [ 0.6420431 , 0.6220095 , -0.44769976],\n",
+ " [ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32)}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Linen to NNX\n",
+ "## Partition metadata\n",
+ "\n",
+ "Flax uses a metadata wrapper box over the raw [`jax.Array`](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) to annotate how a variable should be sharded.\n",
"\n",
- "Even if you are not using any partition metadata in your Linen module, the variable JAX arrays will be converted to `nnx.Variable`s that wraps the true JAX array within. \n",
+ "I- n Flax Linen, this is an optional feature that is triggered by using [`flax.linen.with_partitioning`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) (`nn.with_partitioning`) on initializers.\n",
+ "- In Flax NNX, since all Flax NNX variables are wrapped by [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) class anyway, that class will hold the sharding annotations too.\n",
"\n",
- "If you use `nn.with_partitioning` to annotate your Linen module's variables, the annotation will be converted to a `.sharding` field in the corresponding `nnx.Variable`. \n",
+ "> **Note:** If you are new to `jax.Array`s and _data sharding_, go to [Key concepts](https://jax.readthedocs.io/en/latest/key-concepts.html#array-devices-and-sharding) and [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html#sharded-computation) on the JAX documentation site.\n",
"\n",
- "You can then use `nnx.with_sharding_constraint` to explicitly put the arrays into the annotated partitions within a `jax.jit`-compiled function, to initialize the whole model with every array at the right sharding."
+ "Both [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX) and [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) will automatically convert the sharding annotations if you use the built-in annotation methods, such as [`flax.linen.with_partitioning`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) (`nn.with_partitioning`) or [`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning).\n",
+ "\n",
+ "> **Note:** To get more familiarized with sharding metadata with Flax and JAX, refer to Flax NNX’s [Scale up](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) guide, JAX’s [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html#sharded-computation), and the Flax Linen [Scale up](https://flax-linen.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) guide.\n",
+ "\n",
+ "### Partition metadata: Flax Linen to NNX\n",
+ "\n",
+ "Even if you are not using any partition metadata in your Flax Linen `Module`, the variable [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) will be converted to [`nnx.Variable`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) that wrap the true `jax.Array` within.\n",
+ "\n",
+ "If you use [`flax.linen.with_partitioning`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) (`nn.with_partitioning`) to annotate your Flax Linen `Module` variables, the annotation will be converted to the `.sharding` field in the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable).\n",
+ "\n",
+ "You can then use [`nnx.with_sharding_constraint`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_sharding_constraint) to explicitly put the arrays into the annotated partitions within a [`jax.jit`](https://jax.readthedocs.io/en/latest/jit-compilation.html)-compiled function, to initialize the whole model with every array at the right sharding."
]
},
{
@@ -721,8 +839,8 @@
" out_dim: int\n",
" @nn.compact\n",
" def __call__(self, x):\n",
- " w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(), \n",
- " ('in', 'out')), \n",
+ " w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(),\n",
+ "\t\t ('in', 'out')),\n",
" (x.shape[-1], self.out_dim))\n",
" return x @ w\n",
"\n",
@@ -737,7 +855,7 @@
"\n",
"\n",
"print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')\n",
- "mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), \n",
+ "mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)),\n",
" axis_names=('in', 'out'))\n",
"x = jax.random.normal(jax.random.key(42), (4, 32))\n",
"with mesh:\n",
@@ -752,13 +870,24 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### NNX to Linen\n",
+ " We have 8 fake JAX devices now to partition this model...\n",
+ " \n",
+ " ('in', 'out')\n",
+ " GSPMDSharding({devices=[2,4]<=[8]})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5c3ca5f1",
+ "metadata": {},
+ "source": [
+ "### Partition metadata: Flax NNX to Linen\n",
"\n",
- "If you are not using any metadata feature of the `nnx.Variable` (i.e., no sharding annotation, no registered hooks), the converted Linen module will not add a metadata wrapper to your NNX variable, and you don't need to worry about it.\n",
+ "If you are not using any metadata features of the [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) - that is, no sharding annotation, no registered hooks, then the converted `flax.linen.Module` will not add a metadata wrapper to your Flax NNX variable, and you won't need to worry about it. (Recall that all Flax NNX variables are wrapped with `nnx.Variable` box.\n",
"\n",
- "But if you did add sharding annotations to your NNX variables, `ToLinen` will convert them to a default Linen partition metadata class called `bridge.NNXMeta`, retaining all the metadata you put into the NNX variable.\n",
+ "But if you did add sharding annotations to your Flax NNX variables, then [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) will convert them to a default Flax Linen partition metadata class called [`flax.nnx.bridge.NNXMeta`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.NNXMeta), retaining all the metadata you put into the NNX variable.\n",
"\n",
- "Like with any Linen metadata wrappers, you can use `linen.unbox()` to get the raw JAX array tree."
+ "Similar to any Flax Linen metadata wrappers, you can use `flax.linen.unbox()` ([`flax.linen.meta.unbox`](https://github.com/google/flax/blob/5d31452889b8d106d7c722b5eaac14cb9784fec2/flax/core/meta.py#L160)) to get the raw [`jax.Array`](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) [pytree](https://jax.readthedocs.io/en/latest/glossary.html#term-pytree)."
]
},
{
@@ -792,13 +921,13 @@
" assert type(variables['params']['w']) == bridge.NNXMeta\n",
" # The annotation coming from the `nnx.Param` => (in, out)\n",
" assert variables['params']['w'].metadata['sharding'] == ('in', 'out')\n",
- " \n",
+ "\n",
" unboxed_variables = nn.unbox(variables)\n",
" variable_pspecs = nn.get_partition_spec(variables)\n",
" assert isinstance(unboxed_variables['params']['w'], jax.Array)\n",
" assert variable_pspecs['params']['w'] == jax.sharding.PartitionSpec('in', 'out')\n",
- " \n",
- " sharded_vars = jax.tree.map(jax.lax.with_sharding_constraint, \n",
+ "\n",
+ " sharded_vars = jax.tree.map(jax.lax.with_sharding_constraint,\n",
" nn.unbox(variables),\n",
" nn.get_partition_spec(variables))\n",
" return sharded_vars\n",
@@ -814,20 +943,22 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## Lifted transforms\n",
- "\n",
- "In general, if you want to apply Linen/NNX-style lifted transforms upon an `nnx.bridge`-converted module, just go ahead and do it in the usual Linen/NNX syntax. \n",
- "\n",
- "For Linen-style transforms, note that `bridge.ToLinen` is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases)"
+ " GSPMDSharding({devices=[2,4]<=[8]})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Linen to NNX\n",
+ "## Lifted transformations - go ahead and do it\n",
+ "\n",
+ "In general, if you want to apply [Flax Linen-](https://flax-linen.readthedocs.io/en/latest/developer_notes/lift.html) or [Flax NNX style lifted transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) on an [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html)-converted `Module`, just go ahead and do it in the usual Flax Linen or NNX syntax.\n",
"\n",
- "NNX style lifted transforms are similar to JAX transforms, and they work on functions."
+ "For [Flax Linen style transforms](https://flax-linen.readthedocs.io/en/latest/developer_notes/lift.html), note that [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) is the top-level `Module` class, so you may want to just use it as the first argument of your transforms (which needs to be a `flax.linen.Module` class in most cases).\n",
+ "\n",
+ "### Lifted transforms: Flax Linen to NNX\n",
+ "\n",
+ "[Flax NNX style lifted transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) are similar to [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations), and they too work on functions."
]
},
{
@@ -862,7 +993,7 @@
"x = jax.random.normal(jax.random.key(0), (4, 32))\n",
"model = bridge.lazy_init(NNXVmapped(64, 4, rngs=nnx.Rngs(0)), x)\n",
"\n",
- "print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got vmapped\n",
+ "print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got `vmap`ped.\n",
"y = model(x)\n",
"print(y.shape)"
]
@@ -871,11 +1002,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### NNX to Linen\n",
- "\n",
- "Note that `bridge.ToLinen` is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases).\n",
+ " (4, 32, 64)\n",
+ " (4, 64)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "93419218",
+ "metadata": {},
+ "source": [
+ "### Lifted transforms: Flax NNX to Linen\n",
"\n",
- "Also, since `bridge.ToLinen` introduced this extra `nnx` collection, you need to mark it when using the axis-changing transforms (`linen.vmap`, `linen.scan`, etc) to make sure they are passed inside."
+ "As mentioned before, [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) is the top-level `Module` class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases). And, since `nnx.bridge.ToLinen` introduces this extra `nnx` collection, you need to mark it when using the axis-changing transforms ([`flax.linen.vmap`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html#flax.linen.vmap), [`flax.linen.scan`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html#flax.linen.scan), and so on) to make sure they are passed inside."
]
},
{
@@ -904,13 +1042,27 @@
"x = jax.random.normal(jax.random.key(42), (4, 32))\n",
"model = LinenVmapped(64)\n",
"var = model.init(jax.random.key(0), x)\n",
- "print(var['params']['VmapToLinen_0']['kernel'].shape) # (4, 32, 64) - leading dim 4 got vmapped\n",
+ "print(var['params']['VmapToLinen_0']['kernel'].shape) # (4, 32, 64) - leading dim 4 got `vmap`ped.\n",
"y = model.apply(var, x)\n",
"print(y.shape)"
]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "920324e4",
+ "metadata": {},
+ "source": [
+ " (4, 32, 64)\n",
+ " (4, 64)"
+ ]
}
],
"metadata": {
+ "jupytext": {
+ "cell_metadata_filter": "-all",
+ "formats": "ipynb,md:myst",
+ "main_language": "python"
+ },
"language_info": {
"codemirror_mode": {
"name": "ipython",
diff --git a/docs_nnx/guides/bridge_guide.md b/docs_nnx/guides/bridge_guide.md
index 3f243ae2ab..d3f8faf969 100644
--- a/docs_nnx/guides/bridge_guide.md
+++ b/docs_nnx/guides/bridge_guide.md
@@ -1,22 +1,40 @@
-# Use Flax NNX and Linen together
+---
+jupytext:
+ cell_metadata_filter: -all
+ formats: ipynb,md:myst
+ main_language: python
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.13.8
+---
-This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API.
+# Use Flax NNX and Linen together via `nnx.bridge`
-This will be helpful if you:
+This guide is designed to assist existing Flax users who want to mix Flax NNX and Flax Linen `Module`s in their codebase. Bridging NNX and Linen code is made possible with the help of the [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html) API. This document should enable you to move to and try out Flax NNX at your own pace, and leverage "the best of both worlds". This can be particularly helpful if you:
-* Want to migrate your codebase to NNX gradually, one module at a time;
-* Have external dependency that already moved to NNX but you haven't, or is still in Linen while you've moved to NNX.
+* Want to migrate your codebase to [Flax NNX](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) from [Flax Linen](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) gradually, one `Module` at a time; and/or
+* Have an external dependency that has already been moved to Flax NNX, but you have not done so. Alternatively, it may still be in Flax Linen while you've moved your code to Flax NNX.
-We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the caveats of interoperating the two APIs, on a few aspects that they are fundamentally different.
+You will also learn how to resolve certain caveats of interoperating both Flax Linen and Flax NNX APIs. The guide will also teach you some aspects of how Flax Linen and NNX APIs are fundamentally different.
-**Note**:
+Table of contents:
-This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax-nnx.readthedocs.io/en/latest/guides/linen_to_nnx.html) guide.
+- A sub-`Module` is all you need
+- Basics
+ - Flax Linen to NNX with `nnx.bridge.lazy_init`/`ToNNX`
+ - Flax NNX to Linen with `nnx.bridge.ToLinen`
+- Handling the JAX PRNG keys
+- Flax NNX variable types vs Flax Linen collections
+- Partition metadata
+- Lifted transformations - go ahead and do it
-And all built-in Linen layers should have equivalent NNX versions! Check out the list of [Built-in NNX layers](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html).
+**Note**: Since this guide describes how to glue a [`flax.linen.Module`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) with a [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module), if you need to _migrate_ an existing Linen `Module` (a.k.a. `nn.Module`) to an NNX `Module`, check out the [Migrate from Haiku to Flax (Linen and NNX)](https://flax.readthedocs.io/en/latest/guides/haiku_to_flax.html) guide. In addition, all [built-in Flax Linen layers](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/layers.html) should have [equivalent Flax NNX versions](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/index.html).
+First, let's import some necessary dependencies:
-```python
+```{code-cell} ipython3
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
@@ -29,52 +47,54 @@ from jax.experimental import mesh_utils
from typing import *
```
-## Submodule is all you need
+## A sub-`Module` is all you need
-A Flax model is always a tree of modules - either old Linen modules (`flax.linen.Module`, usually written as `nn.Module`) or NNX modules (`nnx.Module`).
+A Flax model is a [pytree](https://jax.readthedocs.io/en/latest/glossary.html#term-pytree) of `Module`s - either an old [`flax.linen.Module`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) (usually written as `nn.Module`) or a new [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module).
-An `nnx.bridge` wrapper glues the two types together, in both ways:
+The [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html) wrapper API enables you to glue these two types of `Module`s together in two ways using:
-* `nnx.bridge.ToNNX`: Convert a Linen module to NNX, so that it can be a submodule of another NNX module, or stand alone to be trained in NNX-style training loops.
-* `nnx.bridge.ToLinen`: Vice versa, convert a NNX module to Linen.
+* [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX): Converts a `flax.linen.Module` to NNX, so that it can be a sub-`Module` of another `flax.nnx.Module`, or a standalone `Module` to be trained in NNX style training loops.
+* [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen): The opposite of `nnx.bridge.ToNNX` - it converts a `flax.nnx.Module` to `flax.linen.Module`.
-This means you can move in either top-down or bottom-up behavior: convert the whole Linen module to NNX, then gradually move down, or convert all the lower level modules to NNX then move up.
+Therefore, you can convert the entire `flax.linen.Module` to Flax NNX, and then gradually “move down” (the “top-down” way), or convert all the lower-level `flax.linen.Module`s to Flax NNX and then “move up” (the “bottom-up” way).
++++
-## The Basics
+## Basics
-There are two fundamental difference between Linen and NNX modules:
+There are two fundamental differences between [`flax.linen.Module`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module) and [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module):
-* **Stateless vs. stateful**: Linen module instances are stateless: variables are returned from a purely functional `.init()` call and managed separately. NNX modules, however, owns its variables as instance attributes.
+* **Stateless vs stateful**:
+ - Flax Linen `Module` instances are stateless: Variables are returned from a purely functional `Module.init()` call and managed separately.
+ - Flax NNX `Module`s, however, own their variables as instance attributes.
-* **Lazy vs. eager**: Linen modules only allocate space to create variables when they actually see their input. Whereas NNX module instances create their variables the moment they are instantiated, without seeing a sample input.
+* **Lazy vs eager**:
+ - Flax Linen `Module`s only allocate space to create variables when they actually see their input.
+ - In comparison, Flax NNX `Module` instances create their [`nnx.Variable`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) the moment they are instantiated without seeing a sample input.
-With that in mind, let's look at how the `nnx.bridge` wrappers tackle the differences.
+With that in mind, let's review how the [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html) wrappers tackle these differences.
-### Linen -> NNX
+### Basics: Flax Linen to NNX with `nnx.bridge.lazy_init``/`ToNNX`
-Since Linen modules may require an input to create variables, we semi-formally supported lazy initialization in the NNX modules converted from Linen. The Linen variables are created when you give it a sample input.
+Since `flax.linen.Module`s may require an input to create variables, the Flax team semi-formally supports lazy initialization in the `flax.nnx.Module`s converted from Flax Linen. The Flax Linen variables are created when you give it a sample input. For you, it's calling [`nnx.bridge.lazy_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX.lazy_init) (`nnx.bridge.ToNNX.lazy_init`) where you call `module.init()` in the Flax Linen code.
-For you, it's calling `nnx.bridge.lazy_init()` where you call `module.init()` in Linen code.
+> **Note:** To inspect all `nnx.Module` variables and state, You can call [`nnx.display`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/visualization.html#flax.nnx.display).
-(Note: you can call `nnx.display` upon any NNX module to inspect all its variables and state.)
-
-
-```python
+```{code-cell} ipython3
class LinenDot(nn.Module):
out_dim: int
w_init: Callable[..., Any] = nn.initializers.lecun_normal()
@nn.compact
def __call__(self, x):
- # Linen might need the input shape to create the weight!
+ # Flax Linen might need the input shape to create the weight!
w = self.param('w', self.w_init, (x.shape[-1], self.out_dim))
return x @ w
x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.ToNNX(LinenDot(64),
- rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen
-bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen
-y = model(x) # => `y = model.apply(var, x)` in Linen
+ rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen.
+bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen.
+y = model(x) # => `y = model.apply(var, x)` in Linen.
nnx.display(model)
@@ -83,18 +103,17 @@ model.w.value = jax.random.normal(jax.random.key(1), (32, 64))
assert not jnp.allclose(y, model(x))
```
-
++++
-`nnx.bridge.lazy_init` also works even if the top-level module is a pure-NNX one, so you can do sub-moduling as you wish:
-
+The [`nnx.bridge.lazy_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX.lazy_init) method also works even if the top-level `Module` is a pure-NNX one, so you can perform "sub-moduling" as you wish:
-```python
+```{code-cell} ipython3
class NNXOuter(nnx.Module):
def __init__(self, out_dim: int, rngs: nnx.Rngs):
self.dot = nnx.bridge.ToNNX(LinenDot(out_dim), rngs=rngs)
@@ -104,66 +123,62 @@ class NNXOuter(nnx.Module):
return self.dot(x) + self.b
x = jax.random.normal(jax.random.key(42), (4, 32))
-model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit into one line
+model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit them on one line.
nnx.display(model)
```
-
++++
-The Linen weight is already converted to a typical NNX variable, which is a thin wrapper of the actual JAX array value within. Here, `w` is an `nnx.Param` because it belongs to the `params` collection of `LinenDot` module.
-
-We will talk more about different collections and types in the [NNX Variable <-> Linen Collections](#variable-types-vs-collections) section. Right now, just know that they are converted to NNX variables like native ones.
+The Flax Linen weight is already converted to a typical Flax NNX variable ([`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable)), which is a thin wrapper of the actual [`jax.Array`](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) value within. Here, `w` is an [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) because it belongs to the `params` collection of `LinenDot` `flax.linen.Module`.
+Different collections and types are covered in more detail in the _Flax NNX variable types vs Flax Linen collections_ section. Right now, you just need to know that they are converted to Flax `nnx.Variable`s like native ones.
-```python
+```{code-cell} ipython3
assert isinstance(model.dot.w, nnx.Param)
assert isinstance(model.dot.w.value, jax.Array)
```
-If you create this model witout using `nnx.bridge.lazy_init`, the NNX variables defined outside will be initialized as usual, but the Linen part (wrapped inside `ToNNX`) will not.
+If you create this model without using [`nnx.bridge.lazy_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX.lazy_init), the Flax [`nnx.Variable`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) defined outside will be initialized as usual, but the Flax Linen part (that is wrapped inside of [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX)) will not.
-
-```python
+```{code-cell} ipython3
partial_model = NNXOuter(64, rngs=nnx.Rngs(0))
nnx.display(partial_model)
```
-
-
-
-```python
+```{code-cell} ipython3
full_model = bridge.lazy_init(partial_model, x)
nnx.display(full_model)
```
-
++++
-### NNX -> Linen
-
-To convert an NNX module to Linen, you should forward your creation arguments to `bridge.ToLinen` and let it handle the actual creation process.
+### Basics: Flax NNX to Linen `nnx.bridge.ToLinen`
-This is because NNX module instance initializes all its variables eagerly when it is created, which consumes memory and compute. On the other hand, Linen modules are stateless, and the typical `init` and `apply` process involves multiple creation of them. So `bridge.to_linen` will handle the actual module creation and make sure no memory is allocated twice.
+To convert a `flax.nnx.Module` to Flax Linen, you should forward your creation arguments to [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) and let it handle the actual creation process.
+This is because:
+- The `flax.nnx.Module` instance initializes all its variables eagerly when it is created, which consumes memory and compute.
+- On the other hand, `flax.linen.Module`s are stateless, and the typical `init` and `apply` process involves multiple creation of them. Therefore, [`nnx.bridge.to_linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.to_linen) will handle the actual `Module` creation and make sure no memory is allocated twice.
-```python
+```{code-cell} ipython3
class NNXDot(nnx.Module):
def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):
self.w = nnx.Param(nnx.initializers.lecun_normal()(
@@ -172,7 +187,7 @@ class NNXDot(nnx.Module):
return x @ self.w
x = jax.random.normal(jax.random.key(42), (4, 32))
-# Pass in the arguments, not an actual module
+# Pass in the arguments, not an actual `Module`.
model = bridge.to_linen(NNXDot, 32, out_dim=64)
variables = model.init(jax.random.key(0), x)
y = model.apply(variables, x)
@@ -180,45 +195,43 @@ y = model.apply(variables, x)
print(list(variables.keys()))
print(variables['params']['w'].shape) # => (32, 64)
print(y.shape) # => (4, 64)
-
```
['nnx', 'params']
(32, 64)
(4, 64)
++++
-Note that `ToLinen` modules need to track an extra variable collection - `nnx` - for the static metadata of the underlying NNX module.
-
+Note that [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) `Module`s need to track an extra variable collection - `nnx` - for the static metadata of the underlying `nnx.Module`.
-```python
+```{code-cell} ipython3
# This new field stores the static data that defines the underlying `NNXDot`
print(type(variables['nnx']['graphdef'])) # => `nnx.graph.NodeDef`
```
++++
-`bridge.to_linen` is actually a convenience wrapper around the Linen module `bridge.ToLinen`. Most likely you won't need to use `ToLinen` directly at all, unless you are using one of the built-in arguments of `ToLinen`. For example, if your NNX module doesn't want to be initialized with RNG handling:
-
+[`nnx.bridge.to_linen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.to_linen) is actually a convenience wrapper around the Flax Linen Module [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen). Most likely you won't need to use `nnx.bridge.ToLinen directly at all, unless you are using one of the built-in arguments of `nnx.bridge.ToLinen`. For example, if your `nnx.Module` doesn't want to be initialized with PRNG handling:
-```python
+```{code-cell} ipython3
class NNXAddConstant(nnx.Module):
def __init__(self):
self.constant = nnx.Variable(jnp.array(1))
def __call__(self, x):
return x + self.constant
-# You have to use `skip_rng=True` because this module's `__init__` don't
-# take `rng` as argument
+# You have to use `skip_rng=True` because your module `__init__` don't
+# take `rng` as an argument.
model = bridge.ToLinen(NNXAddConstant, skip_rng=True)
y, var = model.init_with_output(jax.random.key(0), x)
```
-Similar to `ToNNX`, you can use `ToLinen` to create a submodule of another Linen module.
+Similar to [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX), you can use [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) to create a sub-`Module` of another `flax.linen.Module`.
-
-```python
+```{code-cell} ipython3
class LinenOuter(nn.Module):
out_dim: int
@nn.compact
@@ -236,55 +249,57 @@ print(w.shape, b.shape, y.shape)
(32, 64) (1, 64) (4, 64)
++++
-## Handling RNG keys
+## Handling the JAX PRNG keys
-All Flax modules, Linen or NNX, automatically handle the RNG keys for variable creation and random layers like dropouts. However, the specific logics of RNG key splitting are different, so you cannot generate the same params between Linen and NNX modules, even if you pass in same keys.
+All Flax `Module`s - in [Linen](https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) or [NNX](https://flax.readthedocs.io/en/latest/guides/randomness.html) - can automatically handle the JAX [pseudorandom number generator (PRNG)](https://jax.readthedocs.io/en/latest/random-numbers.html) keys for variable creation and random layers like dropouts. However, the specific logics of PRNG key splitting are different, so you cannot generate the same params between Linen and NNX `Module`s, even if you pass in the same keys.
Another difference is that NNX modules are stateful, so they can track and update the RNG keys within themselves.
-### Linen to NNX
+> **Note:** To refresh your memory of PRNG key handling, review [JAX PRNG 101](https://jax.readthedocs.io/en/latest/random-numbers.html), [JAX - The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#jax-prng), [Flax NNX Randomness](https://flax.readthedocs.io/en/latest/guides/randomness.html), and [Flax Linen Randomness and PRNGs](https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html).
-If you convert a Linen module to NNX, you enjoy the stateful benefit and don't need to pass in extra RNG keys on every module call. You can use always `nnx.reseed` to reset the RNG state within.
+### PRNG keys: Flax Linen to NNX - Enjoy the stateful benefits!
+If you convert a Flax Linen `Module` to NNX, you can enjoy the stateful benefits and don't need to pass in extra PRNG keys on every `nnx.Module` call. And you can use always [`nnx.reseed`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/rnglib.html#flax.nnx.reseed) to reset the PRNG state within.
-```python
+```{code-cell} ipython3
x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.ToNNX(nn.Dropout(rate=0.5, deterministic=False), rngs=nnx.Rngs(dropout=0))
-# We don't really need to call lazy_init because no extra params were created here,
+# You don't really need to call `lazy_init` because no extra params were created here,
# but it's a good practice to always add this line.
bridge.lazy_init(model, x)
y1, y2 = model(x), model(x)
assert not jnp.allclose(y1, y2) # Two runs yield different outputs!
-# Reset the dropout RNG seed, so that next model run will be the same as the first.
+# Reset the dropout PRNG seed, so that the next model run will be the same as the first.
nnx.reseed(model, dropout=0)
assert jnp.allclose(y1, model(x))
```
-### NNX to Linen
-
-If you convert an NNX module to Linen, the underlying NNX module's RNG states will still be part of the top-level `variables`. On the other hand, Linen `apply()` call accepts different RNG keys on each call, which resets the internal Linen environment and allow different random data to be generated.
+### PRNG keys: Flax NNX to Linen - Two handling style options
-Now, it really depends on whether your underlying NNX module generates new random data from its RNG state, or from the passed-in argument. Fortunately, `nnx.Dropout` supports both - using passed-in keys if there is any, and use its own RNG state if not.
+If you convert a Flax NNX `Module` to Linen, the underlying `flax.nnx.Module's PRNG states will still be part of the top-level variables. On the other hand, the `flax.linen.Module.apply()` call accepts different PRNG keys on each call, which _resets the internal Flax Linen environment and allows different random data to be generated_.
-And this leaves you with two style options of handling the RNG keys:
+Now, it really depends on whether your underlying Flax NNX `Module` generates new random data from its PRNG state, or from the passed-in argument. Fortunately, [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) supports both - using passed-in keys if there is any, and using its own PRNG state if not.
-* The NNX style (recommended): Let the underlying NNX state manage the RNG keys, no need to pass in extra keys in `apply()`. This means a few more lines to mutate the `variables` for every apply call, but things will look easier once your whole model no longer needs `ToLinen`.
+And this leaves you with two style options of handling the PRNG keys:
-* The Linen style: Just pass different RNG keys for every `apply()` call.
+* The Flax NNX style (_recommended_): Let the underlying NNX state manage the PRNG keys, no need to pass in extra keys in `apply()`. This means a few more lines to mutate the `variables` for every apply call, but things will look easier once your whole model no longer needs [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen).
+* The Flax Linen style: Just pass different PRNG keys for every `apply()` call.
+> **Note:** You can make use of the [Flax NNX Randomness](https://flax.readthedocs.io/en/latest/guides/randomness.html), and [Flax Linen Randomness and PRNGs](https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html) tutorials that can help you better understand PRNG handling in Flax.
-```python
+```{code-cell} ipython3
x = jax.random.normal(jax.random.key(42), (4, 32))
model = bridge.to_linen(nnx.Dropout, rate=0.5)
variables = model.init({'dropout': jax.random.key(0)}, x)
-# The NNX RNG state was stored inside `variables`
+# The Flax NNX PRNG state was stored inside `variables`.
print('The RNG key in state:', variables['RngKey']['rngs']['dropout']['key'].value)
print('Number of key splits:', variables['RngCount']['rngs']['dropout']['count'].value)
-# NNX style: Must set `RngCount` as mutable and update the variables after every `apply`
+# Flax NNX style: Must set `RngCount` as mutable and update the variables after every `apply`.
y1, updates = model.apply(variables, x, mutable=['RngCount'])
variables |= updates
y2, updates = model.apply(variables, x, mutable=['RngCount'])
@@ -292,7 +307,7 @@ variables |= updates
print('Number of key splits after y2:', variables['RngCount']['rngs']['dropout']['count'].value)
assert not jnp.allclose(y1, y2) # Every call yields different output!
-# Linen style: Just pass different RNG keys for every `apply()` call.
+# Flax Linen style: Just pass different PRNG keys for every `apply()` call.
y3 = model.apply(variables, x, rngs={'dropout': jax.random.key(1)})
y4 = model.apply(variables, x, rngs={'dropout': jax.random.key(2)})
assert not jnp.allclose(y3, y4) # Every call yields different output!
@@ -305,23 +320,23 @@ assert jnp.allclose(y3, y5) # When you use same top-level RNG, outputs are
Number of key splits: 0
Number of key splits after y2: 2
++++
-## NNX variable types vs. Linen collections
+## Flax NNX variable types vs Flax Linen collections
-When you want to group some variables as one category, in Linen you use different collections; in NNX, since all variables shall be top-level Python attributes, you use different variable types.
+When you want to group certain variables in one category, in Flax Linen you use different collections. In Flax NNX, because all variables shall be top-level Python attributes, you use different variable types.
-Therefore, when mixing Linen and NNX modules, Flax must know the 1-to-1 mapping between Linen collections and NNX variable types, so that `ToNNX` and `ToLinen` can do the conversion automatically.
+Therefore, when mixing Flax Linen and NNX `Module`s, Flax must know the 1-to-1 mapping between Flax Linen collections and Flax NNX variable types, so that [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX) and [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) can do the conversion automatically.
-Flax keeps a registry for this, and it already covers all Flax's built-in Linen collections. You can register extra mapping of NNX variable type and Linen collection names using `nnx.register_variable_name_type_pair`.
+Flax keeps a registry for this, and it already covers all Flax's built-in Linen collections. You can register extra mapping of Flax NNX variable types and Flax Linen collection names using [`flax.nnx.register_variable_name_type_pair`](https://flax.readthedocs.io/en/latest/_modules/flax/nnx/bridge/variables.html).
-### Linen to NNX
+### Variables and collections: Flax Linen to NNX
-For any collection of your Linen module, `ToNNX` will convert all its endpoint arrays (aka. leaves) to a subtype of `nnx.Variable`, either from registry or automatically created on-the-fly.
+For any collection of your Linen module, [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX) will convert all its endpoint arrays (a.k.a. [pytree](https://jax.readthedocs.io/en/latest/glossary.html#term-pytree) [leaves](https://jax.readthedocs.io/en/latest/working-with-pytrees.html#mistaking-pytree-nodes-for-leaves)) to a subtype of [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable), either from registry or automatically created on-the-fly.
-(However, we still keep the whole collection as one class attribute, because Linen modules may have duplicated names over different collections.)
+> **Note:** However, you still keep the entire collection(s) as one class attribute, because `flax.linen.Module`s may have duplicated names over different collections.
-
-```python
+```{code-cell} ipython3
class LinenMultiCollections(nn.Module):
out_dim: int
def setup(self):
@@ -342,7 +357,6 @@ print(model.w) # Of type `nnx.Param` - note this is still under attribute
print(model.b) # Of type `nnx.Param`
print(model.count) # Of type `counter` - auto-created type from the collection name
print(type(model.count))
-
y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger
print(model.dot_sum) # Of type `nnx.Intermediates`
```
@@ -364,20 +378,17 @@ print(model.dot_sum) # Of type `nnx.Intermediates`
value=Array(6.932987, dtype=float32)
),)
++++
-You can quickly separate different types of NNX variables apart using `nnx.split`.
+You can quickly separate different types of Flax NNX variables apart using [`nnx.split`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/graph.html#flax.nnx.split). This can be handy when you only want to set certain variables as trainable.
-This can be handy when you only want to set some variables as trainable.
-
-
-```python
+```{code-cell} ipython3
# Separate variables of different types with nnx.split
CountType = type(model.count)
static, params, counter, the_rest = nnx.split(model, nnx.Param, CountType, ...)
print('All Params:', list(params.keys()))
print('All Counters:', list(counter.keys()))
print('All the rest (intermediates and RNG keys):', list(the_rest.keys()))
-
model = nnx.merge(static, params, counter, the_rest) # You can merge them back at any time
y = model(x, mutable=True) # still works!
```
@@ -386,13 +397,13 @@ y = model(x, mutable=True) # still works!
All Counters: ['count']
All the rest (intermediates and RNG keys): ['dot_sum', 'rngs']
++++
-### NNX to Linen
+### Variables and collections: Flax NNX to Linen
-If you define custom NNX variable types, you should register their names with `nnx.register_variable_name_type_pair` so that they go to the desired collections.
+If you define custom Flax NNX variable types, you should register their names with [`nnx.register_variable_name_type_pair`](https://flax.readthedocs.io/en/latest/_modules/flax/nnx/bridge/variables.html) so that they go to the desired collections.
-
-```python
+```{code-cell} ipython3
class Count(nnx.Variable): pass
nnx.register_variable_name_type_pair('counts', Count, overwrite=True)
@@ -420,31 +431,36 @@ print(var['params'])
[ 0.6420431 , 0.6220095 , -0.44769976],
[ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32)}
++++
## Partition metadata
-Flax uses a metadata wrapper box over the raw JAX array to annotate how a variable should be sharded.
+Flax uses a metadata wrapper box over the raw [`jax.Array`](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) to annotate how a variable should be sharded.
+
+I- n Flax Linen, this is an optional feature that is triggered by using [`flax.linen.with_partitioning`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) (`nn.with_partitioning`) on initializers.
+- In Flax NNX, since all Flax NNX variables are wrapped by [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) class anyway, that class will hold the sharding annotations too.
-In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too.
+> **Note:** If you are new to `jax.Array`s and _data sharding_, go to [Key concepts](https://jax.readthedocs.io/en/latest/key-concepts.html#array-devices-and-sharding) and [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html#sharded-computation) on the JAX documentation site.
-The `bridge.ToNNX` and `bridge.ToLinen` API will automatically convert the sharding annotations, if you use the built-in annotation methods (aka. `nn.with_partitioning` for Linen and `nnx.with_partitioning` for NNX).
+Both [`nnx.bridge.ToNNX`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToNNX) and [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) will automatically convert the sharding annotations if you use the built-in annotation methods, such as [`flax.linen.with_partitioning`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) (`nn.with_partitioning`) or [`flax.nnx.with_partitioning`](https://flax-nnx.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning).
-### Linen to NNX
+> **Note:** To get more familiarized with sharding metadata with Flax and JAX, refer to Flax NNX’s [Scale up](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html) guide, JAX’s [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html#sharded-computation), and the Flax Linen [Scale up](https://flax-linen.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html) guide.
-Even if you are not using any partition metadata in your Linen module, the variable JAX arrays will be converted to `nnx.Variable`s that wraps the true JAX array within.
+### Partition metadata: Flax Linen to NNX
-If you use `nn.with_partitioning` to annotate your Linen module's variables, the annotation will be converted to a `.sharding` field in the corresponding `nnx.Variable`.
+Even if you are not using any partition metadata in your Flax Linen `Module`, the variable [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) will be converted to [`nnx.Variable`s](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) that wrap the true `jax.Array` within.
-You can then use `nnx.with_sharding_constraint` to explicitly put the arrays into the annotated partitions within a `jax.jit`-compiled function, to initialize the whole model with every array at the right sharding.
+If you use [`flax.linen.with_partitioning`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning) (`nn.with_partitioning`) to annotate your Flax Linen `Module` variables, the annotation will be converted to the `.sharding` field in the corresponding [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable).
+You can then use [`nnx.with_sharding_constraint`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_sharding_constraint) to explicitly put the arrays into the annotated partitions within a [`jax.jit`](https://jax.readthedocs.io/en/latest/jit-compilation.html)-compiled function, to initialize the whole model with every array at the right sharding.
-```python
+```{code-cell} ipython3
class LinenDotWithPartitioning(nn.Module):
out_dim: int
@nn.compact
def __call__(self, x):
w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(),
- ('in', 'out')),
+ ('in', 'out')),
(x.shape[-1], self.out_dim))
return x @ w
@@ -475,17 +491,17 @@ print(model.w.value.sharding) # The underlying JAX array is sharded across the
('in', 'out')
GSPMDSharding({devices=[2,4]<=[8]})
++++
-### NNX to Linen
+### Partition metadata: Flax NNX to Linen
-If you are not using any metadata feature of the `nnx.Variable` (i.e., no sharding annotation, no registered hooks), the converted Linen module will not add a metadata wrapper to your NNX variable, and you don't need to worry about it.
+If you are not using any metadata features of the [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) - that is, no sharding annotation, no registered hooks, then the converted `flax.linen.Module` will not add a metadata wrapper to your Flax NNX variable, and you won't need to worry about it. (Recall that all Flax NNX variables are wrapped with `nnx.Variable` box.
-But if you did add sharding annotations to your NNX variables, `ToLinen` will convert them to a default Linen partition metadata class called `bridge.NNXMeta`, retaining all the metadata you put into the NNX variable.
+But if you did add sharding annotations to your Flax NNX variables, then [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) will convert them to a default Flax Linen partition metadata class called [`flax.nnx.bridge.NNXMeta`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.NNXMeta), retaining all the metadata you put into the NNX variable.
-Like with any Linen metadata wrappers, you can use `linen.unbox()` to get the raw JAX array tree.
+Similar to any Flax Linen metadata wrappers, you can use `flax.linen.unbox()` ([`flax.linen.meta.unbox`](https://github.com/google/flax/blob/5d31452889b8d106d7c722b5eaac14cb9784fec2/flax/core/meta.py#L160)) to get the raw [`jax.Array`](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) [pytree](https://jax.readthedocs.io/en/latest/glossary.html#term-pytree).
-
-```python
+```{code-cell} ipython3
class NNXDotWithParititioning(nnx.Module):
def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):
init_fn = nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))
@@ -523,19 +539,19 @@ print(variables['params']['w'].sharding)
GSPMDSharding({devices=[2,4]<=[8]})
++++
-## Lifted transforms
-
-In general, if you want to apply Linen/NNX-style lifted transforms upon an `nnx.bridge`-converted module, just go ahead and do it in the usual Linen/NNX syntax.
+## Lifted transformations - go ahead and do it
-For Linen-style transforms, note that `bridge.ToLinen` is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases)
+In general, if you want to apply [Flax Linen-](https://flax-linen.readthedocs.io/en/latest/developer_notes/lift.html) or [Flax NNX style lifted transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) on an [`nnx.bridge`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html)-converted `Module`, just go ahead and do it in the usual Flax Linen or NNX syntax.
-### Linen to NNX
+For [Flax Linen style transforms](https://flax-linen.readthedocs.io/en/latest/developer_notes/lift.html), note that [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) is the top-level `Module` class, so you may want to just use it as the first argument of your transforms (which needs to be a `flax.linen.Module` class in most cases).
-NNX style lifted transforms are similar to JAX transforms, and they work on functions.
+### Lifted transforms: Flax Linen to NNX
+[Flax NNX style lifted transforms](https://flax.readthedocs.io/en/latest/guides/transforms.html) are similar to [JAX transforms](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations), and they too work on functions.
-```python
+```{code-cell} ipython3
class NNXVmapped(nnx.Module):
def __init__(self, out_dim: int, vmap_axis_size: int, rngs: nnx.Rngs):
self.linen_dot = nnx.bridge.ToNNX(nn.Dense(out_dim, use_bias=False), rngs=rngs)
@@ -553,7 +569,7 @@ class NNXVmapped(nnx.Module):
x = jax.random.normal(jax.random.key(0), (4, 32))
model = bridge.lazy_init(NNXVmapped(64, 4, rngs=nnx.Rngs(0)), x)
-print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got vmapped
+print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got `vmap`ped.
y = model(x)
print(y.shape)
```
@@ -561,15 +577,13 @@ print(y.shape)
(4, 32, 64)
(4, 64)
++++
-### NNX to Linen
+### Lifted transforms: Flax NNX to Linen
-Note that `bridge.ToLinen` is the top level module class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases).
+As mentioned before, [`nnx.bridge.ToLinen`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/bridge.html#flax.nnx.bridge.ToLinen) is the top-level `Module` class, so you may want to just use it as the first argument of your transforms (which needs to be a `linen.Module` class in most cases). And, since `nnx.bridge.ToLinen` introduces this extra `nnx` collection, you need to mark it when using the axis-changing transforms ([`flax.linen.vmap`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html#flax.linen.vmap), [`flax.linen.scan`](https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html#flax.linen.scan), and so on) to make sure they are passed inside.
-Also, since `bridge.ToLinen` introduced this extra `nnx` collection, you need to mark it when using the axis-changing transforms (`linen.vmap`, `linen.scan`, etc) to make sure they are passed inside.
-
-
-```python
+```{code-cell} ipython3
class LinenVmapped(nn.Module):
dout: int
@nn.compact
@@ -581,11 +595,10 @@ class LinenVmapped(nn.Module):
x = jax.random.normal(jax.random.key(42), (4, 32))
model = LinenVmapped(64)
var = model.init(jax.random.key(0), x)
-print(var['params']['VmapToLinen_0']['kernel'].shape) # (4, 32, 64) - leading dim 4 got vmapped
+print(var['params']['VmapToLinen_0']['kernel'].shape) # (4, 32, 64) - leading dim 4 got `vmap`ped.
y = model.apply(var, x)
print(y.shape)
```
(4, 32, 64)
(4, 64)
-