From 3da74f0dcfe52504cbe5d2dd93edc9e359eb74e9 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 18 Jul 2024 13:11:24 +0100 Subject: [PATCH] [nnx] add call section to nnx_basics --- docs/nnx/nnx_basics.ipynb | 132 ++++++++++++++++++++++++++------------ docs/nnx/nnx_basics.md | 38 ++++++++--- 2 files changed, 121 insertions(+), 49 deletions(-) diff --git a/docs/nnx/nnx_basics.ipynb b/docs/nnx/nnx_basics.ipynb index 1df062735c..d44ca0c918 100644 --- a/docs/nnx/nnx_basics.ipynb +++ b/docs/nnx/nnx_basics.ipynb @@ -19,21 +19,29 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 16, "metadata": { "tags": [ "skip-execution" ] }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/cgarciae/.pyenv/versions/3.10.13/lib/python3.10/pty.py:89: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n", + " pid, fd = os.forkpty()\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Requirement already satisfied: flax in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (0.8.5)\n", - "Requirement already satisfied: penzai in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (0.1.3)\n", + "Requirement already satisfied: flax in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (0.8.6)\n", + "Requirement already satisfied: penzai in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (0.1.5)\n", "Requirement already satisfied: numpy>=1.22 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (1.26.4)\n", - "Requirement already satisfied: jax>=0.4.27 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.4.31.dev20240621+0428a1509)\n", + "Requirement already satisfied: jax>=0.4.27 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.4.31.dev20240712+5cce39442)\n", "Requirement already satisfied: msgpack in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (1.0.8)\n", "Requirement already satisfied: optax in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.2.2)\n", "Requirement already satisfied: orbax-checkpoint in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from flax) (0.5.20)\n", @@ -46,10 +54,10 @@ "Requirement already satisfied: ordered_set>=4.1.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from penzai) (4.1.0)\n", "Requirement already satisfied: jaxtyping>=0.2.20 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from equinox>=0.11.3->penzai) (0.2.31)\n", "Requirement already satisfied: jaxlib<=0.4.31,>=0.4.30 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (0.4.30)\n", - "Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (0.4.0)\n", + "Requirement already satisfied: ml-dtypes>=0.2.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (0.3.2)\n", "Requirement already satisfied: opt-einsum in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (3.3.0)\n", "Requirement already satisfied: scipy>=1.9 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jax>=0.4.27->flax) (1.14.0)\n", - "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from rich>=11.1->flax) (2.2.0)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from rich>=11.1->flax) (3.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from rich>=11.1->flax) (2.18.0)\n", "Requirement already satisfied: chex>=0.1.86 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from optax->flax) (0.1.86)\n", "Requirement already satisfied: etils[epath,epy] in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from orbax-checkpoint->flax) (1.7.0)\n", @@ -58,22 +66,22 @@ "Requirement already satisfied: toolz>=0.9.0 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from chex>=0.1.86->optax->flax) (0.12.1)\n", "Requirement already satisfied: typeguard==2.13.3 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from jaxtyping>=0.2.20->equinox>=0.11.3->penzai) (2.13.3)\n", "Requirement already satisfied: mdurl~=0.1 in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax) (0.1.2)\n", - "Requirement already satisfied: fsspec in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint->flax) (2024.6.0)\n", + "Requirement already satisfied: fsspec in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint->flax) (2024.6.1)\n", "Requirement already satisfied: importlib_resources in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint->flax) (6.4.0)\n", "Requirement already satisfied: zipp in /Users/cgarciae/repos/flax/.venv/lib/python3.10/site-packages (from etils[epath,epy]->orbax-checkpoint->flax) (3.19.2)\n", "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.0\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.1.2\u001b[0m\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], "source": [ - "! pip install -U flax penzai" + "# ! pip install -U flax penzai" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -103,7 +111,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -136,7 +144,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -149,7 +157,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" @@ -187,7 +195,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -237,13 +245,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" @@ -296,13 +304,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" @@ -360,7 +368,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -424,7 +432,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -437,7 +445,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" @@ -495,13 +503,20 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 31, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.count.value = Array(1, dtype=uint32)\n" + ] + }, { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" @@ -527,6 +542,7 @@ "model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))\n", "y = model(jnp.ones((1, 3)))\n", "\n", + "print(f'{model.count.value = }')\n", "nnx.display(model)" ] }, @@ -544,13 +560,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" @@ -562,7 +578,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" @@ -593,21 +609,18 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "model.count.value = Array(1, dtype=uint32)\n", "model.count.value = Array(2, dtype=uint32)\n" ] } ], "source": [ - "print(f'{model.count.value = }')\n", - "\n", "# 1. Use split to create a pytree representation of the Module\n", "graphdef, state = nnx.split(model)\n", "\n", @@ -624,7 +637,6 @@ "y, state = forward(graphdef, state, x=jnp.ones((1, 3)))\n", "# 5. Update the state of the original Module\n", "nnx.update(model, state)\n", - "\n", "print(f'{model.count.value = }')" ] }, @@ -636,11 +648,51 @@ "fine within a transform context (including the base eager interpreter)\n", "but its necessary to use the Functional API when crossing boundaries.\n", "\n", - "**Why aren't Module's just Pytrees?** The main reason is that it is very\n", - "easy to lose track of shared references by accident this way, for example\n", - "if you pass two Module that have a shared Module through a JAX boundary\n", - "you will silently lose that sharing. The Functional API makes this\n", - "behavior explicit, and thus it is much easier to reason about." + "#### Using call for pure computation\n", + "\n", + "To simplify functional workflows, the `call` function can be used to run\n", + "a computation on a `(GraphDef, State)` as returned by `split`. `call` will\n", + "internally `merge` the state, call the desired method, and then `split` to\n", + "return a new `(GraphDef, State)` pair along with the result of the computation." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "model.count.value = Array(3, dtype=uint32)\n" + ] + } + ], + "source": [ + "model_state = nnx.split(model)\n", + "\n", + "@jax.jit\n", + "def forward(model_state: tuple, x: jax.Array):\n", + " # same as merge + call + split\n", + " y, model_state = nnx.call(model_state)(x)\n", + " return y, model_state\n", + "\n", + "y, model_state = forward(model_state, x=jnp.ones((1, 3)))\n", + "\n", + "model = nnx.merge(*model_state) # or nnx.update(model, ...)\n", + "print(f'{model.count.value = }')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Why aren't Module's just Pytrees?\n", + "The main reason is that it is very easy to lose track of shared references\n", + "by accident this way, for example if you pass two Module that have a shared\n", + "Module through a JAX boundary you will silently lose that sharing. The Functional\n", + "API makes this behavior explicit, and thus it is much easier to reason about." ] }, { @@ -664,13 +716,13 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" @@ -682,7 +734,7 @@ { "data": { "text/html": [ - "
(Loading...)
" + "
" ], "text/plain": [ "" @@ -710,7 +762,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ diff --git a/docs/nnx/nnx_basics.md b/docs/nnx/nnx_basics.md index ca838042a4..3274188182 100644 --- a/docs/nnx/nnx_basics.md +++ b/docs/nnx/nnx_basics.md @@ -23,7 +23,7 @@ that have allowed Linen to scale effectively to large codebases. ```{code-cell} ipython3 :tags: [skip-execution] -! pip install -U flax penzai +# ! pip install -U flax penzai ``` ```{code-cell} ipython3 @@ -288,6 +288,7 @@ class StatefulLinear(nnx.Module): model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0)) y = model(jnp.ones((1, 3))) +print(f'{model.count.value = }') nnx.display(model) ``` @@ -313,8 +314,6 @@ update an object inplace with the content of a given State. This pattern is used propagate the state from a transform back to the source object outside. ```{code-cell} ipython3 -print(f'{model.count.value = }') - # 1. Use split to create a pytree representation of the Module graphdef, state = nnx.split(model) @@ -331,7 +330,6 @@ def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax y, state = forward(graphdef, state, x=jnp.ones((1, 3))) # 5. Update the state of the original Module nnx.update(model, state) - print(f'{model.count.value = }') ``` @@ -339,11 +337,33 @@ The key insight of this pattern is that using mutable references is fine within a transform context (including the base eager interpreter) but its necessary to use the Functional API when crossing boundaries. -**Why aren't Module's just Pytrees?** The main reason is that it is very -easy to lose track of shared references by accident this way, for example -if you pass two Module that have a shared Module through a JAX boundary -you will silently lose that sharing. The Functional API makes this -behavior explicit, and thus it is much easier to reason about. +#### Using call for pure computation + +To simplify functional workflows, the `call` function can be used to run +a computation on a `(GraphDef, State)` as returned by `split`. `call` will +internally `merge` the state, call the desired method, and then `split` to +return a new `(GraphDef, State)` pair along with the result of the computation. + +```{code-cell} ipython3 +model_state = nnx.split(model) + +@jax.jit +def forward(model_state: tuple, x: jax.Array): + # same as merge + call + split + y, model_state = nnx.call(model_state)(x) + return y, model_state + +y, model_state = forward(model_state, x=jnp.ones((1, 3))) + +model = nnx.merge(*model_state) # or nnx.update(model, ...) +print(f'{model.count.value = }') +``` + +#### Why aren't Module's just Pytrees? +The main reason is that it is very easy to lose track of shared references +by accident this way, for example if you pass two Module that have a shared +Module through a JAX boundary you will silently lose that sharing. The Functional +API makes this behavior explicit, and thus it is much easier to reason about. +++