diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index e364045..95e9117 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -25,7 +25,7 @@ jobs: python-version: "3.11" - name: Install dependencies run: | - python -m pip install --upgrade pip + python -m pip install --upgrade pip build pip install flake8 pytest if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 diff --git a/.gitignore b/.gitignore index 7751bc1..e18a1c6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +dist +*.egg-info +*.egg __pycache__ *.pyc .ipynb_checkpoints @@ -11,4 +14,4 @@ wandb gcs_mount datacache *.deb -gcsfuse.yml \ No newline at end of file +gcsfuse.yml diff --git a/Diffusion flax linen on TPUs.ipynb b/Diffusion flax linen on TPUs.ipynb index 98026db..a13bdfd 100644 --- a/Diffusion flax linen on TPUs.ipynb +++ b/Diffusion flax linen on TPUs.ipynb @@ -1898,20 +1898,20 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 104, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Experiment_Name: Diffusion_SDE_VE_TEXT_2024-08-01_14:38:34\n" + "Experiment_Name: Diffusion_SDE_VE_TEXT_2024-08-01_15:42:34\n" ] }, { "data": { "text/html": [ - "Finishing last run (ID:x90ddjrq) before initializing another..." + "Finishing last run (ID:6rmw3wuq) before initializing another..." ], "text/plain": [ "" @@ -1928,7 +1928,7 @@ " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n", " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n", " \n", - "

Run history:


train/loss█▄▄▃▂▁▁▁▁▁

Run summary:


train/loss0.07343

" + "

Run history:


train/loss█▄▄▄▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁

Run summary:


train/loss0.07254

" ], "text/plain": [ "" @@ -1940,7 +1940,7 @@ { "data": { "text/html": [ - " View run Diffusion_SDE_VE_TEXT_2024-08-01_14:30:55 at: https://wandb.ai/ashishkumar4/flaxdiff/runs/x90ddjrq
View project at: https://wandb.ai/ashishkumar4/flaxdiff
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" + " View run Diffusion_SDE_VE_TEXT_2024-08-01_15:17:19 at: https://wandb.ai/ashishkumar4/flaxdiff/runs/6rmw3wuq
View project at: https://wandb.ai/ashishkumar4/flaxdiff
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" @@ -1952,7 +1952,7 @@ { "data": { "text/html": [ - "Find logs at: ./wandb/run-20240801_143055-x90ddjrq/logs" + "Find logs at: ./wandb/run-20240801_151719-6rmw3wuq/logs" ], "text/plain": [ "" @@ -1976,7 +1976,7 @@ { "data": { "text/html": [ - "Successfully finished last run (ID:x90ddjrq). Initializing new run:
" + "Successfully finished last run (ID:6rmw3wuq). Initializing new run:
" ], "text/plain": [ "" @@ -2000,7 +2000,7 @@ { "data": { "text/html": [ - "Run data is saved locally in /home/mrwhite0racle/research/FlaxDiff/wandb/run-20240801_143834-fjuigbav" + "Run data is saved locally in /home/mrwhite0racle/research/FlaxDiff/wandb/run-20240801_154234-2wlp36d3" ], "text/plain": [ "" @@ -2012,7 +2012,7 @@ { "data": { "text/html": [ - "Syncing run Diffusion_SDE_VE_TEXT_2024-08-01_14:38:34 to Weights & Biases (docs)
" + "Syncing run Diffusion_SDE_VE_TEXT_2024-08-01_15:42:34 to Weights & Biases (docs)
" ], "text/plain": [ "" @@ -2036,7 +2036,7 @@ { "data": { "text/html": [ - " View run at https://wandb.ai/ashishkumar4/flaxdiff/runs/fjuigbav" + " View run at https://wandb.ai/ashishkumar4/flaxdiff/runs/2wlp36d3" ], "text/plain": [ "" @@ -2054,7 +2054,7 @@ } ], "source": [ - "BATCH_SIZE = 64\n", + "BATCH_SIZE = 32\n", "IMAGE_SIZE = 128\n", "\n", "cosine_schedule = CosineNoiseSchedule(1000, beta_end=1)\n", @@ -2085,8 +2085,8 @@ " # None,\n", " # None,\n", " # {\"heads\":32, \"dtype\":jnp.bfloat16, \"flash_attention\":True, \"use_projection\":False, \"use_self_and_cross\":True}, \n", - " {\"heads\":8, \"dtype\":jnp.bfloat16, \"flash_attention\":False, \"use_projection\":False, \"use_self_and_cross\":False}, \n", - " {\"heads\":8, \"dtype\":jnp.bfloat16, \"flash_attention\":False, \"use_projection\":False, \"use_self_and_cross\":False}, \n", + " {\"heads\":8, \"dtype\":jnp.bfloat16, \"flash_attention\":True, \"use_projection\":False, \"use_self_and_cross\":False}, \n", + " {\"heads\":8, \"dtype\":jnp.bfloat16, \"flash_attention\":True, \"use_projection\":False, \"use_self_and_cross\":False}, \n", " {\"heads\":8, \"dtype\":jnp.bfloat16, \"flash_attention\":False, \"use_projection\":False, \"use_self_and_cross\":False},\n", " ],\n", " \"num_res_blocks\":2,\n", @@ -2142,7 +2142,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 105, "metadata": {}, "outputs": [], "source": [ @@ -2151,7 +2151,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 106, + "metadata": {}, + "outputs": [], + "source": [ + "# jax.profiler.start_server(6009)\n", + "final_state = trainer.fit(data, 1000, epochs=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 103, "metadata": {}, "outputs": [ { @@ -2166,7 +2176,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\t\tEpoch 1: 100%|█████████████████████████████████| 1000/1000 [05:20<00:00, 3.12step/s, loss=0.0734]\n" + "\t\tEpoch 1: 100%|█████████████████████████████████| 1000/1000 [13:07<00:00, 1.27step/s, loss=0.0964]\n" ] }, { @@ -2175,31 +2185,31 @@ "text": [ "\n", "\tEpoch done\n", - "Saving model at epoch 1\n" + "Saving model at epoch 1\n", + "\n", + "\tEpoch 1 completed. Avg Loss: 0.21079306304454803, Time: 787.50s, Best Loss: 0.21079306304454803 \n", + "\n", + "Epoch 2/1\n" ] }, { - "ename": "IndexError", - "evalue": "Too many indices for array: 1 non-None/Ellipsis indices for dim 0.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[92], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# jax.profiler.start_server(6009)\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m final_state \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[89], line 431\u001b[0m, in \u001b[0;36mDiffusionTrainer.fit\u001b[0;34m(self, data, steps_per_epoch, epochs)\u001b[0m\n\u001b[1;32m 429\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m 430\u001b[0m text_embedder \u001b[38;5;241m=\u001b[39m data[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmodel\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m--> 431\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msteps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbatch_size\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnull_labels_seq\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43mnull_labels_full\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtext_embedder\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43mtext_embedder\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[89], line 258\u001b[0m, in \u001b[0;36mSimpleTrainer.fit\u001b[0;34m(self, data, steps_per_epoch, epochs, train_step_args)\u001b[0m\n\u001b[1;32m 256\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbest_loss \u001b[38;5;241m=\u001b[39m avg_loss\n\u001b[1;32m 257\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbest_state \u001b[38;5;241m=\u001b[39m state\n\u001b[0;32m--> 258\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 260\u001b[0m \u001b[38;5;66;03m# Compute Metrics\u001b[39;00m\n\u001b[1;32m 261\u001b[0m metrics_str \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m'\u001b[39m\n", - "Cell \u001b[0;32mIn[89], line 141\u001b[0m, in \u001b[0;36mSimpleTrainer.save\u001b[0;34m(self, epoch)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msave\u001b[39m(\u001b[38;5;28mself\u001b[39m, epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m):\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSaving model at epoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 139\u001b[0m ckpt \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 140\u001b[0m \u001b[38;5;66;03m# 'model': self.model,\u001b[39;00m\n\u001b[0;32m--> 141\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mstate\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_state\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 142\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbest_state\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_best_state(),\n\u001b[1;32m 143\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbest_loss\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbest_loss\n\u001b[1;32m 144\u001b[0m }\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 146\u001b[0m save_args \u001b[38;5;241m=\u001b[39m orbax_utils\u001b[38;5;241m.\u001b[39msave_args_from_target(ckpt)\n", - "Cell \u001b[0;32mIn[89], line 107\u001b[0m, in \u001b[0;36mSimpleTrainer.get_state\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_state\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 107\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mflax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjax_utils\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munreplicate\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstate\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/flax/jax_utils.py:50\u001b[0m, in \u001b[0;36munreplicate\u001b[0;34m(tree)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21munreplicate\u001b[39m(tree):\n\u001b[1;32m 49\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Returns a single instance of a replicated array.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 50\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtree_util\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtree_map\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43;01mlambda\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtree\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/tree_util.py:343\u001b[0m, in \u001b[0;36mtree_map\u001b[0;34m(f, tree, is_leaf, *rest)\u001b[0m\n\u001b[1;32m 341\u001b[0m leaves, treedef \u001b[38;5;241m=\u001b[39m tree_flatten(tree, is_leaf)\n\u001b[1;32m 342\u001b[0m all_leaves \u001b[38;5;241m=\u001b[39m [leaves] \u001b[38;5;241m+\u001b[39m [treedef\u001b[38;5;241m.\u001b[39mflatten_up_to(r) \u001b[38;5;28;01mfor\u001b[39;00m r \u001b[38;5;129;01min\u001b[39;00m rest]\n\u001b[0;32m--> 343\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtreedef\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munflatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mxs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mxs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mzip\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mall_leaves\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/tree_util.py:343\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 341\u001b[0m leaves, treedef \u001b[38;5;241m=\u001b[39m tree_flatten(tree, is_leaf)\n\u001b[1;32m 342\u001b[0m all_leaves \u001b[38;5;241m=\u001b[39m [leaves] \u001b[38;5;241m+\u001b[39m [treedef\u001b[38;5;241m.\u001b[39mflatten_up_to(r) \u001b[38;5;28;01mfor\u001b[39;00m r \u001b[38;5;129;01min\u001b[39;00m rest]\n\u001b[0;32m--> 343\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m treedef\u001b[38;5;241m.\u001b[39munflatten(\u001b[43mf\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mxs\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m xs \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(\u001b[38;5;241m*\u001b[39mall_leaves))\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/flax/jax_utils.py:50\u001b[0m, in \u001b[0;36munreplicate..\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21munreplicate\u001b[39m(tree):\n\u001b[1;32m 49\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Returns a single instance of a replicated array.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 50\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m jax\u001b[38;5;241m.\u001b[39mtree_util\u001b[38;5;241m.\u001b[39mtree_map(\u001b[38;5;28;01mlambda\u001b[39;00m x: \u001b[43mx\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m, tree)\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/array.py:355\u001b[0m, in \u001b[0;36mArrayImpl.__getitem__\u001b[0;34m(self, idx)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 353\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m out\n\u001b[0;32m--> 355\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mlax_numpy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_rewriting_take\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:7856\u001b[0m, in \u001b[0;36m_rewriting_take\u001b[0;34m(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)\u001b[0m\n\u001b[1;32m 7853\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m lax\u001b[38;5;241m.\u001b[39mdynamic_index_in_dim(arr, idx, keepdims\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 7855\u001b[0m treedef, static_idx, dynamic_idx \u001b[38;5;241m=\u001b[39m _split_index_for_jit(idx, arr\u001b[38;5;241m.\u001b[39mshape)\n\u001b[0;32m-> 7856\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_gather\u001b[49m\u001b[43m(\u001b[49m\u001b[43marr\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtreedef\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstatic_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdynamic_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindices_are_sorted\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7857\u001b[0m \u001b[43m \u001b[49m\u001b[43munique_indices\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfill_value\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:7865\u001b[0m, in \u001b[0;36m_gather\u001b[0;34m(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)\u001b[0m\n\u001b[1;32m 7862\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_gather\u001b[39m(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,\n\u001b[1;32m 7863\u001b[0m unique_indices, mode, fill_value):\n\u001b[1;32m 7864\u001b[0m idx \u001b[38;5;241m=\u001b[39m _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)\n\u001b[0;32m-> 7865\u001b[0m indexer \u001b[38;5;241m=\u001b[39m \u001b[43m_index_to_gather\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshape\u001b[49m\u001b[43m(\u001b[49m\u001b[43marr\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# shared with _scatter_update\u001b[39;00m\n\u001b[1;32m 7866\u001b[0m y \u001b[38;5;241m=\u001b[39m arr\n\u001b[1;32m 7868\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fill_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:7973\u001b[0m, in \u001b[0;36m_index_to_gather\u001b[0;34m(x_shape, idx, normalize_indices)\u001b[0m\n\u001b[1;32m 7970\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_index_to_gather\u001b[39m(x_shape: Sequence[\u001b[38;5;28mint\u001b[39m], idx: Sequence[Any],\n\u001b[1;32m 7971\u001b[0m normalize_indices: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m _Indexer:\n\u001b[1;32m 7972\u001b[0m \u001b[38;5;66;03m# Remove ellipses and add trailing slice(None)s.\u001b[39;00m\n\u001b[0;32m-> 7973\u001b[0m idx \u001b[38;5;241m=\u001b[39m \u001b[43m_canonicalize_tuple_index\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx_shape\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7975\u001b[0m \u001b[38;5;66;03m# Check for scalar boolean indexing: this requires inserting extra dimensions\u001b[39;00m\n\u001b[1;32m 7976\u001b[0m \u001b[38;5;66;03m# before performing the rest of the logic.\u001b[39;00m\n\u001b[1;32m 7977\u001b[0m scalar_bool_dims: Sequence[\u001b[38;5;28mint\u001b[39m] \u001b[38;5;241m=\u001b[39m [n \u001b[38;5;28;01mfor\u001b[39;00m n, i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(idx) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(i, \u001b[38;5;28mbool\u001b[39m)]\n", - "File \u001b[0;32m~/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:8293\u001b[0m, in \u001b[0;36m_canonicalize_tuple_index\u001b[0;34m(arr_ndim, idx, array_name)\u001b[0m\n\u001b[1;32m 8291\u001b[0m num_dimensions_consumed \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(\u001b[38;5;129;01mnot\u001b[39;00m (e \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m e \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28mEllipsis\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e, \u001b[38;5;28mbool\u001b[39m)) \u001b[38;5;28;01mfor\u001b[39;00m e \u001b[38;5;129;01min\u001b[39;00m idx)\n\u001b[1;32m 8292\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_dimensions_consumed \u001b[38;5;241m>\u001b[39m arr_ndim:\n\u001b[0;32m-> 8293\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m(\n\u001b[1;32m 8294\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mToo many indices for \u001b[39m\u001b[38;5;132;01m{\u001b[39;00marray_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_dimensions_consumed\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 8295\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnon-None/Ellipsis indices for dim \u001b[39m\u001b[38;5;132;01m{\u001b[39;00marr_ndim\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 8296\u001b[0m ellipses \u001b[38;5;241m=\u001b[39m (i \u001b[38;5;28;01mfor\u001b[39;00m i, elt \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(idx) \u001b[38;5;28;01mif\u001b[39;00m elt \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28mEllipsis\u001b[39m)\n\u001b[1;32m 8297\u001b[0m ellipsis_index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mnext\u001b[39m(ellipses, \u001b[38;5;28;01mNone\u001b[39;00m)\n", - "\u001b[0;31mIndexError\u001b[0m: Too many indices for array: 1 non-None/Ellipsis indices for dim 0." + "name": "stderr", + "output_type": "stream", + "text": [ + "\t\tEpoch 2: 100%|█████████████████████████████████| 1000/1000 [11:07<00:00, 1.50step/s, loss=0.0725]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\tEpoch done\n", + "Saving model at epoch 2\n", + "\n", + "\tEpoch 2 completed. Avg Loss: 0.08001875877380371, Time: 667.77s, Best Loss: 0.08001875877380371 \n", + "Saving model at epoch 1\n", + "Error saving checkpoint Checkpoint for step 1 already exists.\n" ] } ], diff --git a/setup.py b/setup.py index e76291e..88f9ac6 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,12 @@ from setuptools import find_packages, setup required_packages=[ - 'flax==0.8.4', + 'flax>=0.8.4', 'optax>=0.2.2', - 'jax==0.4.28', + 'jax>=0.4.28', + 'orbax', + 'clu', + 'mlcommons' ] setup( diff --git a/training_tpu.py b/training_tpu.py new file mode 100644 index 0000000..61d4b0e --- /dev/null +++ b/training_tpu.py @@ -0,0 +1,2132 @@ + +%load_ext dotenv +%dotenv + +import flax +import tqdm +from flax import linen as nn +import jax +from typing import Dict, Callable, Sequence, Any, Union +from dataclasses import field +import jax.numpy as jnp +import tensorflow_datasets as tfds +import grain.python as pygrain +import tensorflow as tf +import numpy as np +import augmax + +import matplotlib.pyplot as plt +from clu import metrics +from flax.training import train_state # Useful dataclass to keep train state +import optax +from flax import struct # Flax dataclasses +import time +import os +from datetime import datetime +from flax.training import orbax_utils +import functools +from tensorflow_datasets.core.utils import gcs_utils +gcs_utils._is_gcs_disabled = True +import json +# For CLIP +from transformers import AutoTokenizer, FlaxCLIPTextModel, CLIPTextModel +import wandb + +# %% [markdown] +# # Global Variables +##################################################################################################################### +############################################## Globasl Variables #################################################### +##################################################################################################################### + +GRAIN_WORKER_COUNT = 16 +GRAIN_READ_THREAD_COUNT = 64 +GRAIN_READ_BUFFER_SIZE = 50 +GRAIN_WORKER_BUFFER_SIZE = 20 + +# %% [markdown] +# # Initialization +##################################################################################################################### +################################################# Initialization #################################################### +##################################################################################################################### + +# %% +jax.distributed.initialize() + +# %% +print(f"Number of devices: {jax.device_count()}") +print(f"Local devices: {jax.local_devices()}") + +# %% +normalizeImage = lambda x: jax.nn.standardize(x, mean=[127.5], std=[127.5]) +denormalizeImage = lambda x: (x + 1.0) * 127.5 + + +def plotImages(imgs, fig_size=(8, 8), dpi=100): + fig = plt.figure(figsize=fig_size, dpi=dpi) + imglen = imgs.shape[0] + for i in range(imglen): + plt.subplot(fig_size[0], fig_size[1], i + 1) + plt.imshow(jnp.astype(denormalizeImage(imgs[i, :, :, :]), jnp.uint8)) + plt.axis("off") + plt.show() + +class RandomClass(): + def __init__(self, rng: jax.random.PRNGKey): + self.rng = rng + + def get_random_key(self): + self.rng, subkey = jax.random.split(self.rng) + return subkey + + def get_sigmas(self, steps): + return jnp.tan(self.theta_min + steps * (self.theta_max - self.theta_min)) / self.kappa + + def reset_random_key(self): + self.rng = jax.random.PRNGKey(42) + +class MarkovState(struct.PyTreeNode): + pass + +class RandomMarkovState(MarkovState): + rng: jax.random.PRNGKey + + def get_random_key(self): + rng, subkey = jax.random.split(self.rng) + return RandomMarkovState(rng), subkey + +# %% [markdown] +# # Data Pipeline + +# %% +def defaultTextEncodeModel(backend="jax"): + modelname = "openai/clip-vit-large-patch14" + if backend == "jax": + model = FlaxCLIPTextModel.from_pretrained(modelname, dtype=jnp.bfloat16) + else: + model = CLIPTextModel.from_pretrained(modelname) + tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16) + return model, tokenizer + +def encodePrompts(prompts, model, tokenizer=None): + if model == None: + model, tokenizer = defaultTextEncodeModel() + if tokenizer == None: + tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14") + + # inputs = tokenizer(prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="np") + inputs = tokenizer(prompts, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="np") + outputs = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask']) + # outputs = infer(inputs['input_ids'], inputs['attention_mask']) + + last_hidden_state = outputs.last_hidden_state + pooler_output = outputs.pooler_output # pooled (EOS token) states + embed_pooled = pooler_output#.astype(jnp.float16) + embed_labels_full = last_hidden_state#.astype(jnp.float16) + + return embed_pooled, embed_labels_full + +class CaptionProcessor: + def __init__(self, tensor_type="pt", modelname="openai/clip-vit-large-patch14"): + self.tokenizer = AutoTokenizer.from_pretrained(modelname) + self.tensor_type = tensor_type + + def __call__(self, caption): + # print(caption) + tokens = self.tokenizer(caption, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors=self.tensor_type) + # print(tokens.keys()) + return { + "input_ids": tokens["input_ids"], + "attention_mask": tokens["attention_mask"], + "caption": caption, + } + + def __repr__(self): + return self.__class__.__name__ + '()' + +# %% +def data_source_tfds(name): + def data_source(): + return tfds.load(name, split="all", shuffle_files=True) + return data_source + +def data_source_cc12m(source="/home/mrwhite0racle/research/FlaxDiff/datasets/gcs_mount/arrayrecord/cc12m/"): + def data_source(): + cc12m_records_path = source + cc12m_records = [os.path.join(cc12m_records_path, i) for i in os.listdir(cc12m_records_path) if 'array_record' in i] + ds = pygrain.ArrayRecordDataSource(cc12m_records[:-1]) + return ds + return data_source + +def labelizer_oxford_flowers102(path): + with open(path, "r") as f: + textlabels = [i.strip() for i in f.readlines()] + textlabels = tf.convert_to_tensor(textlabels) + def load_labels(sample): + return textlabels[sample['label']] + return load_labels + +def labelizer_cc12m(sample): + return sample['txt'] + +# Configure the following for your datasets +datasetMap = { + "oxford_flowers102": { + "source":data_source_tfds("oxford_flowers102"), + "labelizer":labelizer_oxford_flowers102("/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt"), + }, + "cc12m": { + "source":data_source_cc12m(), + "labelizer":labelizer_cc12m, + } +} + +# %% +import struct as st + +def unpack_dict_of_byte_arrays(packed_data): + unpacked_dict = {} + offset = 0 + while offset < len(packed_data): + # Unpack the key length + key_length = st.unpack_from('I', packed_data, offset)[0] + offset += st.calcsize('I') + # Unpack the key bytes and convert to string + key = packed_data[offset:offset+key_length].decode('utf-8') + offset += key_length + # Unpack the byte array length + byte_array_length = st.unpack_from('I', packed_data, offset)[0] + offset += st.calcsize('I') + # Unpack the byte array + byte_array = packed_data[offset:offset+byte_array_length] + offset += byte_array_length + unpacked_dict[key] = byte_array + return unpacked_dict + +def get_dataset_grain(data_name="oxford_flowers102", + batch_size=64, image_scale=256, + count=None, num_epochs=None, + text_encoders=defaultTextEncodeModel(), + method=jax.image.ResizeMethod.LANCZOS3): + dataset = datasetMap[data_name] + data_source = dataset["source"]() + labelizer = dataset["labelizer"] + + import cv2 + + model, tokenizer = text_encoders + + null_labels, null_labels_full = encodePrompts([""], model, tokenizer) + null_labels = np.array(null_labels[0], dtype=np.float16) + null_labels_full = np.array(null_labels_full[0], dtype=np.float16) + + class augmenter(pygrain.MapTransform): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.caption_processor = CaptionProcessor(tensor_type="np") + + def map(self, element) -> Dict[str, jnp.array]: + element = unpack_dict_of_byte_arrays(element) + image = np.asarray(bytearray(element['jpg']), dtype="uint8") + image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED) + image = cv2.cvtColor(image , cv2.COLOR_BGR2RGB) + image = cv2.resize(image, (image_scale, image_scale), interpolation=cv2.INTER_AREA) + # image = (image - 127.5) / 127.5 + caption = labelizer(element).decode('utf-8') + results = self.caption_processor(caption) + return { + "image": image, + "input_ids": results['input_ids'][0], + "attention_mask": results['attention_mask'][0], + } + + sampler = pygrain.IndexSampler( + num_records=len(data_source) if count is None else count, + shuffle=True, + seed=0, + num_epochs=num_epochs, + shard_options=pygrain.NoSharding(), + ) + + transformations = [augmenter(), pygrain.Batch(batch_size, drop_remainder=True)] + + loader = pygrain.DataLoader( + data_source=data_source, + sampler=sampler, + operations=transformations, + worker_count=GRAIN_WORKER_COUNT, + read_options=pygrain.ReadOptions(GRAIN_READ_THREAD_COUNT, GRAIN_READ_BUFFER_SIZE), + worker_buffer_size=GRAIN_WORKER_BUFFER_SIZE + ) + + def get_trainset(): + return loader + + return { + "train": get_trainset, + "train_len": len(data_source), + "batch_size": batch_size, + "null_labels": null_labels, + "null_labels_full": null_labels_full, + "model": model, + "tokenizer": tokenizer, + } + +# %% +from flaxdiff.schedulers import CosineNoiseSchedule, NoiseScheduler, GeneralizedNoiseScheduler, KarrasVENoiseScheduler, EDMNoiseScheduler +from flaxdiff.predictors import VPredictionTransform, EpsilonPredictionTransform, DiffusionPredictionTransform, DirectPredictionTransform, KarrasPredictionTransform + +# %% [markdown] +# # Modeling + +# %% [markdown] +# ## Metrics + +# %% [markdown] +# ## Callbacks + +# %% [markdown] +# ## Model Generator + +# %% +import jax.experimental.pallas.ops.tpu.flash_attention +from flaxdiff.models.simple_unet import l2norm, ConvLayer, TimeEmbedding, TimeProjection, Upsample, Downsample, ResidualBlock, PixelShuffle +from flaxdiff.models.simple_unet import FourierEmbedding + +from flaxdiff.models.attention import kernel_init +# from flash_attn_jax import flash_mha +# from flaxdiff.models.favor_fastattn import make_fast_generalized_attention, make_fast_softmax_attention + +# Kernel initializer to use +def kernel_init(scale, dtype=jnp.float32): + scale = max(scale, 1e-10) + return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype) + +class EfficientAttention(nn.Module): + """ + Based on the pallas attention implementation. + """ + query_dim: int + heads: int = 4 + dim_head: int = 64 + dtype: Any = jnp.float32 + precision: Any = jax.lax.Precision.HIGHEST + use_bias: bool = True + kernel_init: Callable = lambda : kernel_init(1.0) + + def setup(self): + inner_dim = self.dim_head * self.heads + # Weights were exported with old names {to_q, to_k, to_v, to_out} + dense = functools.partial( + nn.Dense, + self.heads * self.dim_head, + precision=self.precision, + use_bias=self.use_bias, + kernel_init=self.kernel_init(), + dtype=self.dtype + ) + self.query = dense(name="to_q") + self.key = dense(name="to_k") + self.value = dense(name="to_v") + + self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision, + kernel_init=self.kernel_init(), dtype=self.dtype, name="to_out_0") + # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16) + + def _reshape_tensor_to_head_dim(self, tensor): + batch_size, _, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + return tensor + + def _reshape_tensor_from_head_dim(self, tensor): + batch_size, _, seq_len, dim = tensor.shape + head_size = self.heads + tensor = jnp.transpose(tensor, (0, 2, 1, 3)) + tensor = tensor.reshape(batch_size, 1, seq_len, dim * head_size) + return tensor + + @nn.compact + def __call__(self, x:jax.Array, context=None): + # print(x.shape) + # x has shape [B, H * W, C] + context = x if context is None else context + + B, H, W, C = x.shape + x = x.reshape((B, 1, H * W, C)) + + B, _H, _W, _C = context.shape + context = context.reshape((B, 1, _H * _W, _C)) + + query = self.query(x) + key = self.key(context) + value = self.value(context) + + query = self._reshape_tensor_to_head_dim(query) + key = self._reshape_tensor_to_head_dim(key) + value = self._reshape_tensor_to_head_dim(value) + + hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.flash_attention( + query, key, value, None + ) + + hidden_states = self._reshape_tensor_from_head_dim(hidden_states) + + + # hidden_states = nn.dot_product_attention( + # query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision + # ) + + proj = self.proj_attn(hidden_states) + + proj = proj.reshape((B, H, W, C)) + + return proj + + +class NormalAttention(nn.Module): + """ + Simple implementation of the normal attention. + """ + query_dim: int + heads: int = 4 + dim_head: int = 64 + dtype: Any = jnp.float32 + precision: Any = jax.lax.Precision.HIGHEST + use_bias: bool = True + kernel_init: Callable = lambda : kernel_init(1.0) + + def setup(self): + inner_dim = self.dim_head * self.heads + dense = functools.partial( + nn.DenseGeneral, + features=[self.heads, self.dim_head], + axis=-1, + precision=self.precision, + use_bias=self.use_bias, + kernel_init=self.kernel_init(), + dtype=self.dtype + ) + self.query = dense(name="to_q") + self.key = dense(name="to_k") + self.value = dense(name="to_v") + + self.proj_attn = nn.DenseGeneral( + self.query_dim, + axis=(-2, -1), + precision=self.precision, + use_bias=self.use_bias, + dtype=self.dtype, + name="to_out_0", + kernel_init=self.kernel_init() + # kernel_init=jax.nn.initializers.xavier_uniform() + ) + + @nn.compact + def __call__(self, x, context=None): + # x has shape [B, H, W, C] + context = x if context is None else context + query = self.query(x) + key = self.key(context) + value = self.value(context) + + hidden_states = nn.dot_product_attention( + query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision + ) + + proj = self.proj_attn(hidden_states) + return proj + +class AttentionBlock(nn.Module): + # Has self and cross attention + query_dim: int + heads: int = 4 + dim_head: int = 64 + dtype: Any = jnp.float32 + precision: Any = jax.lax.Precision.HIGHEST + use_bias: bool = True + kernel_init: Callable = lambda : kernel_init(1.0) + use_flash_attention:bool = False + use_cross_only:bool = False + + def setup(self): + if self.use_flash_attention: + attenBlock = EfficientAttention + else: + attenBlock = NormalAttention + + self.attention1 = attenBlock( + query_dim=self.query_dim, + heads=self.heads, + dim_head=self.dim_head, + name=f'Attention1', + precision=self.precision, + use_bias=self.use_bias, + dtype=self.dtype, + kernel_init=self.kernel_init + ) + self.attention2 = attenBlock( + query_dim=self.query_dim, + heads=self.heads, + dim_head=self.dim_head, + name=f'Attention2', + precision=self.precision, + use_bias=self.use_bias, + dtype=self.dtype, + kernel_init=self.kernel_init + ) + + self.ff = nn.DenseGeneral( + features=self.query_dim, + use_bias=self.use_bias, + precision=self.precision, + dtype=self.dtype, + kernel_init=self.kernel_init(), + name="ff" + ) + self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype) + self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype) + self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype) + self.norm4 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype) + + @nn.compact + def __call__(self, hidden_states, context=None): + # self attention + residual = hidden_states + hidden_states = self.norm1(hidden_states) + if self.use_cross_only: + hidden_states = self.attention1(hidden_states, context) + else: + hidden_states = self.attention1(hidden_states) + hidden_states = hidden_states + residual + + # cross attention + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.attention2(hidden_states, context) + hidden_states = hidden_states + residual + + # feed forward + residual = hidden_states + hidden_states = self.norm3(hidden_states) + hidden_states = nn.gelu(hidden_states) + hidden_states = self.ff(hidden_states) + hidden_states = hidden_states + residual + + return hidden_states + +class TransformerBlock(nn.Module): + heads: int = 4 + dim_head: int = 32 + use_linear_attention: bool = True + dtype: Any = jnp.bfloat16 + precision: Any = jax.lax.Precision.HIGH + use_projection: bool = False + use_flash_attention:bool = True + use_self_and_cross:bool = False + + @nn.compact + def __call__(self, x, context=None): + inner_dim = self.heads * self.dim_head + B, H, W, C = x.shape + normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x) + if self.use_projection == True: + if self.use_linear_attention: + projected_x = nn.Dense(features=inner_dim, + use_bias=False, precision=self.precision, + kernel_init=kernel_init(1.0), + dtype=self.dtype, name=f'project_in')(normed_x) + else: + projected_x = nn.Conv( + features=inner_dim, kernel_size=(1, 1), + kernel_init=kernel_init(1.0), + strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype, + precision=self.precision, name=f'project_in_conv', + )(normed_x) + else: + projected_x = normed_x + inner_dim = C + + context = projected_x if context is None else context + + if self.use_self_and_cross: + projected_x = AttentionBlock( + query_dim=inner_dim, + heads=self.heads, + dim_head=self.dim_head, + name=f'Attention', + precision=self.precision, + use_bias=False, + dtype=self.dtype, + use_flash_attention=self.use_flash_attention, + use_cross_only=False + )(projected_x, context) + elif self.use_flash_attention == True: + projected_x = EfficientAttention( + query_dim=inner_dim, + heads=self.heads, + dim_head=self.dim_head, + name=f'Attention', + precision=self.precision, + use_bias=False, + dtype=self.dtype, + )(projected_x, context) + else: + projected_x = NormalAttention( + query_dim=inner_dim, + heads=self.heads, + dim_head=self.dim_head, + name=f'Attention', + precision=self.precision, + use_bias=False, + )(projected_x, context) + + + if self.use_projection == True: + if self.use_linear_attention: + projected_x = nn.Dense(features=C, precision=self.precision, + dtype=self.dtype, use_bias=False, + kernel_init=kernel_init(1.0), + name=f'project_out')(projected_x) + else: + projected_x = nn.Conv( + features=C, kernel_size=(1, 1), + kernel_init=kernel_init(1.0), + strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype, + precision=self.precision, name=f'project_out_conv', + )(projected_x) + + out = x + projected_x + return out + + +# %% [markdown] +# ## Attention and other prototyping + +# %% +x = jnp.ones((16, 1, 16*16, 64)) +batch_size, _, seq_len, dim = x.shape +head_size = 4 +dim_head = dim // head_size +k = nn.Dense(dim_head * head_size, precision=jax.lax.Precision.HIGHEST, use_bias=True) +param = k.init(jax.random.PRNGKey(42), x) +tensor = k.apply(param, x) +print(tensor.shape) +tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) +tensor = jnp.transpose(tensor, (0, 2, 1, 3)) +print(tensor.shape) + + + +# %% +x = jnp.ones((16, 64, 64, 128)) +context = jnp.ones((16, 64, 64, 128)) +attention_block = TransformerBlock(heads=4, dim_head=64//4, dtype=jnp.bfloat16, use_flash_attention=False, use_projection=False, use_self_and_cross=False) +params = attention_block.init(jax.random.PRNGKey(0), x, context) +@jax.jit +def apply(params, x, context): + return attention_block.apply(params, x, context) + +apply(params, x, context) + +%timeit -n 1 apply(params, x, context) + +# %% +x = jnp.ones((1, 16, 16, 64)) +context = jnp.ones((1, 12, 768)) +# pad the context +context = jnp.pad(context, ((0, 0), (0, 4), (0, 0)), mode='constant', constant_values=0) +print(context.shape) +context = None#jnp.reshape(context, (1, 1, 16, 768)) +attention_block = TransformerBlock(heads=4, dim_head=64//4, dtype=jnp.bfloat16, use_flash_attention=True, use_projection=False, use_self_and_cross=False) +params = attention_block.init(jax.random.PRNGKey(0), x, context) +out = attention_block.apply(params, x, context) +print("Output :", out.shape) +print(attention_block.tabulate(jax.random.key(0), x, context, console_kwargs={"width": 200, "force_jupyter":True, })) +print(jnp.mean(out), jnp.std(out)) +# plt.hist(out.flatten(), bins=100) +# %timeit attention_block.apply(params, x) + +# %% [markdown] +# ## Main Model + +# %% +class ResidualBlock(nn.Module): + conv_type:str + features:int + kernel_size:tuple=(3, 3) + strides:tuple=(1, 1) + padding:str="SAME" + activation:Callable=jax.nn.swish + direction:str=None + res:int=2 + norm_groups:int=8 + kernel_init:Callable=kernel_init(1.0) + dtype: Any = jnp.float32 + precision: Any = jax.lax.Precision.HIGHEST + + @nn.compact + def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None): + residual = x + out = nn.GroupNorm(self.norm_groups)(x) + out = self.activation(out) + + out = ConvLayer( + self.conv_type, + features=self.features, + kernel_size=self.kernel_size, + strides=self.strides, + kernel_init=self.kernel_init, + name="conv1", + dtype=self.dtype, + precision=self.precision + )(out) + + temb = nn.DenseGeneral( + features=self.features, + name="temb_projection", + dtype=self.dtype, + precision=self.precision)(temb) + temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) + # scale, shift = jnp.split(temb, 2, axis=-1) + # out = out * (1 + scale) + shift + out = out + temb + + out = nn.GroupNorm(self.norm_groups)(out) + out = self.activation(out) + + out = ConvLayer( + self.conv_type, + features=self.features, + kernel_size=self.kernel_size, + strides=self.strides, + kernel_init=self.kernel_init, + name="conv2", + dtype=self.dtype, + precision=self.precision + )(out) + + if residual.shape != out.shape: + residual = ConvLayer( + self.conv_type, + features=self.features, + kernel_size=(1, 1), + strides=1, + kernel_init=self.kernel_init, + name="residual_conv", + dtype=self.dtype, + precision=self.precision + )(residual) + out = out + residual + + out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out + + return out + +class Unet(nn.Module): + emb_features:int=64*4, + feature_depths:list=[64, 128, 256, 512], + attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}], + num_res_blocks:int=2, + num_middle_res_blocks:int=1, + activation:Callable = jax.nn.swish + norm_groups:int=8 + dtype: Any = jnp.bfloat16 + precision: Any = jax.lax.Precision.HIGH + + @nn.compact + def __call__(self, x, temb, textcontext=None): + # print("embedding features", self.emb_features) + temb = FourierEmbedding(features=self.emb_features)(temb) + temb = TimeProjection(features=self.emb_features)(temb) + + _, TS, TC = textcontext.shape + + # print("time embedding", temb.shape) + feature_depths = self.feature_depths + attention_configs = self.attention_configs + + conv_type = up_conv_type = down_conv_type = middle_conv_type = "conv" + # middle_conv_type = "separable" + + x = ConvLayer( + conv_type, + features=self.feature_depths[0], + kernel_size=(3, 3), + strides=(1, 1), + kernel_init=kernel_init(1.0), + dtype=self.dtype, + precision=self.precision + )(x) + downs = [x] + + # Downscaling blocks + for i, (dim_out, attention_config) in enumerate(zip(feature_depths, attention_configs)): + dim_in = x.shape[-1] + # dim_in = dim_out + for j in range(self.num_res_blocks): + x = ResidualBlock( + down_conv_type, + name=f"down_{i}_residual_{j}", + features=dim_in, + kernel_init=kernel_init(1.0), + kernel_size=(3, 3), + strides=(1, 1), + activation=self.activation, + norm_groups=self.norm_groups, + dtype=self.dtype, + precision=self.precision + )(x, temb) + if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block + B, H, W, _ = x.shape + if H > TS: + padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC)) + else: + padded_context = None + x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32), + dim_head=dim_in // attention_config['heads'], + use_flash_attention=attention_config.get("flash_attention", True), + use_projection=attention_config.get("use_projection", False), + use_self_and_cross=attention_config.get("use_self_and_cross", True), + precision=attention_config.get("precision", self.precision), + name=f"down_{i}_attention_{j}")(x, padded_context) + # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in) + downs.append(x) + if i != len(feature_depths) - 1: + # print("Downsample", i, x.shape) + x = Downsample( + features=dim_out, + scale=2, + activation=self.activation, + name=f"down_{i}_downsample", + dtype=self.dtype, + precision=self.precision + )(x) + + # Middle Blocks + middle_dim_out = self.feature_depths[-1] + middle_attention = self.attention_configs[-1] + for j in range(self.num_middle_res_blocks): + x = ResidualBlock( + middle_conv_type, + name=f"middle_res1_{j}", + features=middle_dim_out, + kernel_init=kernel_init(1.0), + kernel_size=(3, 3), + strides=(1, 1), + activation=self.activation, + norm_groups=self.norm_groups, + dtype=self.dtype, + precision=self.precision + )(x, temb) + if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block + x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32), + dim_head=middle_dim_out // middle_attention['heads'], + use_flash_attention=middle_attention.get("flash_attention", True), + use_linear_attention=False, + use_projection=middle_attention.get("use_projection", False), + use_self_and_cross=False, + precision=middle_attention.get("precision", self.precision), + name=f"middle_attention_{j}")(x) + x = ResidualBlock( + middle_conv_type, + name=f"middle_res2_{j}", + features=middle_dim_out, + kernel_init=kernel_init(1.0), + kernel_size=(3, 3), + strides=(1, 1), + activation=self.activation, + norm_groups=self.norm_groups, + dtype=self.dtype, + precision=self.precision + )(x, temb) + + # Upscaling Blocks + for i, (dim_out, attention_config) in enumerate(zip(reversed(feature_depths), reversed(attention_configs))): + # print("Upscaling", i, "features", dim_out) + for j in range(self.num_res_blocks): + residual = downs.pop() + x = jnp.concatenate([x, residual], axis=-1) + # print("concat==> ", i, "concat", x.shape) + # kernel_size = (1 + 2 * (j + 1), 1 + 2 * (j + 1)) + kernel_size = (3, 3) + x = ResidualBlock( + up_conv_type,# if j == 0 else "separable", + name=f"up_{i}_residual_{j}", + features=dim_out, + kernel_init=kernel_init(1.0), + kernel_size=kernel_size, + strides=(1, 1), + activation=self.activation, + norm_groups=self.norm_groups, + dtype=self.dtype, + precision=self.precision + )(x, temb) + if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block + # B, H, W, _ = x.shape + # if H > TS: + # padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC)) + # else: + # padded_context = None + x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32), + dim_head=dim_out // attention_config['heads'], + use_flash_attention=attention_config.get("flash_attention", True), + use_projection=attention_config.get("use_projection", False), + use_self_and_cross=attention_config.get("use_self_and_cross", True), + precision=attention_config.get("precision", self.precision), + name=f"up_{i}_attention_{j}")(x, residual) + # print("Upscaling ", i, x.shape) + if i != len(feature_depths) - 1: + x = Upsample( + features=feature_depths[-i], + scale=2, + activation=self.activation, + name=f"up_{i}_upsample", + dtype=self.dtype, + precision=self.precision + )(x) + + # x = nn.GroupNorm(8)(x) + x = ConvLayer( + conv_type, + features=self.feature_depths[0], + kernel_size=(3, 3), + strides=(1, 1), + kernel_init=kernel_init(0.0), + dtype=self.dtype, + precision=self.precision + )(x) + + x = jnp.concatenate([x, downs.pop()], axis=-1) + + x = ResidualBlock( + conv_type, + name="final_residual", + features=self.feature_depths[0], + kernel_init=kernel_init(1.0), + kernel_size=(3,3), + strides=(1, 1), + activation=self.activation, + norm_groups=self.norm_groups, + dtype=self.dtype, + precision=self.precision + )(x, temb) + + x = nn.GroupNorm(self.norm_groups)(x) + x = self.activation(x) + + noise_out = ConvLayer( + conv_type, + features=3, + kernel_size=(3, 3), + strides=(1, 1), + # activation=jax.nn.mish + kernel_init=kernel_init(0.0), + dtype=self.dtype, + precision=self.precision + )(x) + return noise_out#, attentions + +# %% +unet = Unet(emb_features=512, + feature_depths=[128, 256, 512, 1024], + attention_configs=[ + None, + # None, + # {"heads":32, "dtype":jnp.bfloat16, "flash_attention":True, "use_projection":False, "use_self_and_cross":True}, + {"heads":32, "dtype":jnp.bfloat16, "flash_attention":True, "use_projection":True, "use_self_and_cross":True}, + {"heads":32, "dtype":jnp.bfloat16, "flash_attention":True, "use_projection":True, "use_self_and_cross":True}, + {"heads":32, "dtype":jnp.bfloat16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False} + ], + num_res_blocks=4, + num_middle_res_blocks=1 +) + +inp = jnp.ones((1, 128, 128, 3)) +temb = jnp.ones((1,)) +textcontext = jnp.ones((1, 77, 768)) + +params = unet.init(jax.random.PRNGKey(0), inp, temb, textcontext) + +# %% +unet.tabulate(jax.random.key(0), inp, temb, textcontext, console_kwargs={"width": 200, "force_jupyter":True, }) + +# %% [markdown] +# # Training + +# %% +import flax.jax_utils +import orbax.checkpoint +import orbax +from typing import Any, Tuple, Mapping,Callable,List,Dict +from flax.metrics import tensorboard +from functools import partial + +@struct.dataclass +class Metrics(metrics.Collection): + accuracy: metrics.Accuracy + loss: metrics.Average.from_output('loss') + +# Define the TrainState +class SimpleTrainState(train_state.TrainState): + rngs: jax.random.PRNGKey + metrics: Metrics + + def get_random_key(self): + rngs, subkey = jax.random.split(self.rngs) + return self.replace(rngs=rngs), subkey + +class SimpleTrainer: + state : SimpleTrainState + best_state : SimpleTrainState + best_loss : float + model : nn.Module + ema_decay:float = 0.999 + + def __init__(self, + model:nn.Module, + input_shapes:Dict[str, Tuple[int]], + optimizer: optax.GradientTransformation, + rngs:jax.random.PRNGKey, + train_state:SimpleTrainState=None, + name:str="Simple", + load_from_checkpoint:bool=False, + checkpoint_suffix:str="", + loss_fn=optax.l2_loss, + param_transforms:Callable=None, + wandb_config:Dict[str, Any]=None + ): + self.model = model + self.name = name + self.loss_fn = loss_fn + self.input_shapes = input_shapes + + if wandb_config is not None: + run = wandb.init(**wandb_config) + self.wandb = run + + checkpointer = orbax.checkpoint.PyTreeCheckpointer() + options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=4, create=True) + self.checkpointer = orbax.checkpoint.CheckpointManager(self.checkpoint_path() + checkpoint_suffix, checkpointer, options) + + if load_from_checkpoint: + latest_epoch, old_state, old_best_state = self.load() + else: + latest_epoch, old_state, old_best_state = 0, None, None + + self.latest_epoch = latest_epoch + + if train_state == None: + self.init_state(optimizer, rngs, existing_state=old_state, existing_best_state=old_best_state, model=model, param_transforms=param_transforms) + else: + self.state = train_state + self.best_state = train_state + self.best_loss = 1e9 + + def get_input_ones(self): + return {k:jnp.ones((1, *v)) for k,v in self.input_shapes.items()} + + def init_state(self, + optimizer: optax.GradientTransformation, + rngs:jax.random.PRNGKey, + existing_state:dict=None, + existing_best_state:dict=None, + model:nn.Module=None, + param_transforms:Callable=None + ): + @partial(jax.pmap, axis_name="device") + def init_fn(rngs): + rngs, subkey = jax.random.split(rngs) + + if existing_state == None: + input_vars = self.get_input_ones() + params = model.init(subkey, **input_vars) + + # if param_transforms is not None: + # params = param_transforms(params) + + state = SimpleTrainState.create( + apply_fn=model.apply, + params=params, + tx=optimizer, + rngs=rngs, + metrics=Metrics.empty() + ) + return state + self.state = init_fn(jax.device_put_replicated(rngs, jax.devices())) + self.best_loss = 1e9 + if existing_best_state is not None: + self.best_state = self.state.replace(params=existing_best_state['params'], ema_params=existing_best_state['ema_params']) + else: + self.best_state = self.state + + def get_state(self): + return flax.jax_utils.unreplicate(self.state) + + def get_best_state(self): + return flax.jax_utils.unreplicate(self.best_state) + + def checkpoint_path(self): + experiment_name = self.name + path = os.path.join(os.path.abspath('./checkpoints'), experiment_name) + if not os.path.exists(path): + os.makedirs(path) + return path + + def tensorboard_path(self): + experiment_name = self.name + path = os.path.join(os.path.abspath('./tensorboard'), experiment_name) + if not os.path.exists(path): + os.makedirs(path) + return path + + def load(self): + epoch = self.checkpointer.latest_step() + print("Loading model from checkpoint", epoch) + ckpt = self.checkpointer.restore(epoch) + state = ckpt['state'] + best_state = ckpt['best_state'] + # Convert the state to a TrainState + self.best_loss = ckpt['best_loss'] + print(f"Loaded model from checkpoint at epoch {epoch}", ckpt['best_loss']) + return epoch, state, best_state + + def save(self, epoch=0): + print(f"Saving model at epoch {epoch}") + ckpt = { + # 'model': self.model, + 'state': self.get_state(), + 'best_state': self.get_best_state(), + 'best_loss': self.best_loss + } + try: + save_args = orbax_utils.save_args_from_target(ckpt) + self.checkpointer.save(epoch, ckpt, save_kwargs={'save_args': save_args}, force=True) + pass + except Exception as e: + print("Error saving checkpoint", e) + + def _define_train_step(self, **kwargs): + model = self.model + loss_fn = self.loss_fn + + @partial(jax.pmap, axis_name="device") + def train_step(state:SimpleTrainState, batch): + """Train for a single step.""" + images = batch['image'] + labels= batch['label'] + + def model_loss(params): + preds = model.apply(params, images) + expected_output = labels + nloss = loss_fn(preds, expected_output) + loss = jnp.mean(nloss) + return loss + loss, grads = jax.value_and_grad(model_loss)(state.params) + grads = jax.lax.pmean(grads, "device") + state = state.apply_gradients(grads=grads) + return state, loss + return train_step + + def _define_compute_metrics(self): + model = self.model + loss_fn = self.loss_fn + + @jax.jit + def compute_metrics(state:SimpleTrainState, batch): + preds = model.apply(state.params, batch['image']) + expected_output = batch['label'] + loss = jnp.mean(loss_fn(preds, expected_output)) + metric_updates = state.metrics.single_from_model_output(loss=loss, logits=preds, labels=expected_output) + metrics = state.metrics.merge(metric_updates) + state = state.replace(metrics=metrics) + return state + return compute_metrics + + def summary(self): + input_vars = self.get_input_ones() + print(self.model.tabulate(jax.random.key(0), **input_vars, console_kwargs={"width": 200, "force_jupyter":True, })) + + def config(self): + return { + "model": self.model, + "state": self.state, + "name": self.name, + "input_shapes": self.input_shapes + } + + def init_tensorboard(self, batch_size, steps_per_epoch, epochs): + summary_writer = tensorboard.SummaryWriter(self.tensorboard_path()) + summary_writer.hparams({ + **self.config(), + "steps_per_epoch": steps_per_epoch, + "epochs": epochs, + "batch_size": batch_size + }) + return summary_writer + + def fit(self, data, steps_per_epoch, epochs, train_step_args={}): + train_ds = iter(data['train']()) + if 'test' in data: + test_ds = data['test'] + else: + test_ds = None + train_step = self._define_train_step(**train_step_args) + compute_metrics = self._define_compute_metrics() + state = self.state + device_count = jax.device_count() + # train_ds = flax.jax_utils.prefetch_to_device(train_ds, jax.devices()) + + summary_writer = self.init_tensorboard(data['batch_size'], steps_per_epoch, epochs) + + while self.latest_epoch <= epochs: + self.latest_epoch += 1 + current_epoch = self.latest_epoch + print(f"\nEpoch {current_epoch}/{epochs}") + start_time = time.time() + epoch_loss = 0 + + with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar: + for i in range(steps_per_epoch): + batch = next(train_ds) + batch = jax.tree.map(lambda x: x.reshape((device_count, -1, *x.shape[1:])), batch) + # print(batch['image'].shape) + state, loss = train_step(state, batch) + loss = jnp.mean(loss) + # print("==>", loss) + epoch_loss += loss + if i % 100 == 0: + pbar.set_postfix(loss=f'{loss:.4f}') + pbar.update(100) + current_step = current_epoch*steps_per_epoch + i + summary_writer.scalar('Train Loss', loss, step=current_step) + if self.wandb is not None: + self.wandb.log({"train/loss": loss}) + + print(f"\n\tEpoch done") + end_time = time.time() + self.state = state + total_time = end_time - start_time + avg_time_per_step = total_time / steps_per_epoch + avg_loss = epoch_loss / steps_per_epoch + if avg_loss < self.best_loss: + self.best_loss = avg_loss + self.best_state = state + self.save(current_epoch) + + # Compute Metrics + metrics_str = '' + # if test_ds is not None: + # for test_batch in iter(test_ds()): + # state = compute_metrics(state, test_batch) + # metrics = state.metrics.compute() + # for metric,value in metrics.items(): + # summary_writer.scalar(f'Test {metric}', value, step=current_epoch) + # metrics_str += f', Test {metric}: {value:.4f}' + # state = state.replace(metrics=Metrics.empty()) + + print(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss} {metrics_str}") + + self.save(epochs) + return self.state + +# Define the TrainState with EMA parameters +class TrainState(SimpleTrainState): + rngs: jax.random.PRNGKey + ema_params: dict + + def get_random_key(self): + rngs, subkey = jax.random.split(self.rngs) + return self.replace(rngs=rngs), subkey + + def apply_ema(self, decay: float=0.999): + new_ema_params = jax.tree_util.tree_map( + lambda ema, param: decay * ema + (1 - decay) * param, + self.ema_params, + self.params, + ) + return self.replace(ema_params=new_ema_params) + +class DiffusionTrainer(SimpleTrainer): + noise_schedule : NoiseScheduler + model_output_transform:DiffusionPredictionTransform + ema_decay:float = 0.999 + + def __init__(self, + model:nn.Module, + input_shapes:Dict[str, Tuple[int]], + optimizer: optax.GradientTransformation, + noise_schedule:NoiseScheduler, + rngs:jax.random.PRNGKey, + unconditional_prob:float=0.2, + name:str="Diffusion", + model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(), + **kwargs + ): + super().__init__( + model=model, + input_shapes=input_shapes, + optimizer=optimizer, + rngs=rngs, + name=name, + **kwargs + ) + self.noise_schedule = noise_schedule + self.model_output_transform = model_output_transform + self.unconditional_prob = unconditional_prob + + def init_state(self, + optimizer: optax.GradientTransformation, + rngs:jax.random.PRNGKey, + existing_state:dict=None, + existing_best_state:dict=None, + model:nn.Module=None, + param_transforms:Callable=None, + ): + # @partial(jax.pmap, axis_name="device") + def init_fn(rngs): + rngs, subkey = jax.random.split(rngs) + + if existing_state == None: + input_vars = self.get_input_ones() + params = model.init(subkey, **input_vars) + new_state = {"params":params, "ema_params":params} + else: + new_state = existing_state + + if param_transforms is not None: + params = param_transforms(params) + + state = TrainState.create( + apply_fn=model.apply, + params=new_state['params'], + ema_params=new_state['ema_params'], + tx=optimizer, + rngs=rngs, + metrics=Metrics.empty() + ) + return state + + self.best_loss = 1e9 + # self.state = init_fn(jax.device_put_replicated(rngs, jax.devices())) + state = init_fn(rngs) + if existing_best_state is not None: + best_state = state.replace(params=existing_best_state['params'], ema_params=existing_best_state['ema_params']) + else: + best_state = state + + self.state = flax.jax_utils.replicate(state, jax.devices()) + self.best_state = flax.jax_utils.replicate(best_state, jax.devices()) + + def _define_train_step(self, batch_size, null_labels_seq, text_embedder): + noise_schedule = self.noise_schedule + model = self.model + model_output_transform = self.model_output_transform + loss_fn = self.loss_fn + unconditional_prob = self.unconditional_prob + + # Determine the number of unconditional samples + num_unconditional = int(batch_size * unconditional_prob) + + nS, nC = null_labels_seq.shape + null_labels_seq = jnp.broadcast_to(null_labels_seq, (batch_size, nS, nC)) + + # @jax.jit + @partial(jax.pmap, axis_name="device") + def train_step(state:TrainState, batch): + """Train for a single step.""" + images = batch['image'] + # normalize image + images = (images - 127.5) / 127.5 + + output = text_embedder(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']) + # output = infer(input_ids=batch['input_ids'], attention_mask=batch['attention_mask']) + + label_seq = output.last_hidden_state + + # Generate random probabilities to decide how much of this batch will be unconditional + + label_seq = jnp.concat([null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0) + + noise_level, state = noise_schedule.generate_timesteps(images.shape[0], state) + state, rngs = state.get_random_key() + noise:jax.Array = jax.random.normal(rngs, shape=images.shape) + rates = noise_schedule.get_rates(noise_level) + noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(images, noise, rates) + def model_loss(params): + preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq) + preds = model_output_transform.pred_transform(noisy_images, preds, rates) + nloss = loss_fn(preds, expected_output) + # nloss = jnp.mean(nloss, axis=1) + nloss *= noise_schedule.get_weights(noise_level) + nloss = jnp.mean(nloss) + loss = nloss + return loss + loss, grads = jax.value_and_grad(model_loss)(state.params) + grads = jax.lax.pmean(grads, "device") + state = state.apply_gradients(grads=grads) + state = state.apply_ema(self.ema_decay) + return state, loss + return train_step + + def _define_compute_metrics(self): + @jax.jit + def compute_metrics(state:TrainState, expected, pred): + loss = jnp.mean(jnp.square(pred - expected)) + metric_updates = state.metrics.single_from_model_output(loss=loss) + metrics = state.metrics.merge(metric_updates) + state = state.replace(metrics=metrics) + return state + return compute_metrics + + def fit(self, data, steps_per_epoch, epochs): + null_labels_full = data['null_labels_full'] + batch_size = data['batch_size'] + text_embedder = data['model'] + super().fit(data, steps_per_epoch, epochs, {"batch_size":batch_size, "null_labels_seq":null_labels_full, "text_embedder":text_embedder}) + +# %% +BATCH_SIZE = 64 +IMAGE_SIZE = 128 + +cosine_schedule = CosineNoiseSchedule(1000, beta_end=1) +karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5) +edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5) + +experiment_name = "{name}_{date}".format( + name="Diffusion_SDE_VE_TEXT", date=datetime.now().strftime("%Y-%m-%d_%H:%M:%S") +) +# experiment_name = 'Diffusion_SDE_VE_TEXT_2024-07-16_02:16:07' +# experiment_name = 'Diffusion_SDE_VE_TEXT_2024-07-21_02:12:40' +# experiment_name = 'Diffusion_SDE_VE_TEXT_2024-07-30_05:48:22' +# experiment_name = 'Diffusion_SDE_VE_TEXT_2024-08-01_08:59:00' +print("Experiment_Name:", experiment_name) + +dataset_name = "cc12m" +datalen = len(datasetMap[dataset_name]['source']) +batches = datalen // BATCH_SIZE + +config = { + "model" : { + "emb_features":256, + "feature_depths":[64, 128, 256, 512], + "attention_configs":[ + None, + # None, + # None, + # None, + # None, + # {"heads":32, "dtype":jnp.bfloat16, "flash_attention":True, "use_projection":False, "use_self_and_cross":True}, + {"heads":8, "dtype":jnp.bfloat16, "flash_attention":True, "use_projection":False, "use_self_and_cross":False}, + {"heads":8, "dtype":jnp.bfloat16, "flash_attention":True, "use_projection":False, "use_self_and_cross":False}, + {"heads":8, "dtype":jnp.bfloat16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}, + ], + "num_res_blocks":2, + "num_middle_res_blocks":1, + }, + + "dataset": { + "name" : dataset_name, + "length" : datalen, + "batches": batches + }, + "learning_rate": 2e-4, + + "input_shapes": { + "x": (IMAGE_SIZE, IMAGE_SIZE, 3), + "temb": (), + "textcontext": (77, 768) + }, +} + +unet = Unet(**config['model']) + +learning_rate = config['learning_rate'] +solver = optax.adam(learning_rate) +# solver = optax.adamw(2e-6) + +trainer = DiffusionTrainer(unet, optimizer=solver, + input_shapes=config['input_shapes'], + noise_schedule=edm_schedule, + rngs=jax.random.PRNGKey(4), + name=experiment_name, + model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data), + # train_state=trainer.best_state, + # loss_fn=lambda x, y: jnp.abs(x - y), + # param_transforms=params_transform, + # load_from_checkpoint=True, + wandb_config={ + "project": "flaxdiff", + "config": config, + "name": experiment_name, + }, + ) + + +# %% +trainer.summary() + +# %% +data = get_dataset_grain(config['dataset']['name'], batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE) + +# %% +# jax.profiler.start_server(6009) +final_state = trainer.fit(data, 1000, epochs=3) + +# %% +# jax.profiler.start_server(6009) +final_state = trainer.fit(data, 1000, epochs=1) + +# %% +# jax.profiler.start_server(6009) +final_state = trainer.fit(data, 1000, epochs=1) + +# %% +data = get_dataset("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE) +final_state = trainer.fit(data, batches, epochs=4000) + +# %% +from flaxdiff.utils import clip_images + +def clip_images(images, clip_min=-1, clip_max=1): + return jnp.clip(images, clip_min, clip_max) + +class DiffusionSampler(): + model:nn.Module + noise_schedule:NoiseScheduler + params:dict + model_output_transform:DiffusionPredictionTransform + + def __init__(self, model:nn.Module, params:dict, + noise_schedule:NoiseScheduler, + model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(), + guidance_scale:float = 0.0, + null_labels_seq:jax.Array=None + ): + self.model = model + self.noise_schedule = noise_schedule + self.params = params + self.model_output_transform = model_output_transform + self.guidance_scale = guidance_scale + if self.guidance_scale > 0: + # Classifier free guidance + assert null_labels_seq is not None, "Null labels sequence is required for classifier-free guidance" + print("Using classifier-free guidance") + @jax.jit + def sample_model(x_t, t, *additional_inputs): + # Concatenate unconditional and conditional inputs + x_t_cat = jnp.concatenate([x_t] * 2, axis=0) + t_cat = jnp.concatenate([t] * 2, axis=0) + rates_cat = self.noise_schedule.get_rates(t_cat) + c_in_cat = self.model_output_transform.get_input_scale(rates_cat) + + text_labels_seq, = additional_inputs + text_labels_seq = jnp.concatenate([text_labels_seq, jnp.broadcast_to(null_labels_seq, text_labels_seq.shape)], axis=0) + model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq) + # Split model output into unconditional and conditional parts + model_output_cond, model_output_uncond = jnp.split(model_output, 2, axis=0) + model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond) + + x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule) + return x_0, eps, model_output + + self.sample_model = sample_model + else: + # Unconditional sampling + @jax.jit + def sample_model(x_t, t, *additional_inputs): + rates = self.noise_schedule.get_rates(t) + c_in = self.model_output_transform.get_input_scale(rates) + model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs) + x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule) + return x_0, eps, model_output + + self.sample_model = sample_model + + # Used to sample from the diffusion model + def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]: + # First clip the noisy images + step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32) + current_step = step_ones * current_step + next_step = step_ones * next_step + pred_images, pred_noise, _ = self.sample_model(current_samples, current_step, *model_conditioning_inputs) + # plotImages(pred_images) + pred_images = clip_images(pred_images) + new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images, + pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state, + model_conditioning_inputs=model_conditioning_inputs + ) + return new_samples, state + + def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, + pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: + # estimate the q(x_{t-1} | x_t, x_0). + # pred_images is x_0, noisy_images is x_t, steps is t + return NotImplementedError + + def scale_steps(self, steps): + scale_factor = self.noise_schedule.max_timesteps / 1000 + return steps * scale_factor + + def get_steps(self, start_step, end_step, diffusion_steps): + step_range = start_step - end_step + if diffusion_steps is None or diffusion_steps == 0: + diffusion_steps = start_step - end_step + diffusion_steps = min(diffusion_steps, step_range) + steps = jnp.linspace(end_step, start_step, diffusion_steps, dtype=jnp.int16)[::-1] + return steps + + def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step): + start_step = self.scale_steps(start_step) + alpha_n, sigma_n = self.noise_schedule.get_rates(start_step) + variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2) + return jax.random.normal(rngs, (num_images, IMAGE_SIZE, IMAGE_SIZE, 3)) * variance + + def generate_images(self, + num_images=16, + diffusion_steps=1000, + start_step:int = None, + end_step:int = 0, + steps_override=None, + priors=None, + rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42)), + model_conditioning_inputs:tuple=() + ) -> jnp.ndarray: + if priors is None: + rngstate, newrngs = rngstate.get_random_key() + samples = self.get_initial_samples(num_images, newrngs, start_step) + else: + print("Using priors") + samples = priors + + # @jax.jit + def sample_step(state:RandomMarkovState, samples, current_step, next_step): + samples, state = self.sample_step(current_samples=samples, + current_step=current_step, + model_conditioning_inputs=model_conditioning_inputs, + state=state, next_step=next_step) + return samples, state + + if start_step is None: + start_step = self.noise_schedule.max_timesteps + + if steps_override is not None: + steps = steps_override + else: + steps = self.get_steps(start_step, end_step, diffusion_steps) + + # print("Sampling steps", steps) + for i in tqdm.tqdm(range(0, len(steps))): + current_step = self.scale_steps(steps[i]) + next_step = self.scale_steps(steps[i+1] if i+1 < len(steps) else 0) + if i != len(steps) - 1: + # print("normal step") + samples, rngstate = sample_step(rngstate, samples, current_step, next_step) + else: + # print("last step") + step_ones = jnp.ones((num_images, ), dtype=jnp.int32) + samples, _, _ = self.sample_model(samples, current_step * step_ones, *model_conditioning_inputs) + samples = clip_images(samples) + return samples + +class DDPMSampler(DiffusionSampler): + def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, + pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: + mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step) + variance = self.noise_schedule.get_posterior_variance(steps=current_step) + + state, rng = state.get_random_key() + # Now sample from the posterior + noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32) + + return mean + noise * variance, state + + def generate_images(self, num_images=16, diffusion_steps=1000, start_step: int = None, *args, **kwargs): + return super().generate_images(num_images=num_images, diffusion_steps=diffusion_steps, start_step=start_step, *args, **kwargs) + +class SimpleDDPMSampler(DiffusionSampler): + def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, + pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: + state, rng = state.get_random_key() + noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32) + + # Compute noise rates and signal rates only once + current_signal_rate, current_noise_rate = self.noise_schedule.get_rates(current_step) + next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step) + + pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate) + + noise_ratio_squared = (next_noise_rate ** 2) / (current_noise_rate ** 2) + signal_ratio_squared = (current_signal_rate ** 2) / (next_signal_rate ** 2) + gamma = jnp.sqrt(noise_ratio_squared * (1 - signal_ratio_squared)) + + next_samples = next_signal_rate * reconstructed_samples + pred_noise_coeff * pred_noise + noise * gamma + return next_samples, state + +class DDIMSampler(DiffusionSampler): + def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, + pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: + next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step) + return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state + +class EulerSampler(DiffusionSampler): + # Basically a DDIM Sampler but parameterized as an ODE + def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, + pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: + current_alpha, current_sigma = self.noise_schedule.get_rates(current_step) + next_alpha, next_sigma = self.noise_schedule.get_rates(next_step) + + dt = next_sigma - current_sigma + + x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / (dt) + dx = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma + next_samples = current_samples + dx * dt + return next_samples, state + +class SimplifiedEulerSampler(DiffusionSampler): + """ + This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t + """ + def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, + pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: + _, current_sigma = self.noise_schedule.get_rates(current_step) + _, next_sigma = self.noise_schedule.get_rates(next_step) + + dt = next_sigma - current_sigma + + dx = (current_samples - reconstructed_samples) / current_sigma + next_samples = current_samples + dx * dt + return next_samples, state + +class HeunSampler(DiffusionSampler): + def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, + pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: + # Get the noise and signal rates for the current and next steps + current_alpha, current_sigma = self.noise_schedule.get_rates(current_step) + next_alpha, next_sigma = self.noise_schedule.get_rates(next_step) + + dt = next_sigma - current_sigma + x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / dt + + dx_0 = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma + next_samples_0 = current_samples + dx_0 * dt + + # Recompute x_0 and eps at the first estimate to refine the derivative + estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step, *model_conditioning_inputs) + + # Estimate the refined derivative using the midpoint (Heun's method) + dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma + # Compute the final next samples by averaging the initial and refined derivatives + final_next_samples = current_samples + 0.5 * (dx_0 + dx_1) * dt + + return final_next_samples, state + +class RK4Sampler(DiffusionSampler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert issubclass(type(self.noise_schedule), GeneralizedNoiseScheduler), "Noise schedule must be a GeneralizedNoiseScheduler" + @jax.jit + def get_derivative(x_t, sigma, state:RandomMarkovState, model_conditioning_inputs) -> tuple[jnp.ndarray, RandomMarkovState]: + t = self.noise_schedule.get_timesteps(sigma) + x_0, eps, _ = self.sample_model(x_t, t, *model_conditioning_inputs) + return eps, state + + self.get_derivative = get_derivative + + def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]: + step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32) + current_step = step_ones * current_step + next_step = step_ones * next_step + _, current_sigma = self.noise_schedule.get_rates(current_step) + _, next_sigma = self.noise_schedule.get_rates(next_step) + + dt = next_sigma - current_sigma + + k1, state = self.get_derivative(current_samples, current_sigma, state, model_conditioning_inputs) + k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs) + k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs) + k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state, model_conditioning_inputs) + + next_samples = current_samples + (((k1 + 2 * k2 + 2 * k3 + k4) * dt) / 6) + return next_samples, state + +class MultiStepDPM(DiffusionSampler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.history = [] + + def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, + pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: + # Get the noise and signal rates for the current and next steps + current_alpha, current_sigma = self.noise_schedule.get_rates(current_step) + next_alpha, next_sigma = self.noise_schedule.get_rates(next_step) + + dt = next_sigma - current_sigma + + def first_order(current_noise, current_sigma): + dx = current_noise + return dx + + def second_order(current_noise, current_sigma, last_noise, last_sigma): + dx_2 = (current_noise - last_noise) / (current_sigma - last_sigma) + return dx_2 + + def third_order(current_noise, current_sigma, last_noise, last_sigma, second_last_noise, second_last_sigma): + dx_2 = second_order(current_noise, current_sigma, last_noise, last_sigma) + dx_2_last = second_order(last_noise, last_sigma, second_last_noise, second_last_sigma) + + dx_3 = (dx_2 - dx_2_last) / (0.5 * ((current_sigma + last_sigma) - (last_sigma + second_last_sigma))) + + return dx_3 + + if len(self.history) == 0: + # First order only + dx_1 = first_order(pred_noise, current_sigma) + next_samples = current_samples + dx_1 * dt + elif len(self.history) == 1: + # First + Second order + dx_1 = first_order(pred_noise, current_sigma) + last_step = self.history[-1] + dx_2 = second_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma']) + next_samples = current_samples + dx_1 * dt + 0.5 * dx_2 * dt**2 + else: + # First + Second + Third order + last_step = self.history[-1] + second_last_step = self.history[-2] + + dx_1 = first_order(pred_noise, current_sigma) + dx_2 = second_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma']) + dx_3 = third_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'], second_last_step['eps'], second_last_step['sigma']) + next_samples = current_samples + (dx_1 * dt) + (0.5 * dx_2 * dt**2) + ((1/6) * dx_3 * dt**3) + + self.history.append({ + "eps": pred_noise, + "sigma" : current_sigma, + }) + return next_samples, state + +class EulerAncestralSampler(DiffusionSampler): + def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, + pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: + current_alpha, current_sigma = self.noise_schedule.get_rates(current_step) + next_alpha, next_sigma = self.noise_schedule.get_rates(next_step) + + sigma_up = (next_sigma**2 * (current_sigma**2 - next_sigma**2) / current_sigma**2) ** 0.5 + sigma_down = (next_sigma**2 - sigma_up**2) ** 0.5 + + dt = sigma_down - current_sigma + + x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / (next_sigma - current_sigma) + dx = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma + + state, subkey = state.get_random_key() + dW = jax.random.normal(subkey, current_samples.shape) * sigma_up + + next_samples = current_samples + dx * dt + dW + return next_samples, state + +# %% +images = next(iter(data)) +plotImages(images, dpi=300) +print(images.shape) +noise_schedule = karas_ve_schedule +predictor = trainer.model_output_transform + +rng = jax.random.PRNGKey(4) +noise = jax.random.normal(rng, shape=images.shape, dtype=jnp.float32) +noise_level = 0.9999 +noise_levels = jnp.ones((images.shape[0],), dtype=jnp.int32) * noise_level + +rates = noise_schedule.get_rates(noise_levels) +noisy_images, c_in, expected_output = predictor.forward_diffusion(images, noise, rates=rates) +plotImages(noisy_images) +print(jnp.mean(noisy_images), jnp.std(noisy_images)) +regenerated_images = noise_schedule.remove_all_noise(noisy_images, noise, noise_levels) +plotImages(regenerated_images) + +sampler = EulerSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform) +samples = sampler.generate_images(num_images=16, diffusion_steps=20, start_step=int(noise_level*1000), end_step=0, priors=None) +plotImages(samples, dpi=300) + +# %% +textEncoderModel, textTokenizer = defaultTextEncodeModel() + +# %% +prompts = [ + 'water tulip', + 'a water lily', + 'a water lily', + 'a photo of a rose' + ] +pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) + +sampler = EulerAncestralSampler(trainer.model, trainer.get_state().ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=2, null_labels_seq=data['null_labels_full']) +samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) +plotImages(samples, dpi=300) + + +# %% +prompts = [ + 'water tulip', + 'a water lily', + 'a water lily', + 'a photo of a rose' + ] +pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) + +sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=2, null_labels_seq=data['null_labels_full']) +samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) +plotImages(samples, dpi=300) + + +# %% +prompts = [ + 'water tulip', + 'a water lily', + 'a water lily', + 'a photo of a rose' + ] +pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) + +sampler = EulerAncestralSampler(trainer.model, trainer.best_state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=2, null_labels_seq=data['null_labels_full']) +samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) +plotImages(samples, dpi=300) + + +# %% +prompts = [ + 'water tulip', + 'a water lily', + 'a water lily', + 'a water lily', + 'a photo of a marigold', + 'a water lily', + 'a water lily', + 'a photo of a lotus', + 'a photo of a lotus', + 'a photo of a lotus', + 'a photo of a rose', + 'a photo of a rose', + 'a photo of a rose', + 'a photo of a rose', + 'a photo of a rose', + ] +pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) + +sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=2, null_labels_seq=data['null_labels_full']) +samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) +plotImages(samples, dpi=300) + + +# %% +prompts = [ + 'water tulip', + 'a water lily', + 'a water lily', + 'a water lily', + 'a photo of a marigold', + 'a water lily', + 'a water lily', + 'a photo of a lotus', + 'a photo of a lotus', + 'a photo of a lotus', + 'a photo of a rose', + 'a photo of a rose', + 'a photo of a rose', + 'a photo of a rose', + 'a photo of a rose', + ] +pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) + +sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=2, null_labels_seq=data['null_labels_full']) +samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) +plotImages(samples, dpi=300) + + +# %% +prompts = [ + 'water tulip', + 'a water lily', + 'a water lily', + 'a water lily', + 'a photo of a marigold', + 'a water lily', + 'a water lily', + 'a photo of a lotus', + 'a photo of a lotus', + 'a photo of a lotus', + 'a photo of a rose', + 'a photo of a rose', + 'a photo of a rose', + 'a photo of a rose', + 'a photo of a rose', + ] +pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) + +sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=2, null_labels_seq=data['null_labels_full']) +samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) +plotImages(samples, dpi=300) + +# %% +prompts = [ + 'water tulip', + 'a water lily', + 'a water lily', + 'a water lily', + 'a photo of a marigold', + 'a water lily', + 'a water lily', + 'a photo of a lotus', + 'a photo of a lotus', + 'a photo of a lotus', + 'a photo of a rose', + 'a photo of a rose', + 'a photo of a rose', + 'a photo of a rose', + 'a photo of a rose', + ] +pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) + +sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=4, null_labels_seq=data['null_labels_full']) +samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) +plotImages(samples, dpi=500, fig_size=(4, 5)) + +# %% +prompts = [ + 'water tulip', + 'a green water rose', + 'a green water rose', + 'a green water rose', + 'a water red rose', + 'a marigold and rose hybrid', + 'a marigold and rose hybrid', + 'a marigold and rose hybrid', + 'a water lily and a marigold', + 'a water lily and a marigold', + 'a water lily and a marigold', + 'a water lily and a marigold', + ] +pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) + +sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=3, null_labels_seq=data['null_labels_full']) +samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) +plotImages(samples, dpi=300) + +# %% +prompts = [ + 'water tulip', + 'a water lily', + 'a water lily', + 'a photo of a rose', + 'a photo of a rose', + 'a water lily', + 'a water lily', + 'a photo of a marigold', + 'a photo of a marigold', + 'a photo of a marigold', + 'a water lily', + 'a photo of a sunflower', + 'a photo of a lotus', + "columbine", + "columbine", + "an orchid", + "an orchid", + "an orchid", + 'a water lily', + 'a water lily', + 'a water lily', + "columbine", + "columbine", + 'a photo of a sunflower', + 'a photo of a sunflower', + 'a photo of a sunflower', + 'a photo of a lotus', + 'a photo of a lotus', + 'a photo of a marigold', + 'a photo of a marigold', + 'a photo of a rose', + 'a photo of a rose', + 'a photo of a rose', + "orange dahlia", + "orange dahlia", + "a lenten rose", + "a lenten rose", + 'a water lily', + 'a water lily', + 'a water lily', + 'a water lily', + "an orchid", + "an orchid", + "an orchid", + 'hard-leaved pocket orchid', + "bird of paradise", + "bird of paradise", + "a photo of a lovely rose", + "a photo of a lovely rose", + "a photo of a globe-flower", + "a photo of a globe-flower", + "a photo of a lovely rose", + "a photo of a lovely rose", + "a photo of a ruby-lipped cattleya", + "a photo of a ruby-lipped cattleya", + "a photo of a lovely rose", + 'a water lily', + 'a osteospermum', + 'a osteospermum', + 'a water lily', + 'a water lily', + 'a water lily', + "a red rose", + "a red rose", + ] +pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) + +sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=4, null_labels_seq=data['null_labels_full']) +samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) +plotImages(samples, dpi=300) + +# %% +dataToLabelGenMap["oxford_flowers102"]() + +# %% + + +# %% +sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform) +samples = sampler.generate_images(num_images=64, diffusion_steps=200, start_step=1000, end_step=0, priors=None) +plotImages(samples, dpi=300) + +# %% +sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform) +samples = sampler.generate_images(num_images=64, diffusion_steps=200, start_step=1000, end_step=0, priors=None) +plotImages(samples, dpi=300) + +# %% +sampler = RK4Sampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform) +samples = sampler.generate_images(num_images=64, diffusion_steps=6, start_step=1000, end_step=0, priors=None) +plotImages(samples, dpi=300) + +# %% +sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform) +samples = sampler.generate_images(num_images=64, diffusion_steps=300, start_step=1000, end_step=0, priors=None) +plotImages(samples, dpi=300) + +# %% +sampler = DDPMSampler(trainer.model, trainer.state.params, trainer.noise_schedule, model_output_transform=trainer.model_output_transform) +samples = sampler.generate_images(num_images=16, start_step=1000, priors=None) +plotImages(samples, dpi=300) + +# %% +sampler = DDPMSampler(trainer.model, trainer.best_state.params, trainer.noise_schedule, model_output_transform=trainer.model_output_transform) +samples = sampler.generate_images(num_images=16, start_step=1000, priors=None) +plotImages(samples, dpi=300) + +# %% +sampler = DDPMSampler(trainer.model, trainer.best_state.params, trainer.noise_schedule, model_output_transform=trainer.model_output_transform) +samples = sampler.generate_images(num_images=64, start_step=1000, priors=None) +plotImages(samples) + +# %% [markdown] +# + +