From e23e6f45c7fba67ec43d3e1923f2be2393701d93 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Singh Date: Wed, 17 Jul 2024 12:41:28 +0530 Subject: [PATCH] feat: proper transformer blocks in attention --- Diffusion flax linen.ipynb | 344 ++++++++++++++++++++++++------------- 1 file changed, 226 insertions(+), 118 deletions(-) diff --git a/Diffusion flax linen.ipynb b/Diffusion flax linen.ipynb index efee46d..42fb534 100644 --- a/Diffusion flax linen.ipynb +++ b/Diffusion flax linen.ipynb @@ -57,14 +57,15 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "There was a problem when trying to write in your cache folder (/home/mrwhite0racle/.cache/huggingface/hub). You should set the environment variable TRANSFORMERS_CACHE to a writable directory.\n" + "The dotenv extension is already loaded. To reload it, use:\n", + " %reload_ext dotenv\n" ] } ], @@ -111,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -144,7 +145,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -195,7 +196,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -225,7 +226,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -354,7 +355,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -422,7 +423,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -460,7 +461,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -502,7 +503,7 @@ " kernel_init=self.kernel_init(), dtype=self.dtype, name=\"to_v\")\n", " self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision, \n", " kernel_init=self.kernel_init(), dtype=self.dtype, name=\"to_out_0\")\n", - " self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)\n", + " # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)\n", "\n", " @nn.compact\n", " def __call__(self, x:jax.Array, context=None):\n", @@ -516,17 +517,13 @@ " # query.astype(jnp.float16), key.astype(jnp.float16), value.astype(jnp.float16),\n", " # )\n", " \n", - " # hidden_states = jax.experimental.pallas.ops.attention.mha_reference(\n", - " # query, key, value, None\n", - " # )\n", + " hidden_states = jax.experimental.pallas.ops.attention.mha_reference(\n", + " query, key, value, None\n", + " )\n", " \n", " # hidden_states = self.attnfn(\n", " # query, key, value, None\n", " # )\n", - "\n", - " hidden_states = nn.dot_product_attention(\n", - " query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision\n", - " )\n", " \n", " proj = self.proj_attn(hidden_states)\n", " return proj\n", @@ -583,22 +580,100 @@ " )\n", " proj = self.proj_attn(hidden_states)\n", " return proj\n", - "\n", + " \n", "class AttentionBlock(nn.Module):\n", + " # Has self and cross attention\n", + " query_dim: int\n", + " heads: int = 4\n", + " dim_head: int = 64\n", + " dtype: Any = jnp.float32\n", + " precision: Any = jax.lax.Precision.HIGHEST\n", + " use_bias: bool = True\n", + " kernel_init: Callable = lambda : kernel_init(1.0)\n", + " use_flash_attention:bool = False\n", + " use_cross_only:bool = False\n", + " \n", + " def setup(self):\n", + " if self.use_flash_attention:\n", + " attenBlock = FlashAttention\n", + " else:\n", + " attenBlock = NormalAttention\n", + " \n", + " self.attention1 = attenBlock(\n", + " query_dim=self.query_dim,\n", + " heads=self.heads,\n", + " dim_head=self.dim_head,\n", + " name=f'Attention1',\n", + " precision=self.precision,\n", + " use_bias=self.use_bias,\n", + " dtype=self.dtype,\n", + " kernel_init=self.kernel_init\n", + " )\n", + " self.attention2 = attenBlock(\n", + " query_dim=self.query_dim,\n", + " heads=self.heads,\n", + " dim_head=self.dim_head,\n", + " name=f'Attention2',\n", + " precision=self.precision,\n", + " use_bias=self.use_bias,\n", + " dtype=self.dtype,\n", + " kernel_init=self.kernel_init\n", + " )\n", + " \n", + " self.ff = nn.DenseGeneral(\n", + " features=self.query_dim,\n", + " use_bias=self.use_bias,\n", + " precision=self.precision,\n", + " dtype=self.dtype,\n", + " kernel_init=self.kernel_init(),\n", + " name=\"ff\"\n", + " )\n", + " self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)\n", + " self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)\n", + " self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)\n", + " self.norm4 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)\n", + " \n", + " @nn.compact\n", + " def __call__(self, hidden_states, context=None):\n", + " # self attention\n", + " residual = hidden_states\n", + " hidden_states = self.norm1(hidden_states)\n", + " if self.use_cross_only:\n", + " hidden_states = self.attention1(hidden_states, context)\n", + " else:\n", + " hidden_states = self.attention1(hidden_states)\n", + " hidden_states = hidden_states + residual\n", + "\n", + " # cross attention\n", + " residual = hidden_states\n", + " hidden_states = self.norm2(hidden_states)\n", + " hidden_states = self.attention2(hidden_states, context)\n", + " hidden_states = hidden_states + residual\n", + "\n", + " # feed forward\n", + " residual = hidden_states\n", + " hidden_states = self.norm3(hidden_states)\n", + " hidden_states = nn.gelu(hidden_states)\n", + " hidden_states = self.ff(hidden_states)\n", + " hidden_states = hidden_states + residual\n", + " \n", + " return hidden_states\n", + "\n", + "class TransformerBlock(nn.Module):\n", " heads: int = 4\n", " dim_head: int = 32\n", " use_linear_attention: bool = True\n", " dtype: Any = jnp.float32\n", " precision: Any = jax.lax.Precision.HIGH\n", " use_projection: bool = False\n", - " use_flash_attention:bool = False\n", + " use_flash_attention:bool = True\n", + " use_self_and_cross:bool = False\n", "\n", " @nn.compact\n", " def __call__(self, x, context=None):\n", " inner_dim = self.heads * self.dim_head\n", " B, H, W, C = x.shape\n", " normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)\n", - " # normed_x = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)(x)\n", " if self.use_projection == True:\n", " if self.use_linear_attention:\n", " projected_x = nn.Dense(features=inner_dim, \n", @@ -618,7 +693,19 @@ " \n", " context = projected_x if context is None else context\n", "\n", - " if self.use_flash_attention == True:\n", + " if self.use_self_and_cross:\n", + " projected_x = AttentionBlock(\n", + " query_dim=inner_dim,\n", + " heads=self.heads,\n", + " dim_head=self.dim_head,\n", + " name=f'Attention',\n", + " precision=self.precision,\n", + " use_bias=False,\n", + " dtype=self.dtype,\n", + " use_flash_attention=self.use_flash_attention,\n", + " use_cross_only=False\n", + " )(projected_x, context)\n", + " elif self.use_flash_attention == True:\n", " projected_x = FlashAttention(\n", " query_dim=inner_dim,\n", " heads=self.heads,\n", @@ -637,16 +724,7 @@ " precision=self.precision,\n", " use_bias=False,\n", " )(projected_x, context)\n", - "\n", - " # projected_x = nn.MultiHeadAttention(num_heads=self.heads, use_bias=False, precision='high', decode=False)(projected_x)\n", - "\n", - " # projected_x = BasicTransformerBlock(\n", - " # query_dim=inner_dim,\n", - " # heads=self.heads,\n", - " # dim_head=self.dim_head,\n", - " # name=f'Attention',\n", - " # precision=self.precision,\n", - " # )(projected_x, projected_x)\n", + " \n", "\n", " if self.use_projection == True:\n", " if self.use_linear_attention:\n", @@ -675,17 +753,32 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 33, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "14.1 ms ± 178 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" + ] + } + ], "source": [ "x = jnp.ones((16, 16, 16, 64))\n", "# context = jnp.ones((8, 16, 16, 64))\n", - "attention_block = AttentionBlock(heads=4, dim_head=64//4, dtype=jnp.float16, use_flash_attention=True)\n", + "attention_block = TransformerBlock(heads=4, dim_head=64//4, dtype=jnp.float16, use_flash_attention=True, use_projection=True, use_self_and_cross=True)\n", "params = attention_block.init(jax.random.PRNGKey(0), x)\n", - "# %timeit attention_block.apply(params, x)" + "%timeit attention_block.apply(params, x)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -698,7 +791,7 @@ "context = jnp.pad(context, ((0, 0), (0, 4), (0, 0)), mode='constant', constant_values=0)\n", "print(context.shape)\n", "context = jnp.reshape(context, (1, 1, 16, 768))\n", - "attention_block = AttentionBlock(heads=4, dim_head=64//4, dtype=jnp.float16, use_flash_attention=True)\n", + "attention_block = TransformerBlock(heads=4, dim_head=64//4, dtype=jnp.float16, use_flash_attention=True, use_projection=True, use_self_and_cross=True)\n", "params = attention_block.init(jax.random.PRNGKey(0), x, context)\n", "out = attention_block.apply(params, x, context)\n", "print(\"Output :\", out.shape)\n", @@ -717,7 +810,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ @@ -796,12 +889,6 @@ " temb = FourierEmbedding(features=self.emb_features)(temb)\n", " temb = TimeProjection(features=self.emb_features)(temb)\n", " \n", - " # textemb = textcontext\n", - " # textemb = nn.DenseGeneral(features=self.emb_features, name=\"textemb_projection\")(textemb)\n", - " # textemb = nn.gelu(textemb)\n", - " \n", - " # B, S = textemb.shape\n", - " # textcontext = textemb.reshape((B, 1, S))\n", " _, TS, TC = textcontext.shape\n", " \n", " # print(\"time embedding\", temb.shape)\n", @@ -841,9 +928,11 @@ " padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))\n", " else:\n", " padded_context = None\n", - " x = AttentionBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),\n", + " x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),\n", " dim_head=dim_in // attention_config['heads'],\n", " use_flash_attention=attention_config.get(\"flash_attention\", True),\n", + " use_projection=attention_config.get(\"use_projection\", False),\n", + " use_self_and_cross=attention_config.get(\"use_self_and_cross\", True),\n", " name=f\"down_{i}_attention_{j}\")(x, padded_context)\n", " # print(\"down residual for feature level\", i, \"is of shape\", x.shape, \"features\", dim_in)\n", " downs.append(x)\n", @@ -859,15 +948,6 @@ " # Middle Blocks\n", " middle_dim_out = self.feature_depths[-1]\n", " middle_attention = self.attention_configs[-1]\n", - " # x = nn.GroupNorm(8)(x)\n", - " # x = ConvLayer(\n", - " # conv_type,\n", - " # features=middle_dim_out,\n", - " # kernel_size=(3, 3),\n", - " # strides=(1, 1),\n", - " # kernel_init=kernel_init(1.0),\n", - " # name=\"middle_conv\"\n", - " # )(x)\n", " for j in range(self.num_middle_res_blocks):\n", " x = ResidualBlock(\n", " middle_conv_type,\n", @@ -880,11 +960,13 @@ " norm_groups=self.norm_groups\n", " )(x, temb)\n", " if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block\n", - " x = AttentionBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32), \n", - " dim_head=middle_dim_out // middle_attention['heads'],\n", + " x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32), \n", + " dim_head=middle_dim_out // middle_attention['heads'],\n", " use_flash_attention=middle_attention.get(\"flash_attention\", True),\n", - " use_linear_attention=False,\n", - " name=f\"middle_attention_{j}\")(x)\n", + " use_linear_attention=False,\n", + " use_projection=middle_attention.get(\"use_projection\", False),\n", + " use_self_and_cross=False,\n", + " name=f\"middle_attention_{j}\")(x)\n", " x = ResidualBlock(\n", " middle_conv_type,\n", " name=f\"middle_res2_{j}\",\n", @@ -920,9 +1002,11 @@ " padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))\n", " else:\n", " padded_context = None\n", - " x = AttentionBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32), \n", + " x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32), \n", " dim_head=dim_out // attention_config['heads'],\n", " use_flash_attention=attention_config.get(\"flash_attention\", True),\n", + " use_projection=attention_config.get(\"use_projection\", False),\n", + " use_self_and_cross=attention_config.get(\"use_self_and_cross\", True),\n", " name=f\"up_{i}_attention_{j}\")(x, padded_context)\n", " # print(\"Upscaling \", i, x.shape)\n", " if i != len(feature_depths) - 1:\n", @@ -978,7 +1062,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -992,7 +1076,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -1238,7 +1322,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -1252,28 +1336,26 @@ "name": "stdout", "output_type": "stream", "text": [ - "Experiment_Name: Diffusion_SDE_VE_TEXT_2024-07-16_02:16:07\n", + "Experiment_Name: Diffusion_SDE_VE_TEXT_2024-07-17_12:36:20\n", "Gpu Device: cuda:0, Cpu Device: TFRT_CPU_0\n", - "Loading labels from cache\n", - "Loading model from checkpoint 1368\n", - "Loaded model from checkpoint at step 1368 0.06490325\n" + "Loading labels from cache\n" ] } ], "source": [ - "# experiment_name = \"{name}_{date}\".format(\n", - "# name=\"Diffusion_SDE_VE_TEXT\", date=datetime.now().strftime(\"%Y-%m-%d_%H:%M:%S\")\n", - "# )\n", - "experiment_name = 'Diffusion_SDE_VE_TEXT_2024-07-16_02:16:07'\n", + "experiment_name = \"{name}_{date}\".format(\n", + " name=\"Diffusion_SDE_VE_TEXT\", date=datetime.now().strftime(\"%Y-%m-%d_%H:%M:%S\")\n", + ")\n", + "# experiment_name = 'Diffusion_SDE_VE_TEXT_2024-07-16_02:16:07'\n", "print(\"Experiment_Name:\", experiment_name)\n", "unet = Unet(emb_features=256, \n", " feature_depths=[64, 64, 128, 256, 512],\n", " attention_configs=[\n", " None, #{\"heads\":8, \"dtype\":jnp.float16, \"flash_attention\":True}, \n", - " {\"heads\":8, \"dtype\":jnp.float16, \"flash_attention\":True}, \n", - " {\"heads\":8, \"dtype\":jnp.float16, \"flash_attention\":True}, \n", - " {\"heads\":8, \"dtype\":jnp.float16, \"flash_attention\":True}, \n", - " {\"heads\":8, \"dtype\":jnp.float16, \"flash_attention\":False}\n", + " {\"heads\":8, \"dtype\":jnp.float16, \"flash_attention\":True, \"use_projection\":True, \"use_self_and_cross\":True}, \n", + " {\"heads\":8, \"dtype\":jnp.float16, \"flash_attention\":True, \"use_projection\":True, \"use_self_and_cross\":True}, \n", + " {\"heads\":8, \"dtype\":jnp.float16, \"flash_attention\":True, \"use_projection\":True, \"use_self_and_cross\":True}, \n", + " {\"heads\":8, \"dtype\":jnp.float16, \"flash_attention\":False, \"use_projection\":False, \"use_self_and_cross\":False}\n", " ],\n", " num_res_blocks=2,\n", " num_middle_res_blocks=1\n", @@ -1305,8 +1387,8 @@ "\n", "# solver = optax.adamw(learning_rate=learning_rate_schedule)\n", "# solver = optax.radam(2e-4)\n", - "# solver = optax.adam(2e-4)\n", - "solver = optax.adamw(1e-5)\n", + "solver = optax.adam(2e-4)\n", + "# solver = optax.adamw(1e-5)\n", "\n", "# solver = optax.lookahead(solver, sync_period=6, slow_step_size=0.5)\n", "# params_transform = lambda x: optax.LookaheadParams.init_synced(x)\n", @@ -1320,14 +1402,23 @@ " # train_state=trainer.best_state,\n", " # loss_fn=lambda x, y: jnp.abs(x - y),\n", " # param_transforms=params_transform,\n", - " load_from_checkpoint=True,\n", + " # load_from_checkpoint=True,\n", " )\n", "#trainer.summary()" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.summary()" + ] + }, + { + "cell_type": "code", + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -1337,14 +1428,69 @@ "Gpu Device: cuda:0, Cpu Device: TFRT_CPU_0\n", "Loading labels from cache\n", "\n", - "Epoch 1369/2000\n" + "Epoch 1/2000\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t\tEpoch 1: 600step [01:34, 6.37step/s, loss=0.1298] \n", + "WARNING:absl:SaveArgs.aggregate is deprecated, please use custom TypeHandler (https://orbax.readthedocs.io/en/latest/custom_handlers.html#typehandler) or contact Orbax team to migrate before August 1st, 2024.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saving model at epoch 1\n", + "\n", + "\tEpoch 1 completed. Avg Loss: 0.269815593957901, Time: 94.22s, Best Loss: 0.269815593957901\n", + "\n", + "Epoch 2/2000\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t\tEpoch 2: 600step [00:51, 11.73step/s, loss=0.1322] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saving model at epoch 2\n", + "\n", + "\tEpoch 2 completed. Avg Loss: 0.12731926143169403, Time: 51.17s, Best Loss: 0.12731926143169403\n", + "\n", + "Epoch 3/2000\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "\t\tEpoch 1369: 0%| | 0/511 [00:00