Skip to content

Commit

Permalink
feat: training script completed
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 1, 2024
1 parent 659f604 commit 3168c40
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 1,234 deletions.
133 changes: 105 additions & 28 deletions Diffusion flax linen on TPUs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@
" data_source=data_source,\n",
" sampler=sampler,\n",
" operations=transformations,\n",
" worker_count=120,\n",
" worker_count=32,\n",
" read_options=pygrain.ReadOptions(64, 50),\n",
" worker_buffer_size=20\n",
" )\n",
Expand Down Expand Up @@ -1898,20 +1898,20 @@
},
{
"cell_type": "code",
"execution_count": 104,
"execution_count": 111,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Experiment_Name: Diffusion_SDE_VE_TEXT_2024-08-01_15:42:34\n"
"Experiment_Name: Diffusion_SDE_VE_TEXT_2024-08-01_16:53:30\n"
]
},
{
"data": {
"text/html": [
"Finishing last run (ID:6rmw3wuq) before initializing another..."
"Finishing last run (ID:2wlp36d3) before initializing another..."
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -1928,7 +1928,7 @@
" .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
" .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
" </style>\n",
"<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>train/loss</td><td>█▄▄▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>train/loss</td><td>0.07254</td></tr></table><br/></div></div>"
"<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>train/loss</td><td>█▅▆▅▅▅▂▃▂▂▂▁▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▆▅▆▅▆▃▅▅</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>train/loss</td><td>0.36975</td></tr></table><br/></div></div>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -1940,7 +1940,7 @@
{
"data": {
"text/html": [
" View run <strong style=\"color:#cdcd00\">Diffusion_SDE_VE_TEXT_2024-08-01_15:17:19</strong> at: <a href='https://wandb.ai/ashishkumar4/flaxdiff/runs/6rmw3wuq' target=\"_blank\">https://wandb.ai/ashishkumar4/flaxdiff/runs/6rmw3wuq</a><br/> View project at: <a href='https://wandb.ai/ashishkumar4/flaxdiff' target=\"_blank\">https://wandb.ai/ashishkumar4/flaxdiff</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
" View run <strong style=\"color:#cdcd00\">Diffusion_SDE_VE_TEXT_2024-08-01_15:42:34</strong> at: <a href='https://wandb.ai/ashishkumar4/flaxdiff/runs/2wlp36d3' target=\"_blank\">https://wandb.ai/ashishkumar4/flaxdiff/runs/2wlp36d3</a><br/> View project at: <a href='https://wandb.ai/ashishkumar4/flaxdiff' target=\"_blank\">https://wandb.ai/ashishkumar4/flaxdiff</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -1952,7 +1952,7 @@
{
"data": {
"text/html": [
"Find logs at: <code>./wandb/run-20240801_151719-6rmw3wuq/logs</code>"
"Find logs at: <code>./wandb/run-20240801_154234-2wlp36d3/logs</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -1976,7 +1976,7 @@
{
"data": {
"text/html": [
"Successfully finished last run (ID:6rmw3wuq). Initializing new run:<br/>"
"Successfully finished last run (ID:2wlp36d3). Initializing new run:<br/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -2000,7 +2000,7 @@
{
"data": {
"text/html": [
"Run data is saved locally in <code>/home/mrwhite0racle/research/FlaxDiff/wandb/run-20240801_154234-2wlp36d3</code>"
"Run data is saved locally in <code>/home/mrwhite0racle/research/FlaxDiff/wandb/run-20240801_165330-auhv65gw</code>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -2012,7 +2012,7 @@
{
"data": {
"text/html": [
"Syncing run <strong><a href='https://wandb.ai/ashishkumar4/flaxdiff/runs/2wlp36d3' target=\"_blank\">Diffusion_SDE_VE_TEXT_2024-08-01_15:42:34</a></strong> to <a href='https://wandb.ai/ashishkumar4/flaxdiff' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
"Syncing run <strong><a href='https://wandb.ai/ashishkumar4/flaxdiff/runs/auhv65gw' target=\"_blank\">Diffusion_SDE_VE_TEXT_2024-08-01_16:53:30</a></strong> to <a href='https://wandb.ai/ashishkumar4/flaxdiff' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -2036,7 +2036,7 @@
{
"data": {
"text/html": [
" View run at <a href='https://wandb.ai/ashishkumar4/flaxdiff/runs/2wlp36d3' target=\"_blank\">https://wandb.ai/ashishkumar4/flaxdiff/runs/2wlp36d3</a>"
" View run at <a href='https://wandb.ai/ashishkumar4/flaxdiff/runs/auhv65gw' target=\"_blank\">https://wandb.ai/ashishkumar4/flaxdiff/runs/auhv65gw</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
Expand All @@ -2054,7 +2054,7 @@
}
],
"source": [
"BATCH_SIZE = 32\n",
"BATCH_SIZE = 64\n",
"IMAGE_SIZE = 128\n",
"\n",
"cosine_schedule = CosineNoiseSchedule(1000, beta_end=1)\n",
Expand Down Expand Up @@ -2142,7 +2142,7 @@
},
{
"cell_type": "code",
"execution_count": 105,
"execution_count": 112,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -2151,12 +2151,101 @@
},
{
"cell_type": "code",
"execution_count": 106,
"execution_count": 113,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Epoch 1/3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\t\tEpoch 1: 100%|█████████████████████████████████| 1000/1000 [07:26<00:00, 2.24step/s, loss=0.0959]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\tEpoch done\n",
"Saving model at epoch 1\n",
"\n",
"\tEpoch 1 completed. Avg Loss: 0.21474403142929077, Time: 446.66s, Best Loss: 0.21474403142929077 \n",
"\n",
"Epoch 2/3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\t\tEpoch 2: 100%|█████████████████████████████████| 1000/1000 [05:37<00:00, 2.96step/s, loss=0.0794]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\tEpoch done\n",
"Saving model at epoch 2\n",
"\n",
"\tEpoch 2 completed. Avg Loss: 0.08710786700248718, Time: 337.44s, Best Loss: 0.08710786700248718 \n",
"\n",
"Epoch 3/3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\t\tEpoch 3: 100%|█████████████████████████████████| 1000/1000 [05:38<00:00, 2.95step/s, loss=0.0765]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\tEpoch done\n",
"Saving model at epoch 3\n",
"\n",
"\tEpoch 3 completed. Avg Loss: 0.0747244730591774, Time: 338.51s, Best Loss: 0.0747244730591774 \n",
"\n",
"Epoch 4/3\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\t\tEpoch 4: 100%|█████████████████████████████████| 1000/1000 [05:37<00:00, 2.96step/s, loss=0.0638]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\tEpoch done\n",
"Saving model at epoch 4\n",
"\n",
"\tEpoch 4 completed. Avg Loss: 0.06880037486553192, Time: 337.86s, Best Loss: 0.06880037486553192 \n",
"Saving model at epoch 3\n",
"Error saving checkpoint Checkpoint for step 3 already exists.\n"
]
}
],
"source": [
"# jax.profiler.start_server(6009)\n",
"final_state = trainer.fit(data, 1000, epochs=1)"
"final_state = trainer.fit(data, 1000, epochs=3)"
]
},
{
Expand Down Expand Up @@ -7061,18 +7150,6 @@
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
Expand Down
2 changes: 2 additions & 0 deletions setup_tpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ pip install jax[tpu] flax[all] -f https://storage.googleapis.com/jax-releases/li
# Install CPU version of tensorflow
pip install tensorflow[cpu] keras orbax optax clu grain augmax transformers opencv-python pandas tensorflow-datasets jupyterlab python-dotenv scikit-learn termcolor wrapt wandb

pip install flaxdiff

wget https://secure.nic.cz/files/knot-resolver/knot-resolver-release.deb
sudo dpkg -i knot-resolver-release.deb
sudo apt update
Expand Down
Loading

0 comments on commit 3168c40

Please sign in to comment.