Skip to content

Commit

Permalink
Merge branch 'main' of github.com:AshishKumar4/FlaxDiff
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Aug 11, 2024
2 parents 9aae792 + 9445d03 commit 3cacef5
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 138 deletions.
341 changes: 243 additions & 98 deletions evaluate.ipynb

Large diffs are not rendered by default.

39 changes: 24 additions & 15 deletions flaxdiff/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class EfficientAttention(nn.Module):
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
use_bias: bool = True
kernel_init: Callable = lambda : kernel_init(1.0)
kernel_init: Callable = kernel_init(1.0)
force_fp32_for_softmax: bool = True

def setup(self):
inner_dim = self.dim_head * self.heads
Expand All @@ -32,15 +33,15 @@ def setup(self):
self.heads * self.dim_head,
precision=self.precision,
use_bias=self.use_bias,
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
dtype=self.dtype
)
self.query = dense(name="to_q")
self.key = dense(name="to_k")
self.value = dense(name="to_v")

self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
kernel_init=self.kernel_init(), dtype=self.dtype, name="to_out_0")
kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0")
# self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)

def _reshape_tensor_to_head_dim(self, tensor):
Expand Down Expand Up @@ -113,7 +114,8 @@ class NormalAttention(nn.Module):
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
use_bias: bool = True
kernel_init: Callable = lambda : kernel_init(1.0)
kernel_init: Callable = kernel_init(1.0)
force_fp32_for_softmax: bool = True

def setup(self):
inner_dim = self.dim_head * self.heads
Expand All @@ -123,7 +125,7 @@ def setup(self):
axis=-1,
precision=self.precision,
use_bias=self.use_bias,
kernel_init=self.kernel_init(),
kernel_init=self.kernel_init,
dtype=self.dtype
)
self.query = dense(name="to_q")
Expand All @@ -137,7 +139,7 @@ def setup(self):
use_bias=self.use_bias,
dtype=self.dtype,
name="to_out_0",
kernel_init=self.kernel_init()
kernel_init=self.kernel_init
# kernel_init=jax.nn.initializers.xavier_uniform()
)

Expand All @@ -157,7 +159,7 @@ def __call__(self, x, context=None):

hidden_states = nn.dot_product_attention(
query, key, value, dtype=self.dtype, broadcast_dropout=False,
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=True,
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
deterministic=True
)
proj = self.proj_attn(hidden_states)
Expand Down Expand Up @@ -233,10 +235,11 @@ class BasicTransformerBlock(nn.Module):
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
use_bias: bool = True
kernel_init: Callable = lambda : kernel_init(1.0)
kernel_init: Callable = kernel_init(1.0)
use_flash_attention:bool = False
use_cross_only:bool = False
only_pure_attention:bool = False
force_fp32_for_softmax: bool = True

def setup(self):
if self.use_flash_attention:
Expand All @@ -252,7 +255,8 @@ def setup(self):
precision=self.precision,
use_bias=self.use_bias,
dtype=self.dtype,
kernel_init=self.kernel_init
kernel_init=self.kernel_init,
force_fp32_for_softmax=self.force_fp32_for_softmax
)
self.attention2 = attenBlock(
query_dim=self.query_dim,
Expand All @@ -262,7 +266,8 @@ def setup(self):
precision=self.precision,
use_bias=self.use_bias,
dtype=self.dtype,
kernel_init=self.kernel_init
kernel_init=self.kernel_init,
force_fp32_for_softmax=self.force_fp32_for_softmax
)

self.ff = FlaxFeedForward(dim=self.query_dim)
Expand Down Expand Up @@ -296,6 +301,8 @@ class TransformerBlock(nn.Module):
use_flash_attention:bool = False
use_self_and_cross:bool = True
only_pure_attention:bool = False
force_fp32_for_softmax: bool = True
kernel_init: Callable = kernel_init(1.0)

@nn.compact
def __call__(self, x, context=None):
Expand All @@ -306,12 +313,12 @@ def __call__(self, x, context=None):
if self.use_linear_attention:
projected_x = nn.Dense(features=inner_dim,
use_bias=False, precision=self.precision,
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init,
dtype=self.dtype, name=f'project_in')(normed_x)
else:
projected_x = nn.Conv(
features=inner_dim, kernel_size=(1, 1),
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init,
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
precision=self.precision, name=f'project_in_conv',
)(normed_x)
Expand All @@ -331,19 +338,21 @@ def __call__(self, x, context=None):
dtype=self.dtype,
use_flash_attention=self.use_flash_attention,
use_cross_only=(not self.use_self_and_cross),
only_pure_attention=self.only_pure_attention
only_pure_attention=self.only_pure_attention,
force_fp32_for_softmax=self.force_fp32_for_softmax,
kernel_init=self.kernel_init
)(projected_x, context)

if self.use_projection == True:
if self.use_linear_attention:
projected_x = nn.Dense(features=C, precision=self.precision,
dtype=self.dtype, use_bias=False,
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init,
name=f'project_out')(projected_x)
else:
projected_x = nn.Conv(
features=C, kernel_size=(1, 1),
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init,
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
precision=self.precision, name=f'project_out_conv',
)(projected_x)
Expand Down
8 changes: 5 additions & 3 deletions flaxdiff/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,17 @@ class ResidualBlock(nn.Module):
kernel_init:Callable=kernel_init(1.0)
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
named_norms:bool=False

def setup(self):
if self.norm_groups > 0:
norm = partial(nn.GroupNorm, self.norm_groups)
self.norm1 = norm(name="GroupNorm_0") if self.named_norms else norm()
self.norm2 = norm(name="GroupNorm_1") if self.named_norms else norm()
else:
norm = partial(nn.RMSNorm, 1e-5)

self.norm1 = norm()
self.norm2 = norm()
self.norm1 = norm()
self.norm2 = norm()

@nn.compact
def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
Expand Down
44 changes: 28 additions & 16 deletions flaxdiff/models/simple_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@ class Unet(nn.Module):
norm_groups:int=8
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms
kernel_init: Callable = partial(kernel_init, dtype=jnp.float32)

def setup(self):
if self.norm_groups > 0:
norm = partial(nn.GroupNorm, self.norm_groups)
self.conv_out_norm = norm(name="GroupNorm_0") if self.named_norms else norm()
else:
norm = partial(nn.RMSNorm, 1e-5)

# self.last_up_norm = norm()
self.conv_out_norm = norm()
self.conv_out_norm = norm()

@nn.compact
def __call__(self, x, temb, textcontext):
Expand All @@ -49,7 +50,7 @@ def __call__(self, x, temb, textcontext):
features=self.feature_depths[0],
kernel_size=(3, 3),
strides=(1, 1),
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
dtype=self.dtype,
precision=self.precision
)(x)
Expand All @@ -64,13 +65,14 @@ def __call__(self, x, temb, textcontext):
down_conv_type,
name=f"down_{i}_residual_{j}",
features=dim_in,
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
kernel_size=(3, 3),
strides=(1, 1),
activation=self.activation,
norm_groups=self.norm_groups,
dtype=self.dtype,
precision=self.precision
precision=self.precision,
named_norms=self.named_norms
)(x, temb)
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
Expand All @@ -80,6 +82,8 @@ def __call__(self, x, temb, textcontext):
use_self_and_cross=attention_config.get("use_self_and_cross", True),
precision=attention_config.get("precision", self.precision),
only_pure_attention=attention_config.get("only_pure_attention", True),
force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
kernel_init=self.kernel_init(1.0),
name=f"down_{i}_attention_{j}")(x, textcontext)
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
downs.append(x)
Expand All @@ -102,13 +106,14 @@ def __call__(self, x, temb, textcontext):
middle_conv_type,
name=f"middle_res1_{j}",
features=middle_dim_out,
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
kernel_size=(3, 3),
strides=(1, 1),
activation=self.activation,
norm_groups=self.norm_groups,
dtype=self.dtype,
precision=self.precision
precision=self.precision,
named_norms=self.named_norms
)(x, temb)
if middle_attention is not None and j == self.num_middle_res_blocks - 1: # Apply attention only on the last block
x = TransformerBlock(heads=middle_attention['heads'], dtype=middle_attention.get('dtype', jnp.float32),
Expand All @@ -119,18 +124,21 @@ def __call__(self, x, temb, textcontext):
use_self_and_cross=False,
precision=middle_attention.get("precision", self.precision),
only_pure_attention=middle_attention.get("only_pure_attention", True),
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
kernel_init=self.kernel_init(1.0),
name=f"middle_attention_{j}")(x, textcontext)
x = ResidualBlock(
middle_conv_type,
name=f"middle_res2_{j}",
features=middle_dim_out,
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
kernel_size=(3, 3),
strides=(1, 1),
activation=self.activation,
norm_groups=self.norm_groups,
dtype=self.dtype,
precision=self.precision
precision=self.precision,
named_norms=self.named_norms
)(x, temb)

# Upscaling Blocks
Expand All @@ -145,13 +153,14 @@ def __call__(self, x, temb, textcontext):
up_conv_type,# if j == 0 else "separable",
name=f"up_{i}_residual_{j}",
features=dim_out,
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
kernel_size=kernel_size,
strides=(1, 1),
activation=self.activation,
norm_groups=self.norm_groups,
dtype=self.dtype,
precision=self.precision
precision=self.precision,
named_norms=self.named_norms
)(x, temb)
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
Expand All @@ -161,6 +170,8 @@ def __call__(self, x, temb, textcontext):
use_self_and_cross=attention_config.get("use_self_and_cross", True),
precision=attention_config.get("precision", self.precision),
only_pure_attention=attention_config.get("only_pure_attention", True),
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
kernel_init=self.kernel_init(1.0),
name=f"up_{i}_attention_{j}")(x, textcontext)
# print("Upscaling ", i, x.shape)
if i != len(feature_depths) - 1:
Expand All @@ -179,7 +190,7 @@ def __call__(self, x, temb, textcontext):
features=self.feature_depths[0],
kernel_size=(3, 3),
strides=(1, 1),
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
dtype=self.dtype,
precision=self.precision
)(x)
Expand All @@ -190,13 +201,14 @@ def __call__(self, x, temb, textcontext):
conv_type,
name="final_residual",
features=self.feature_depths[0],
kernel_init=kernel_init(1.0),
kernel_init=self.kernel_init(1.0),
kernel_size=(3,3),
strides=(1, 1),
activation=self.activation,
norm_groups=self.norm_groups,
dtype=self.dtype,
precision=self.precision
precision=self.precision,
named_norms=self.named_norms
)(x, temb)

x = self.conv_out_norm(x)
Expand All @@ -208,7 +220,7 @@ def __call__(self, x, temb, textcontext):
kernel_size=(3, 3),
strides=(1, 1),
# activation=jax.nn.mish
kernel_init=kernel_init(0.0),
kernel_init=self.kernel_init(0.0),
dtype=self.dtype,
precision=self.precision
)(x)
Expand Down
6 changes: 4 additions & 2 deletions flaxdiff/trainer/diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics

from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
from flax.training.dynamic_scale import DynamicScale

class TrainState(SimpleTrainState):
rngs: jax.random.PRNGKey
Expand Down Expand Up @@ -83,7 +84,8 @@ def generate_states(
new_state = existing_state

if param_transforms is not None:
params = param_transforms(params)
new_state['params'] = param_transforms(new_state['params'])
new_state['ema_params'] = param_transforms(new_state['ema_params'])

state = TrainState.create(
apply_fn=model.apply,
Expand All @@ -92,7 +94,7 @@ def generate_states(
tx=optimizer,
rngs=rngs,
metrics=Metrics.empty(),
dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None
dynamic_scale = DynamicScale() if use_dynamic_scale else None
)

if existing_best_state is not None:
Expand Down
9 changes: 6 additions & 3 deletions flaxdiff/trainer/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
from termcolor import colored
from typing import Dict, Callable, Sequence, Any, Union, Tuple

from flax.training.dynamic_scale import DynamicScale
from flaxdiff.utils import RandomMarkovState

PROCESS_COLOR_MAP = {
Expand Down Expand Up @@ -68,7 +68,7 @@ class Metrics(metrics.Collection):
# Define the TrainState
class SimpleTrainState(train_state.TrainState):
metrics: Metrics
dynamic_scale: flax.training.dynamic_scale.DynamicScale
dynamic_scale: DynamicScale

class SimpleTrainer:
state: SimpleTrainState
Expand Down Expand Up @@ -177,13 +177,16 @@ def generate_states(
params = model.init(subkey, **input_vars)
else:
params = existing_state['params']

if param_transforms is not None:
params = param_transforms(params)

state = SimpleTrainState.create(
apply_fn=model.apply,
params=params,
tx=optimizer,
metrics=Metrics.empty(),
dynamic_scale = flax.training.dynamic_scale.DynamicScale() if use_dynamic_scale else None
dynamic_scale = DynamicScale() if use_dynamic_scale else None
)
if existing_best_state is not None:
best_state = state.replace(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.10',
version='0.1.12',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 3cacef5

Please sign in to comment.