From 3168c4027d3a0131c73630bc7eceb194d2708473 Mon Sep 17 00:00:00 2001 From: Ashish Kumar Singh Date: Thu, 1 Aug 2024 20:43:30 +0000 Subject: [PATCH] feat: training script completed --- Diffusion flax linen on TPUs.ipynb | 133 ++- setup_tpu.sh | 2 + training_tpu.py | 1332 +++------------------------- 3 files changed, 233 insertions(+), 1234 deletions(-) diff --git a/Diffusion flax linen on TPUs.ipynb b/Diffusion flax linen on TPUs.ipynb index a13bdfd..3658a40 100644 --- a/Diffusion flax linen on TPUs.ipynb +++ b/Diffusion flax linen on TPUs.ipynb @@ -488,7 +488,7 @@ " data_source=data_source,\n", " sampler=sampler,\n", " operations=transformations,\n", - " worker_count=120,\n", + " worker_count=32,\n", " read_options=pygrain.ReadOptions(64, 50),\n", " worker_buffer_size=20\n", " )\n", @@ -1898,20 +1898,20 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 111, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Experiment_Name: Diffusion_SDE_VE_TEXT_2024-08-01_15:42:34\n" + "Experiment_Name: Diffusion_SDE_VE_TEXT_2024-08-01_16:53:30\n" ] }, { "data": { "text/html": [ - "Finishing last run (ID:6rmw3wuq) before initializing another..." + "Finishing last run (ID:2wlp36d3) before initializing another..." ], "text/plain": [ "" @@ -1928,7 +1928,7 @@ " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n", " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n", " \n", - "

Run history:


train/loss█▄▄▄▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁

Run summary:


train/loss0.07254

" + "

Run history:


train/loss█▅▆▅▅▅▂▃▂▂▂▁▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▆▅▆▅▆▃▅▅

Run summary:


train/loss0.36975

" ], "text/plain": [ "" @@ -1940,7 +1940,7 @@ { "data": { "text/html": [ - " View run Diffusion_SDE_VE_TEXT_2024-08-01_15:17:19 at: https://wandb.ai/ashishkumar4/flaxdiff/runs/6rmw3wuq
View project at: https://wandb.ai/ashishkumar4/flaxdiff
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" + " View run Diffusion_SDE_VE_TEXT_2024-08-01_15:42:34 at: https://wandb.ai/ashishkumar4/flaxdiff/runs/2wlp36d3
View project at: https://wandb.ai/ashishkumar4/flaxdiff
Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)" ], "text/plain": [ "" @@ -1952,7 +1952,7 @@ { "data": { "text/html": [ - "Find logs at: ./wandb/run-20240801_151719-6rmw3wuq/logs" + "Find logs at: ./wandb/run-20240801_154234-2wlp36d3/logs" ], "text/plain": [ "" @@ -1976,7 +1976,7 @@ { "data": { "text/html": [ - "Successfully finished last run (ID:6rmw3wuq). Initializing new run:
" + "Successfully finished last run (ID:2wlp36d3). Initializing new run:
" ], "text/plain": [ "" @@ -2000,7 +2000,7 @@ { "data": { "text/html": [ - "Run data is saved locally in /home/mrwhite0racle/research/FlaxDiff/wandb/run-20240801_154234-2wlp36d3" + "Run data is saved locally in /home/mrwhite0racle/research/FlaxDiff/wandb/run-20240801_165330-auhv65gw" ], "text/plain": [ "" @@ -2012,7 +2012,7 @@ { "data": { "text/html": [ - "Syncing run Diffusion_SDE_VE_TEXT_2024-08-01_15:42:34 to Weights & Biases (docs)
" + "Syncing run Diffusion_SDE_VE_TEXT_2024-08-01_16:53:30 to Weights & Biases (docs)
" ], "text/plain": [ "" @@ -2036,7 +2036,7 @@ { "data": { "text/html": [ - " View run at https://wandb.ai/ashishkumar4/flaxdiff/runs/2wlp36d3" + " View run at https://wandb.ai/ashishkumar4/flaxdiff/runs/auhv65gw" ], "text/plain": [ "" @@ -2054,7 +2054,7 @@ } ], "source": [ - "BATCH_SIZE = 32\n", + "BATCH_SIZE = 64\n", "IMAGE_SIZE = 128\n", "\n", "cosine_schedule = CosineNoiseSchedule(1000, beta_end=1)\n", @@ -2142,7 +2142,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 112, "metadata": {}, "outputs": [], "source": [ @@ -2151,12 +2151,101 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 113, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch 1/3\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t\tEpoch 1: 100%|█████████████████████████████████| 1000/1000 [07:26<00:00, 2.24step/s, loss=0.0959]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\tEpoch done\n", + "Saving model at epoch 1\n", + "\n", + "\tEpoch 1 completed. Avg Loss: 0.21474403142929077, Time: 446.66s, Best Loss: 0.21474403142929077 \n", + "\n", + "Epoch 2/3\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t\tEpoch 2: 100%|█████████████████████████████████| 1000/1000 [05:37<00:00, 2.96step/s, loss=0.0794]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\tEpoch done\n", + "Saving model at epoch 2\n", + "\n", + "\tEpoch 2 completed. Avg Loss: 0.08710786700248718, Time: 337.44s, Best Loss: 0.08710786700248718 \n", + "\n", + "Epoch 3/3\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t\tEpoch 3: 100%|█████████████████████████████████| 1000/1000 [05:38<00:00, 2.95step/s, loss=0.0765]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\tEpoch done\n", + "Saving model at epoch 3\n", + "\n", + "\tEpoch 3 completed. Avg Loss: 0.0747244730591774, Time: 338.51s, Best Loss: 0.0747244730591774 \n", + "\n", + "Epoch 4/3\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\t\tEpoch 4: 100%|█████████████████████████████████| 1000/1000 [05:37<00:00, 2.96step/s, loss=0.0638]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\tEpoch done\n", + "Saving model at epoch 4\n", + "\n", + "\tEpoch 4 completed. Avg Loss: 0.06880037486553192, Time: 337.86s, Best Loss: 0.06880037486553192 \n", + "Saving model at epoch 3\n", + "Error saving checkpoint Checkpoint for step 3 already exists.\n" + ] + } + ], "source": [ "# jax.profiler.start_server(6009)\n", - "final_state = trainer.fit(data, 1000, epochs=1)" + "final_state = trainer.fit(data, 1000, epochs=3)" ] }, { @@ -7061,18 +7150,6 @@ "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" } }, "nbformat": 4, diff --git a/setup_tpu.sh b/setup_tpu.sh index 9b8734c..1b28949 100755 --- a/setup_tpu.sh +++ b/setup_tpu.sh @@ -6,6 +6,8 @@ pip install jax[tpu] flax[all] -f https://storage.googleapis.com/jax-releases/li # Install CPU version of tensorflow pip install tensorflow[cpu] keras orbax optax clu grain augmax transformers opencv-python pandas tensorflow-datasets jupyterlab python-dotenv scikit-learn termcolor wrapt wandb +pip install flaxdiff + wget https://secure.nic.cz/files/knot-resolver/knot-resolver-release.deb sudo dpkg -i knot-resolver-release.deb sudo apt update diff --git a/training_tpu.py b/training_tpu.py index 61d4b0e..1560ca9 100644 --- a/training_tpu.py +++ b/training_tpu.py @@ -1,7 +1,3 @@ - -%load_ext dotenv -%dotenv - import flax import tqdm from flax import linen as nn @@ -9,7 +5,6 @@ from typing import Dict, Callable, Sequence, Any, Union from dataclasses import field import jax.numpy as jnp -import tensorflow_datasets as tfds import grain.python as pygrain import tensorflow as tf import numpy as np @@ -25,23 +20,13 @@ from datetime import datetime from flax.training import orbax_utils import functools -from tensorflow_datasets.core.utils import gcs_utils -gcs_utils._is_gcs_disabled = True + import json # For CLIP from transformers import AutoTokenizer, FlaxCLIPTextModel, CLIPTextModel import wandb -# %% [markdown] -# # Global Variables -##################################################################################################################### -############################################## Globasl Variables #################################################### -##################################################################################################################### - -GRAIN_WORKER_COUNT = 16 -GRAIN_READ_THREAD_COUNT = 64 -GRAIN_READ_BUFFER_SIZE = 50 -GRAIN_WORKER_BUFFER_SIZE = 20 +import argparse # %% [markdown] # # Initialization @@ -145,6 +130,7 @@ def __repr__(self): # %% def data_source_tfds(name): + import tensorflow_datasets as tfds def data_source(): return tfds.load(name, split="all", shuffle_files=True) return data_source @@ -172,11 +158,11 @@ def labelizer_cc12m(sample): datasetMap = { "oxford_flowers102": { "source":data_source_tfds("oxford_flowers102"), - "labelizer":labelizer_oxford_flowers102("/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt"), + "labelizer":lambda : labelizer_oxford_flowers102("/home/mrwhite0racle/tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt"), }, "cc12m": { "source":data_source_cc12m(), - "labelizer":labelizer_cc12m, + "labelizer":lambda : labelizer_cc12m, } } @@ -206,10 +192,12 @@ def get_dataset_grain(data_name="oxford_flowers102", batch_size=64, image_scale=256, count=None, num_epochs=None, text_encoders=defaultTextEncodeModel(), - method=jax.image.ResizeMethod.LANCZOS3): + method=jax.image.ResizeMethod.LANCZOS3, + grain_worker_count=32, grain_read_thread_count=64, + grain_read_buffer_size=50, grain_worker_buffer_size=20): dataset = datasetMap[data_name] data_source = dataset["source"]() - labelizer = dataset["labelizer"] + labelizer = dataset["labelizer"]() import cv2 @@ -253,9 +241,9 @@ def map(self, element) -> Dict[str, jnp.array]: data_source=data_source, sampler=sampler, operations=transformations, - worker_count=GRAIN_WORKER_COUNT, - read_options=pygrain.ReadOptions(GRAIN_READ_THREAD_COUNT, GRAIN_READ_BUFFER_SIZE), - worker_buffer_size=GRAIN_WORKER_BUFFER_SIZE + worker_count=grain_worker_count, + read_options=pygrain.ReadOptions(grain_read_thread_count, grain_read_buffer_size), + worker_buffer_size=grain_worker_buffer_size ) def get_trainset(): @@ -275,454 +263,18 @@ def get_trainset(): from flaxdiff.schedulers import CosineNoiseSchedule, NoiseScheduler, GeneralizedNoiseScheduler, KarrasVENoiseScheduler, EDMNoiseScheduler from flaxdiff.predictors import VPredictionTransform, EpsilonPredictionTransform, DiffusionPredictionTransform, DirectPredictionTransform, KarrasPredictionTransform -# %% [markdown] -# # Modeling - -# %% [markdown] -# ## Metrics - -# %% [markdown] -# ## Callbacks - -# %% [markdown] -# ## Model Generator - # %% import jax.experimental.pallas.ops.tpu.flash_attention -from flaxdiff.models.simple_unet import l2norm, ConvLayer, TimeEmbedding, TimeProjection, Upsample, Downsample, ResidualBlock, PixelShuffle +from flaxdiff.models.simple_unet import ConvLayer, TimeProjection, Upsample, Downsample, ResidualBlock from flaxdiff.models.simple_unet import FourierEmbedding -from flaxdiff.models.attention import kernel_init -# from flash_attn_jax import flash_mha -# from flaxdiff.models.favor_fastattn import make_fast_generalized_attention, make_fast_softmax_attention +from flaxdiff.models.attention import kernel_init, TransformerBlock # Kernel initializer to use def kernel_init(scale, dtype=jnp.float32): scale = max(scale, 1e-10) return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype) -class EfficientAttention(nn.Module): - """ - Based on the pallas attention implementation. - """ - query_dim: int - heads: int = 4 - dim_head: int = 64 - dtype: Any = jnp.float32 - precision: Any = jax.lax.Precision.HIGHEST - use_bias: bool = True - kernel_init: Callable = lambda : kernel_init(1.0) - - def setup(self): - inner_dim = self.dim_head * self.heads - # Weights were exported with old names {to_q, to_k, to_v, to_out} - dense = functools.partial( - nn.Dense, - self.heads * self.dim_head, - precision=self.precision, - use_bias=self.use_bias, - kernel_init=self.kernel_init(), - dtype=self.dtype - ) - self.query = dense(name="to_q") - self.key = dense(name="to_k") - self.value = dense(name="to_v") - - self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision, - kernel_init=self.kernel_init(), dtype=self.dtype, name="to_out_0") - # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16) - - def _reshape_tensor_to_head_dim(self, tensor): - batch_size, _, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - return tensor - - def _reshape_tensor_from_head_dim(self, tensor): - batch_size, _, seq_len, dim = tensor.shape - head_size = self.heads - tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - tensor = tensor.reshape(batch_size, 1, seq_len, dim * head_size) - return tensor - - @nn.compact - def __call__(self, x:jax.Array, context=None): - # print(x.shape) - # x has shape [B, H * W, C] - context = x if context is None else context - - B, H, W, C = x.shape - x = x.reshape((B, 1, H * W, C)) - - B, _H, _W, _C = context.shape - context = context.reshape((B, 1, _H * _W, _C)) - - query = self.query(x) - key = self.key(context) - value = self.value(context) - - query = self._reshape_tensor_to_head_dim(query) - key = self._reshape_tensor_to_head_dim(key) - value = self._reshape_tensor_to_head_dim(value) - - hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.flash_attention( - query, key, value, None - ) - - hidden_states = self._reshape_tensor_from_head_dim(hidden_states) - - - # hidden_states = nn.dot_product_attention( - # query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision - # ) - - proj = self.proj_attn(hidden_states) - - proj = proj.reshape((B, H, W, C)) - - return proj - - -class NormalAttention(nn.Module): - """ - Simple implementation of the normal attention. - """ - query_dim: int - heads: int = 4 - dim_head: int = 64 - dtype: Any = jnp.float32 - precision: Any = jax.lax.Precision.HIGHEST - use_bias: bool = True - kernel_init: Callable = lambda : kernel_init(1.0) - - def setup(self): - inner_dim = self.dim_head * self.heads - dense = functools.partial( - nn.DenseGeneral, - features=[self.heads, self.dim_head], - axis=-1, - precision=self.precision, - use_bias=self.use_bias, - kernel_init=self.kernel_init(), - dtype=self.dtype - ) - self.query = dense(name="to_q") - self.key = dense(name="to_k") - self.value = dense(name="to_v") - - self.proj_attn = nn.DenseGeneral( - self.query_dim, - axis=(-2, -1), - precision=self.precision, - use_bias=self.use_bias, - dtype=self.dtype, - name="to_out_0", - kernel_init=self.kernel_init() - # kernel_init=jax.nn.initializers.xavier_uniform() - ) - - @nn.compact - def __call__(self, x, context=None): - # x has shape [B, H, W, C] - context = x if context is None else context - query = self.query(x) - key = self.key(context) - value = self.value(context) - - hidden_states = nn.dot_product_attention( - query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision - ) - - proj = self.proj_attn(hidden_states) - return proj - -class AttentionBlock(nn.Module): - # Has self and cross attention - query_dim: int - heads: int = 4 - dim_head: int = 64 - dtype: Any = jnp.float32 - precision: Any = jax.lax.Precision.HIGHEST - use_bias: bool = True - kernel_init: Callable = lambda : kernel_init(1.0) - use_flash_attention:bool = False - use_cross_only:bool = False - - def setup(self): - if self.use_flash_attention: - attenBlock = EfficientAttention - else: - attenBlock = NormalAttention - - self.attention1 = attenBlock( - query_dim=self.query_dim, - heads=self.heads, - dim_head=self.dim_head, - name=f'Attention1', - precision=self.precision, - use_bias=self.use_bias, - dtype=self.dtype, - kernel_init=self.kernel_init - ) - self.attention2 = attenBlock( - query_dim=self.query_dim, - heads=self.heads, - dim_head=self.dim_head, - name=f'Attention2', - precision=self.precision, - use_bias=self.use_bias, - dtype=self.dtype, - kernel_init=self.kernel_init - ) - - self.ff = nn.DenseGeneral( - features=self.query_dim, - use_bias=self.use_bias, - precision=self.precision, - dtype=self.dtype, - kernel_init=self.kernel_init(), - name="ff" - ) - self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype) - self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype) - self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype) - self.norm4 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype) - - @nn.compact - def __call__(self, hidden_states, context=None): - # self attention - residual = hidden_states - hidden_states = self.norm1(hidden_states) - if self.use_cross_only: - hidden_states = self.attention1(hidden_states, context) - else: - hidden_states = self.attention1(hidden_states) - hidden_states = hidden_states + residual - - # cross attention - residual = hidden_states - hidden_states = self.norm2(hidden_states) - hidden_states = self.attention2(hidden_states, context) - hidden_states = hidden_states + residual - - # feed forward - residual = hidden_states - hidden_states = self.norm3(hidden_states) - hidden_states = nn.gelu(hidden_states) - hidden_states = self.ff(hidden_states) - hidden_states = hidden_states + residual - - return hidden_states - -class TransformerBlock(nn.Module): - heads: int = 4 - dim_head: int = 32 - use_linear_attention: bool = True - dtype: Any = jnp.bfloat16 - precision: Any = jax.lax.Precision.HIGH - use_projection: bool = False - use_flash_attention:bool = True - use_self_and_cross:bool = False - - @nn.compact - def __call__(self, x, context=None): - inner_dim = self.heads * self.dim_head - B, H, W, C = x.shape - normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x) - if self.use_projection == True: - if self.use_linear_attention: - projected_x = nn.Dense(features=inner_dim, - use_bias=False, precision=self.precision, - kernel_init=kernel_init(1.0), - dtype=self.dtype, name=f'project_in')(normed_x) - else: - projected_x = nn.Conv( - features=inner_dim, kernel_size=(1, 1), - kernel_init=kernel_init(1.0), - strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype, - precision=self.precision, name=f'project_in_conv', - )(normed_x) - else: - projected_x = normed_x - inner_dim = C - - context = projected_x if context is None else context - - if self.use_self_and_cross: - projected_x = AttentionBlock( - query_dim=inner_dim, - heads=self.heads, - dim_head=self.dim_head, - name=f'Attention', - precision=self.precision, - use_bias=False, - dtype=self.dtype, - use_flash_attention=self.use_flash_attention, - use_cross_only=False - )(projected_x, context) - elif self.use_flash_attention == True: - projected_x = EfficientAttention( - query_dim=inner_dim, - heads=self.heads, - dim_head=self.dim_head, - name=f'Attention', - precision=self.precision, - use_bias=False, - dtype=self.dtype, - )(projected_x, context) - else: - projected_x = NormalAttention( - query_dim=inner_dim, - heads=self.heads, - dim_head=self.dim_head, - name=f'Attention', - precision=self.precision, - use_bias=False, - )(projected_x, context) - - - if self.use_projection == True: - if self.use_linear_attention: - projected_x = nn.Dense(features=C, precision=self.precision, - dtype=self.dtype, use_bias=False, - kernel_init=kernel_init(1.0), - name=f'project_out')(projected_x) - else: - projected_x = nn.Conv( - features=C, kernel_size=(1, 1), - kernel_init=kernel_init(1.0), - strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype, - precision=self.precision, name=f'project_out_conv', - )(projected_x) - - out = x + projected_x - return out - - -# %% [markdown] -# ## Attention and other prototyping - -# %% -x = jnp.ones((16, 1, 16*16, 64)) -batch_size, _, seq_len, dim = x.shape -head_size = 4 -dim_head = dim // head_size -k = nn.Dense(dim_head * head_size, precision=jax.lax.Precision.HIGHEST, use_bias=True) -param = k.init(jax.random.PRNGKey(42), x) -tensor = k.apply(param, x) -print(tensor.shape) -tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) -tensor = jnp.transpose(tensor, (0, 2, 1, 3)) -print(tensor.shape) - - - -# %% -x = jnp.ones((16, 64, 64, 128)) -context = jnp.ones((16, 64, 64, 128)) -attention_block = TransformerBlock(heads=4, dim_head=64//4, dtype=jnp.bfloat16, use_flash_attention=False, use_projection=False, use_self_and_cross=False) -params = attention_block.init(jax.random.PRNGKey(0), x, context) -@jax.jit -def apply(params, x, context): - return attention_block.apply(params, x, context) - -apply(params, x, context) - -%timeit -n 1 apply(params, x, context) - -# %% -x = jnp.ones((1, 16, 16, 64)) -context = jnp.ones((1, 12, 768)) -# pad the context -context = jnp.pad(context, ((0, 0), (0, 4), (0, 0)), mode='constant', constant_values=0) -print(context.shape) -context = None#jnp.reshape(context, (1, 1, 16, 768)) -attention_block = TransformerBlock(heads=4, dim_head=64//4, dtype=jnp.bfloat16, use_flash_attention=True, use_projection=False, use_self_and_cross=False) -params = attention_block.init(jax.random.PRNGKey(0), x, context) -out = attention_block.apply(params, x, context) -print("Output :", out.shape) -print(attention_block.tabulate(jax.random.key(0), x, context, console_kwargs={"width": 200, "force_jupyter":True, })) -print(jnp.mean(out), jnp.std(out)) -# plt.hist(out.flatten(), bins=100) -# %timeit attention_block.apply(params, x) - -# %% [markdown] -# ## Main Model - -# %% -class ResidualBlock(nn.Module): - conv_type:str - features:int - kernel_size:tuple=(3, 3) - strides:tuple=(1, 1) - padding:str="SAME" - activation:Callable=jax.nn.swish - direction:str=None - res:int=2 - norm_groups:int=8 - kernel_init:Callable=kernel_init(1.0) - dtype: Any = jnp.float32 - precision: Any = jax.lax.Precision.HIGHEST - - @nn.compact - def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None): - residual = x - out = nn.GroupNorm(self.norm_groups)(x) - out = self.activation(out) - - out = ConvLayer( - self.conv_type, - features=self.features, - kernel_size=self.kernel_size, - strides=self.strides, - kernel_init=self.kernel_init, - name="conv1", - dtype=self.dtype, - precision=self.precision - )(out) - - temb = nn.DenseGeneral( - features=self.features, - name="temb_projection", - dtype=self.dtype, - precision=self.precision)(temb) - temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1) - # scale, shift = jnp.split(temb, 2, axis=-1) - # out = out * (1 + scale) + shift - out = out + temb - - out = nn.GroupNorm(self.norm_groups)(out) - out = self.activation(out) - - out = ConvLayer( - self.conv_type, - features=self.features, - kernel_size=self.kernel_size, - strides=self.strides, - kernel_init=self.kernel_init, - name="conv2", - dtype=self.dtype, - precision=self.precision - )(out) - - if residual.shape != out.shape: - residual = ConvLayer( - self.conv_type, - features=self.features, - kernel_size=(1, 1), - strides=1, - kernel_init=self.kernel_init, - name="residual_conv", - dtype=self.dtype, - precision=self.precision - )(residual) - out = out + residual - - out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out - - return out - class Unet(nn.Module): emb_features:int=64*4, feature_depths:list=[64, 128, 256, 512], @@ -730,7 +282,7 @@ class Unet(nn.Module): num_res_blocks:int=2, num_middle_res_blocks:int=1, activation:Callable = jax.nn.swish - norm_groups:int=8 + norm_groups:int=32 dtype: Any = jnp.bfloat16 precision: Any = jax.lax.Precision.HIGH @@ -912,7 +464,7 @@ def __call__(self, x, temb, textcontext=None): precision=self.precision )(x, temb) - x = nn.GroupNorm(self.norm_groups)(x) + x = nn.RMSNorm()(x) x = self.activation(x) noise_out = ConvLayer( @@ -927,33 +479,6 @@ def __call__(self, x, temb, textcontext=None): )(x) return noise_out#, attentions -# %% -unet = Unet(emb_features=512, - feature_depths=[128, 256, 512, 1024], - attention_configs=[ - None, - # None, - # {"heads":32, "dtype":jnp.bfloat16, "flash_attention":True, "use_projection":False, "use_self_and_cross":True}, - {"heads":32, "dtype":jnp.bfloat16, "flash_attention":True, "use_projection":True, "use_self_and_cross":True}, - {"heads":32, "dtype":jnp.bfloat16, "flash_attention":True, "use_projection":True, "use_self_and_cross":True}, - {"heads":32, "dtype":jnp.bfloat16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False} - ], - num_res_blocks=4, - num_middle_res_blocks=1 -) - -inp = jnp.ones((1, 128, 128, 3)) -temb = jnp.ones((1,)) -textcontext = jnp.ones((1, 77, 768)) - -params = unet.init(jax.random.PRNGKey(0), inp, temb, textcontext) - -# %% -unet.tabulate(jax.random.key(0), inp, temb, textcontext, console_kwargs={"width": 200, "force_jupyter":True, }) - -# %% [markdown] -# # Training - # %% import flax.jax_utils import orbax.checkpoint @@ -1216,14 +741,6 @@ def fit(self, data, steps_per_epoch, epochs, train_step_args={}): # Compute Metrics metrics_str = '' - # if test_ds is not None: - # for test_batch in iter(test_ds()): - # state = compute_metrics(state, test_batch) - # metrics = state.metrics.compute() - # for metric,value in metrics.items(): - # summary_writer.scalar(f'Test {metric}', value, step=current_epoch) - # metrics_str += f', Test {metric}: {value:.4f}' - # state = state.replace(metrics=Metrics.empty()) print(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss} {metrics_str}") @@ -1386,67 +903,123 @@ def fit(self, data, steps_per_epoch, epochs): super().fit(data, steps_per_epoch, epochs, {"batch_size":batch_size, "null_labels_seq":null_labels_full, "text_embedder":text_embedder}) # %% -BATCH_SIZE = 64 -IMAGE_SIZE = 128 +# Parse command-line arguments +parser = argparse.ArgumentParser(description='Train a diffusion model') +parser.add_argument('--GRAIN_WORKER_COUNT', type=int, default=16, help='Number of grain workers') +parser.add_argument('--GRAIN_READ_THREAD_COUNT', type=int, default=64, help='Number of grain read threads') +parser.add_argument('--GRAIN_READ_BUFFER_SIZE', type=int, default=50, help='Grain read buffer size') +parser.add_argument('--GRAIN_WORKER_BUFFER_SIZE', type=int, default=20, help='Grain worker buffer size') + +parser.add_argument('--BATCH_SIZE', type=int, default=64, help='Batch size') +parser.add_argument('--IMAGE_SIZE', type=int, default=128, help='Image size') +parser.add_argument('--epochs', type=int, default=3, help='Number of epochs') +parser.add_argument('--steps_per_epoch', type=int, default=None, help='Steps per epoch') +parser.add_argument('--dataset', type=str, default='cc12m', help='Dataset to use') + +parser.add_argument('--learning_rate', type=float, default=2e-4, help='Learning rate') +parser.add_argument('--noise_schedule', type=str, default='edm', choices=['cosine', 'karras', 'edm'], help='Noise schedule') + +parser.add_argument('--emb_features', type=int, default=256, help='Embedding features') +parser.add_argument('--feature_depths', type=int, nargs='+', default=[64, 128, 256, 512], help='Feature depths') +parser.add_argument('--attention_heads', type=int, default=8, help='Number of attention heads') +parser.add_argument('--flash_attention', type=bool, default=False, help='Use Flash Attention') +parser.add_argument('--use_projection', type=bool, default=False, help='Use projection') +parser.add_argument('--use_self_and_cross', type=bool, default=False, help='Use self and cross attention') +parser.add_argument('--num_res_blocks', type=int, default=2, help='Number of residual blocks') +parser.add_argument('--num_middle_res_blocks', type=int, default=1, help='Number of middle residual blocks') +parser.add_argument('--activation', type=str, default='swish', help='activation to use') + +parser.add_argument('--dtype', type=str, default='bfloat16', help='dtype to use') +parser.add_argument('--precision', type=str, default='high', help='precision to use') + + +args = parser.parse_args() + +DTYPE_MAP = { + 'bfloat16': jnp.bfloat16, + 'float32': jnp.float32 +} -cosine_schedule = CosineNoiseSchedule(1000, beta_end=1) -karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5) -edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5) +PRECISION_MAP = { + 'high': jax.lax.Precision.HIGH, + 'default': jax.lax.Precision.DEFAULT, + 'highes': jax.lax.Precision.HIGHEST +} -experiment_name = "{name}_{date}".format( - name="Diffusion_SDE_VE_TEXT", date=datetime.now().strftime("%Y-%m-%d_%H:%M:%S") -) -# experiment_name = 'Diffusion_SDE_VE_TEXT_2024-07-16_02:16:07' -# experiment_name = 'Diffusion_SDE_VE_TEXT_2024-07-21_02:12:40' -# experiment_name = 'Diffusion_SDE_VE_TEXT_2024-07-30_05:48:22' -# experiment_name = 'Diffusion_SDE_VE_TEXT_2024-08-01_08:59:00' -print("Experiment_Name:", experiment_name) +ACTIVATION_MAP = { + 'swish': jax.nn.swish, + 'mish': jax.nn.mish, +} + +DTYPE = DTYPE_MAP[args.dtype] +PRECISION = PRECISION_MAP[args.precision] + +GRAIN_WORKER_COUNT = args.GRAIN_WORKER_COUNT +GRAIN_READ_THREAD_COUNT = args.GRAIN_READ_THREAD_COUNT +GRAIN_READ_BUFFER_SIZE = args.GRAIN_READ_BUFFER_SIZE +GRAIN_WORKER_BUFFER_SIZE = args.GRAIN_WORKER_BUFFER_SIZE + +BATCH_SIZE = args.BATCH_SIZE +IMAGE_SIZE = args.IMAGE_SIZE -dataset_name = "cc12m" -datalen = len(datasetMap[dataset_name]['source']) +dataset_name = args.dataset +datalen = len(datasetMap[dataset_name]['source']()) batches = datalen // BATCH_SIZE -config = { - "model" : { - "emb_features":256, - "feature_depths":[64, 128, 256, 512], - "attention_configs":[ - None, - # None, - # None, - # None, - # None, - # {"heads":32, "dtype":jnp.bfloat16, "flash_attention":True, "use_projection":False, "use_self_and_cross":True}, - {"heads":8, "dtype":jnp.bfloat16, "flash_attention":True, "use_projection":False, "use_self_and_cross":False}, - {"heads":8, "dtype":jnp.bfloat16, "flash_attention":True, "use_projection":False, "use_self_and_cross":False}, - {"heads":8, "dtype":jnp.bfloat16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}, - ], - "num_res_blocks":2, - "num_middle_res_blocks":1, +# Define the configuration using the command-line arguments +attention_configs = [ + None, +] + +attention_configs += [ + {"heads": args.attention_heads, "dtype": DTYPE, "flash_attention": args.flash_attention, "use_projection": args.use_projection, "use_self_and_cross": args.use_self_and_cross}, +] * (len(args.feature_depths) - 2) + +attention_configs += [ + {"heads": args.attention_heads, "dtype": DTYPE, "flash_attention": False, "use_projection": False, "use_self_and_cross": False}, +] + +CONFIG = { + "model": { + "emb_features": args.emb_features, + "feature_depths": args.feature_depths, + "attention_configs": attention_configs, + "num_res_blocks": args.num_res_blocks, + "num_middle_res_blocks": args.num_middle_res_blocks, + "dtype": DTYPE, + "precision": PRECISION, + "activation": ACTIVATION_MAP[args.activation], }, - "dataset": { - "name" : dataset_name, - "length" : datalen, - "batches": batches + "name": dataset_name, + "length": datalen, + "batches": datalen // BATCH_SIZE, }, - "learning_rate": 2e-4, - + "learning_rate": args.learning_rate, "input_shapes": { - "x": (IMAGE_SIZE, IMAGE_SIZE, 3), + "x": (args.IMAGE_SIZE, args.IMAGE_SIZE, 3), "temb": (), "textcontext": (77, 768) }, } -unet = Unet(**config['model']) +cosine_schedule = CosineNoiseSchedule(1000, beta_end=1) +karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5) +edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5) + +experiment_name = "{name}_{date}".format( + name="Diffusion_SDE_VE_TEXT", date=datetime.now().strftime("%Y-%m-%d_%H:%M:%S") +) +print("Experiment_Name:", experiment_name) -learning_rate = config['learning_rate'] +unet = Unet(**CONFIG['model']) + +learning_rate = CONFIG['learning_rate'] solver = optax.adam(learning_rate) # solver = optax.adamw(2e-6) trainer = DiffusionTrainer(unet, optimizer=solver, - input_shapes=config['input_shapes'], + input_shapes=CONFIG['input_shapes'], noise_schedule=edm_schedule, rngs=jax.random.PRNGKey(4), name=experiment_name, @@ -1457,676 +1030,23 @@ def fit(self, data, steps_per_epoch, epochs): # load_from_checkpoint=True, wandb_config={ "project": "flaxdiff", - "config": config, + "config": CONFIG, "name": experiment_name, }, ) # %% -trainer.summary() - -# %% -data = get_dataset_grain(config['dataset']['name'], batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE) - -# %% -# jax.profiler.start_server(6009) -final_state = trainer.fit(data, 1000, epochs=3) - -# %% -# jax.profiler.start_server(6009) -final_state = trainer.fit(data, 1000, epochs=1) - -# %% -# jax.profiler.start_server(6009) -final_state = trainer.fit(data, 1000, epochs=1) - -# %% -data = get_dataset("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE) -final_state = trainer.fit(data, batches, epochs=4000) +print(trainer.summary()) # %% -from flaxdiff.utils import clip_images - -def clip_images(images, clip_min=-1, clip_max=1): - return jnp.clip(images, clip_min, clip_max) - -class DiffusionSampler(): - model:nn.Module - noise_schedule:NoiseScheduler - params:dict - model_output_transform:DiffusionPredictionTransform - - def __init__(self, model:nn.Module, params:dict, - noise_schedule:NoiseScheduler, - model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(), - guidance_scale:float = 0.0, - null_labels_seq:jax.Array=None - ): - self.model = model - self.noise_schedule = noise_schedule - self.params = params - self.model_output_transform = model_output_transform - self.guidance_scale = guidance_scale - if self.guidance_scale > 0: - # Classifier free guidance - assert null_labels_seq is not None, "Null labels sequence is required for classifier-free guidance" - print("Using classifier-free guidance") - @jax.jit - def sample_model(x_t, t, *additional_inputs): - # Concatenate unconditional and conditional inputs - x_t_cat = jnp.concatenate([x_t] * 2, axis=0) - t_cat = jnp.concatenate([t] * 2, axis=0) - rates_cat = self.noise_schedule.get_rates(t_cat) - c_in_cat = self.model_output_transform.get_input_scale(rates_cat) - - text_labels_seq, = additional_inputs - text_labels_seq = jnp.concatenate([text_labels_seq, jnp.broadcast_to(null_labels_seq, text_labels_seq.shape)], axis=0) - model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq) - # Split model output into unconditional and conditional parts - model_output_cond, model_output_uncond = jnp.split(model_output, 2, axis=0) - model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond) - - x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule) - return x_0, eps, model_output - - self.sample_model = sample_model - else: - # Unconditional sampling - @jax.jit - def sample_model(x_t, t, *additional_inputs): - rates = self.noise_schedule.get_rates(t) - c_in = self.model_output_transform.get_input_scale(rates) - model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs) - x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule) - return x_0, eps, model_output - - self.sample_model = sample_model - - # Used to sample from the diffusion model - def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]: - # First clip the noisy images - step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32) - current_step = step_ones * current_step - next_step = step_ones * next_step - pred_images, pred_noise, _ = self.sample_model(current_samples, current_step, *model_conditioning_inputs) - # plotImages(pred_images) - pred_images = clip_images(pred_images) - new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images, - pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state, - model_conditioning_inputs=model_conditioning_inputs - ) - return new_samples, state - - def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, - pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: - # estimate the q(x_{t-1} | x_t, x_0). - # pred_images is x_0, noisy_images is x_t, steps is t - return NotImplementedError - - def scale_steps(self, steps): - scale_factor = self.noise_schedule.max_timesteps / 1000 - return steps * scale_factor - - def get_steps(self, start_step, end_step, diffusion_steps): - step_range = start_step - end_step - if diffusion_steps is None or diffusion_steps == 0: - diffusion_steps = start_step - end_step - diffusion_steps = min(diffusion_steps, step_range) - steps = jnp.linspace(end_step, start_step, diffusion_steps, dtype=jnp.int16)[::-1] - return steps - - def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step): - start_step = self.scale_steps(start_step) - alpha_n, sigma_n = self.noise_schedule.get_rates(start_step) - variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2) - return jax.random.normal(rngs, (num_images, IMAGE_SIZE, IMAGE_SIZE, 3)) * variance - - def generate_images(self, - num_images=16, - diffusion_steps=1000, - start_step:int = None, - end_step:int = 0, - steps_override=None, - priors=None, - rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42)), - model_conditioning_inputs:tuple=() - ) -> jnp.ndarray: - if priors is None: - rngstate, newrngs = rngstate.get_random_key() - samples = self.get_initial_samples(num_images, newrngs, start_step) - else: - print("Using priors") - samples = priors - - # @jax.jit - def sample_step(state:RandomMarkovState, samples, current_step, next_step): - samples, state = self.sample_step(current_samples=samples, - current_step=current_step, - model_conditioning_inputs=model_conditioning_inputs, - state=state, next_step=next_step) - return samples, state - - if start_step is None: - start_step = self.noise_schedule.max_timesteps - - if steps_override is not None: - steps = steps_override - else: - steps = self.get_steps(start_step, end_step, diffusion_steps) - - # print("Sampling steps", steps) - for i in tqdm.tqdm(range(0, len(steps))): - current_step = self.scale_steps(steps[i]) - next_step = self.scale_steps(steps[i+1] if i+1 < len(steps) else 0) - if i != len(steps) - 1: - # print("normal step") - samples, rngstate = sample_step(rngstate, samples, current_step, next_step) - else: - # print("last step") - step_ones = jnp.ones((num_images, ), dtype=jnp.int32) - samples, _, _ = self.sample_model(samples, current_step * step_ones, *model_conditioning_inputs) - samples = clip_images(samples) - return samples - -class DDPMSampler(DiffusionSampler): - def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, - pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: - mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step) - variance = self.noise_schedule.get_posterior_variance(steps=current_step) - - state, rng = state.get_random_key() - # Now sample from the posterior - noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32) - - return mean + noise * variance, state - - def generate_images(self, num_images=16, diffusion_steps=1000, start_step: int = None, *args, **kwargs): - return super().generate_images(num_images=num_images, diffusion_steps=diffusion_steps, start_step=start_step, *args, **kwargs) - -class SimpleDDPMSampler(DiffusionSampler): - def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, - pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: - state, rng = state.get_random_key() - noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32) - - # Compute noise rates and signal rates only once - current_signal_rate, current_noise_rate = self.noise_schedule.get_rates(current_step) - next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step) - - pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate) - - noise_ratio_squared = (next_noise_rate ** 2) / (current_noise_rate ** 2) - signal_ratio_squared = (current_signal_rate ** 2) / (next_signal_rate ** 2) - gamma = jnp.sqrt(noise_ratio_squared * (1 - signal_ratio_squared)) - - next_samples = next_signal_rate * reconstructed_samples + pred_noise_coeff * pred_noise + noise * gamma - return next_samples, state - -class DDIMSampler(DiffusionSampler): - def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, - pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: - next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step) - return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state - -class EulerSampler(DiffusionSampler): - # Basically a DDIM Sampler but parameterized as an ODE - def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, - pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: - current_alpha, current_sigma = self.noise_schedule.get_rates(current_step) - next_alpha, next_sigma = self.noise_schedule.get_rates(next_step) - - dt = next_sigma - current_sigma - - x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / (dt) - dx = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma - next_samples = current_samples + dx * dt - return next_samples, state - -class SimplifiedEulerSampler(DiffusionSampler): - """ - This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t - """ - def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, - pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: - _, current_sigma = self.noise_schedule.get_rates(current_step) - _, next_sigma = self.noise_schedule.get_rates(next_step) - - dt = next_sigma - current_sigma - - dx = (current_samples - reconstructed_samples) / current_sigma - next_samples = current_samples + dx * dt - return next_samples, state - -class HeunSampler(DiffusionSampler): - def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, - pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: - # Get the noise and signal rates for the current and next steps - current_alpha, current_sigma = self.noise_schedule.get_rates(current_step) - next_alpha, next_sigma = self.noise_schedule.get_rates(next_step) - - dt = next_sigma - current_sigma - x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / dt - - dx_0 = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma - next_samples_0 = current_samples + dx_0 * dt - - # Recompute x_0 and eps at the first estimate to refine the derivative - estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step, *model_conditioning_inputs) - - # Estimate the refined derivative using the midpoint (Heun's method) - dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma - # Compute the final next samples by averaging the initial and refined derivatives - final_next_samples = current_samples + 0.5 * (dx_0 + dx_1) * dt - - return final_next_samples, state - -class RK4Sampler(DiffusionSampler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert issubclass(type(self.noise_schedule), GeneralizedNoiseScheduler), "Noise schedule must be a GeneralizedNoiseScheduler" - @jax.jit - def get_derivative(x_t, sigma, state:RandomMarkovState, model_conditioning_inputs) -> tuple[jnp.ndarray, RandomMarkovState]: - t = self.noise_schedule.get_timesteps(sigma) - x_0, eps, _ = self.sample_model(x_t, t, *model_conditioning_inputs) - return eps, state - - self.get_derivative = get_derivative - - def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]: - step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32) - current_step = step_ones * current_step - next_step = step_ones * next_step - _, current_sigma = self.noise_schedule.get_rates(current_step) - _, next_sigma = self.noise_schedule.get_rates(next_step) - - dt = next_sigma - current_sigma - - k1, state = self.get_derivative(current_samples, current_sigma, state, model_conditioning_inputs) - k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs) - k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs) - k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state, model_conditioning_inputs) - - next_samples = current_samples + (((k1 + 2 * k2 + 2 * k3 + k4) * dt) / 6) - return next_samples, state - -class MultiStepDPM(DiffusionSampler): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.history = [] - - def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, - pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: - # Get the noise and signal rates for the current and next steps - current_alpha, current_sigma = self.noise_schedule.get_rates(current_step) - next_alpha, next_sigma = self.noise_schedule.get_rates(next_step) - - dt = next_sigma - current_sigma - - def first_order(current_noise, current_sigma): - dx = current_noise - return dx - - def second_order(current_noise, current_sigma, last_noise, last_sigma): - dx_2 = (current_noise - last_noise) / (current_sigma - last_sigma) - return dx_2 - - def third_order(current_noise, current_sigma, last_noise, last_sigma, second_last_noise, second_last_sigma): - dx_2 = second_order(current_noise, current_sigma, last_noise, last_sigma) - dx_2_last = second_order(last_noise, last_sigma, second_last_noise, second_last_sigma) - - dx_3 = (dx_2 - dx_2_last) / (0.5 * ((current_sigma + last_sigma) - (last_sigma + second_last_sigma))) - - return dx_3 - - if len(self.history) == 0: - # First order only - dx_1 = first_order(pred_noise, current_sigma) - next_samples = current_samples + dx_1 * dt - elif len(self.history) == 1: - # First + Second order - dx_1 = first_order(pred_noise, current_sigma) - last_step = self.history[-1] - dx_2 = second_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma']) - next_samples = current_samples + dx_1 * dt + 0.5 * dx_2 * dt**2 - else: - # First + Second + Third order - last_step = self.history[-1] - second_last_step = self.history[-2] - - dx_1 = first_order(pred_noise, current_sigma) - dx_2 = second_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma']) - dx_3 = third_order(pred_noise, current_sigma, last_step['eps'], last_step['sigma'], second_last_step['eps'], second_last_step['sigma']) - next_samples = current_samples + (dx_1 * dt) + (0.5 * dx_2 * dt**2) + ((1/6) * dx_3 * dt**3) - - self.history.append({ - "eps": pred_noise, - "sigma" : current_sigma, - }) - return next_samples, state - -class EulerAncestralSampler(DiffusionSampler): - def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs, - pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]: - current_alpha, current_sigma = self.noise_schedule.get_rates(current_step) - next_alpha, next_sigma = self.noise_schedule.get_rates(next_step) - - sigma_up = (next_sigma**2 * (current_sigma**2 - next_sigma**2) / current_sigma**2) ** 0.5 - sigma_down = (next_sigma**2 - sigma_up**2) ** 0.5 - - dt = sigma_down - current_sigma - - x_0_coeff = (current_alpha * next_sigma - next_alpha * current_sigma) / (next_sigma - current_sigma) - dx = (current_samples - x_0_coeff * reconstructed_samples) / current_sigma - - state, subkey = state.get_random_key() - dW = jax.random.normal(subkey, current_samples.shape) * sigma_up - - next_samples = current_samples + dx * dt + dW - return next_samples, state - -# %% -images = next(iter(data)) -plotImages(images, dpi=300) -print(images.shape) -noise_schedule = karas_ve_schedule -predictor = trainer.model_output_transform - -rng = jax.random.PRNGKey(4) -noise = jax.random.normal(rng, shape=images.shape, dtype=jnp.float32) -noise_level = 0.9999 -noise_levels = jnp.ones((images.shape[0],), dtype=jnp.int32) * noise_level - -rates = noise_schedule.get_rates(noise_levels) -noisy_images, c_in, expected_output = predictor.forward_diffusion(images, noise, rates=rates) -plotImages(noisy_images) -print(jnp.mean(noisy_images), jnp.std(noisy_images)) -regenerated_images = noise_schedule.remove_all_noise(noisy_images, noise, noise_levels) -plotImages(regenerated_images) - -sampler = EulerSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform) -samples = sampler.generate_images(num_images=16, diffusion_steps=20, start_step=int(noise_level*1000), end_step=0, priors=None) -plotImages(samples, dpi=300) - -# %% -textEncoderModel, textTokenizer = defaultTextEncodeModel() - -# %% -prompts = [ - 'water tulip', - 'a water lily', - 'a water lily', - 'a photo of a rose' - ] -pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) - -sampler = EulerAncestralSampler(trainer.model, trainer.get_state().ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=2, null_labels_seq=data['null_labels_full']) -samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) -plotImages(samples, dpi=300) - - -# %% -prompts = [ - 'water tulip', - 'a water lily', - 'a water lily', - 'a photo of a rose' - ] -pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) - -sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=2, null_labels_seq=data['null_labels_full']) -samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) -plotImages(samples, dpi=300) - - -# %% -prompts = [ - 'water tulip', - 'a water lily', - 'a water lily', - 'a photo of a rose' - ] -pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) - -sampler = EulerAncestralSampler(trainer.model, trainer.best_state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=2, null_labels_seq=data['null_labels_full']) -samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) -plotImages(samples, dpi=300) - - -# %% -prompts = [ - 'water tulip', - 'a water lily', - 'a water lily', - 'a water lily', - 'a photo of a marigold', - 'a water lily', - 'a water lily', - 'a photo of a lotus', - 'a photo of a lotus', - 'a photo of a lotus', - 'a photo of a rose', - 'a photo of a rose', - 'a photo of a rose', - 'a photo of a rose', - 'a photo of a rose', - ] -pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) - -sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=2, null_labels_seq=data['null_labels_full']) -samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) -plotImages(samples, dpi=300) - - -# %% -prompts = [ - 'water tulip', - 'a water lily', - 'a water lily', - 'a water lily', - 'a photo of a marigold', - 'a water lily', - 'a water lily', - 'a photo of a lotus', - 'a photo of a lotus', - 'a photo of a lotus', - 'a photo of a rose', - 'a photo of a rose', - 'a photo of a rose', - 'a photo of a rose', - 'a photo of a rose', - ] -pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) - -sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=2, null_labels_seq=data['null_labels_full']) -samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) -plotImages(samples, dpi=300) - - -# %% -prompts = [ - 'water tulip', - 'a water lily', - 'a water lily', - 'a water lily', - 'a photo of a marigold', - 'a water lily', - 'a water lily', - 'a photo of a lotus', - 'a photo of a lotus', - 'a photo of a lotus', - 'a photo of a rose', - 'a photo of a rose', - 'a photo of a rose', - 'a photo of a rose', - 'a photo of a rose', - ] -pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) - -sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=2, null_labels_seq=data['null_labels_full']) -samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) -plotImages(samples, dpi=300) - -# %% -prompts = [ - 'water tulip', - 'a water lily', - 'a water lily', - 'a water lily', - 'a photo of a marigold', - 'a water lily', - 'a water lily', - 'a photo of a lotus', - 'a photo of a lotus', - 'a photo of a lotus', - 'a photo of a rose', - 'a photo of a rose', - 'a photo of a rose', - 'a photo of a rose', - 'a photo of a rose', - ] -pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) - -sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=4, null_labels_seq=data['null_labels_full']) -samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) -plotImages(samples, dpi=500, fig_size=(4, 5)) - -# %% -prompts = [ - 'water tulip', - 'a green water rose', - 'a green water rose', - 'a green water rose', - 'a water red rose', - 'a marigold and rose hybrid', - 'a marigold and rose hybrid', - 'a marigold and rose hybrid', - 'a water lily and a marigold', - 'a water lily and a marigold', - 'a water lily and a marigold', - 'a water lily and a marigold', - ] -pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) - -sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=3, null_labels_seq=data['null_labels_full']) -samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) -plotImages(samples, dpi=300) - -# %% -prompts = [ - 'water tulip', - 'a water lily', - 'a water lily', - 'a photo of a rose', - 'a photo of a rose', - 'a water lily', - 'a water lily', - 'a photo of a marigold', - 'a photo of a marigold', - 'a photo of a marigold', - 'a water lily', - 'a photo of a sunflower', - 'a photo of a lotus', - "columbine", - "columbine", - "an orchid", - "an orchid", - "an orchid", - 'a water lily', - 'a water lily', - 'a water lily', - "columbine", - "columbine", - 'a photo of a sunflower', - 'a photo of a sunflower', - 'a photo of a sunflower', - 'a photo of a lotus', - 'a photo of a lotus', - 'a photo of a marigold', - 'a photo of a marigold', - 'a photo of a rose', - 'a photo of a rose', - 'a photo of a rose', - "orange dahlia", - "orange dahlia", - "a lenten rose", - "a lenten rose", - 'a water lily', - 'a water lily', - 'a water lily', - 'a water lily', - "an orchid", - "an orchid", - "an orchid", - 'hard-leaved pocket orchid', - "bird of paradise", - "bird of paradise", - "a photo of a lovely rose", - "a photo of a lovely rose", - "a photo of a globe-flower", - "a photo of a globe-flower", - "a photo of a lovely rose", - "a photo of a lovely rose", - "a photo of a ruby-lipped cattleya", - "a photo of a ruby-lipped cattleya", - "a photo of a lovely rose", - 'a water lily', - 'a osteospermum', - 'a osteospermum', - 'a water lily', - 'a water lily', - 'a water lily', - "a red rose", - "a red rose", - ] -pooled_labels, labels_seq = encodePrompts(prompts, textEncoderModel, textTokenizer) - -sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform, guidance_scale=4, null_labels_seq=data['null_labels_full']) -samples = sampler.generate_images(num_images=len(prompts), diffusion_steps=200, start_step=1000, end_step=0, priors=None, model_conditioning_inputs=(labels_seq,)) -plotImages(samples, dpi=300) - -# %% -dataToLabelGenMap["oxford_flowers102"]() - -# %% - - -# %% -sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform) -samples = sampler.generate_images(num_images=64, diffusion_steps=200, start_step=1000, end_step=0, priors=None) -plotImages(samples, dpi=300) - -# %% -sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform) -samples = sampler.generate_images(num_images=64, diffusion_steps=200, start_step=1000, end_step=0, priors=None) -plotImages(samples, dpi=300) - -# %% -sampler = RK4Sampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform) -samples = sampler.generate_images(num_images=64, diffusion_steps=6, start_step=1000, end_step=0, priors=None) -plotImages(samples, dpi=300) - -# %% -sampler = EulerAncestralSampler(trainer.model, trainer.state.ema_params, karas_ve_schedule, model_output_transform=trainer.model_output_transform) -samples = sampler.generate_images(num_images=64, diffusion_steps=300, start_step=1000, end_step=0, priors=None) -plotImages(samples, dpi=300) - -# %% -sampler = DDPMSampler(trainer.model, trainer.state.params, trainer.noise_schedule, model_output_transform=trainer.model_output_transform) -samples = sampler.generate_images(num_images=16, start_step=1000, priors=None) -plotImages(samples, dpi=300) - -# %% -sampler = DDPMSampler(trainer.model, trainer.best_state.params, trainer.noise_schedule, model_output_transform=trainer.model_output_transform) -samples = sampler.generate_images(num_images=16, start_step=1000, priors=None) -plotImages(samples, dpi=300) - -# %% -sampler = DDPMSampler(trainer.model, trainer.best_state.params, trainer.noise_schedule, model_output_transform=trainer.model_output_transform) -samples = sampler.generate_images(num_images=64, start_step=1000, priors=None) -plotImages(samples) - -# %% [markdown] -# - - +data = get_dataset_grain(CONFIG['dataset']['name'], + batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE, + grain_worker_count=GRAIN_WORKER_COUNT, grain_read_thread_count=GRAIN_READ_THREAD_COUNT, + grain_read_buffer_size=GRAIN_READ_BUFFER_SIZE, grain_worker_buffer_size=GRAIN_WORKER_BUFFER_SIZE, + ) + +batches = batches if args.steps_per_epoch is None else args.steps_per_epoch +print(f"Training on {CONFIG['dataset']['name']} dataset with {batches} samples") +jax.profiler.start_server(6009) +final_state = trainer.fit(data, batches, epochs=10)