Skip to content

Commit

Permalink
fix: fixed nbviewer link in readme
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Jul 24, 2024
1 parent 0e1a2d7 commit a8bc316
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 67 deletions.
12 changes: 4 additions & 8 deletions Diffusion flax linen.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -488,9 +488,9 @@
" scale = max(scale, 1e-10)\n",
" return nn.initializers.variance_scaling(scale=scale, mode=\"fan_avg\", distribution=\"truncated_normal\", dtype=dtype)\n",
"\n",
"class FlashAttention(nn.Module):\n",
"class EfficientAttention(nn.Module):\n",
" \"\"\"\n",
" Based on the flash attention implementation.\n",
" Based on the pallas attention implementation.\n",
" \"\"\"\n",
" query_dim: int\n",
" heads: int = 4\n",
Expand Down Expand Up @@ -521,10 +521,6 @@
" key = self.key(context)\n",
" value = self.value(context)\n",
" \n",
" # hidden_states = flash_mha(\n",
" # query.astype(jnp.float16), key.astype(jnp.float16), value.astype(jnp.float16),\n",
" # )\n",
" \n",
" hidden_states = jax.experimental.pallas.ops.attention.mha_reference(\n",
" query, key, value, None\n",
" )\n",
Expand Down Expand Up @@ -603,7 +599,7 @@
" \n",
" def setup(self):\n",
" if self.use_flash_attention:\n",
" attenBlock = FlashAttention\n",
" attenBlock = EfficientAttention\n",
" else:\n",
" attenBlock = NormalAttention\n",
" \n",
Expand Down Expand Up @@ -714,7 +710,7 @@
" use_cross_only=False\n",
" )(projected_x, context)\n",
" elif self.use_flash_attention == True:\n",
" projected_x = FlashAttention(\n",
" projected_x = EfficientAttention(\n",
" query_dim=inner_dim,\n",
" heads=self.heads,\n",
" dim_head=self.dim_head,\n",
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ In the `example notebooks` folder, you will find comprehensive notebooks for var

### Available Notebooks

- **[Diffusion explained (nbviewer link)](https://github.com/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/edm%20tutorial.ipynb) [(local link)](tutorial%20notebooks/simple%20diffusion%20flax.ipynb)**
- **[Diffusion explained (nbviewer link)](https://nbviewer.org/github/AshishKumar4/FlaxDiff/blob/main/tutorial%20notebooks/simple%20diffusion%20flax.ipynb) [(local link)](tutorial%20notebooks/simple%20diffusion%20flax.ipynb)**

- **WORK IN PROGRESS** An in-depth exploration of the concept of Diffusion based generative models, DDPM (Denoising Diffusion Probabilistic Models), DDIM (Denoising Diffusion Implicit Models), and the SDE/ODE generalizations of diffusion, with step-by-step explainations and code.

Expand Down
230 changes: 172 additions & 58 deletions tutorial notebooks/simple diffusion flax.ipynb

Large diffs are not rendered by default.

0 comments on commit a8bc316

Please sign in to comment.