Skip to content

Commit

Permalink
feat: fixed transformerblock and added backward compatibility options
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Sep 19, 2024
1 parent 3a37ea6 commit 2f06494
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 33 deletions.
147 changes: 123 additions & 24 deletions evaluate.ipynb

Large diffs are not rendered by default.

18 changes: 12 additions & 6 deletions flaxdiff/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,27 +303,30 @@ class TransformerBlock(nn.Module):
only_pure_attention:bool = False
force_fp32_for_softmax: bool = True
kernel_init: Callable = kernel_init(1.0)
norm_inputs: bool = True
explicitly_add_residual: bool = True

@nn.compact
def __call__(self, x, context=None):
inner_dim = self.heads * self.dim_head
C = x.shape[-1]
normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
if self.norm_inputs:
x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
if self.use_projection == True:
if self.use_linear_attention:
projected_x = nn.Dense(features=inner_dim,
use_bias=False, precision=self.precision,
kernel_init=self.kernel_init,
dtype=self.dtype, name=f'project_in')(normed_x)
dtype=self.dtype, name=f'project_in')(x)
else:
projected_x = nn.Conv(
features=inner_dim, kernel_size=(1, 1),
kernel_init=self.kernel_init,
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
precision=self.precision, name=f'project_in_conv',
)(normed_x)
)(x)
else:
projected_x = normed_x
projected_x = x
inner_dim = C

context = projected_x if context is None else context
Expand Down Expand Up @@ -356,6 +359,9 @@ def __call__(self, x, context=None):
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
precision=self.precision, name=f'project_out_conv',
)(projected_x)

out = x + projected_x

if self.only_pure_attention or self.explicitly_add_residual:
projected_x = x + projected_x

out = projected_x
return out
6 changes: 6 additions & 0 deletions flaxdiff/models/simple_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __call__(self, x, temb, textcontext):
precision=attention_config.get("precision", self.precision),
only_pure_attention=attention_config.get("only_pure_attention", True),
force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
norm_inputs=attention_config.get("norm_inputs", True),
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
kernel_init=self.kernel_init(1.0),
name=f"down_{i}_attention_{j}")(x, textcontext)
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
Expand Down Expand Up @@ -125,6 +127,8 @@ def __call__(self, x, temb, textcontext):
precision=middle_attention.get("precision", self.precision),
only_pure_attention=middle_attention.get("only_pure_attention", True),
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
norm_inputs=middle_attention.get("norm_inputs", True),
explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
kernel_init=self.kernel_init(1.0),
name=f"middle_attention_{j}")(x, textcontext)
x = ResidualBlock(
Expand Down Expand Up @@ -171,6 +175,8 @@ def __call__(self, x, temb, textcontext):
precision=attention_config.get("precision", self.precision),
only_pure_attention=attention_config.get("only_pure_attention", True),
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
norm_inputs=attention_config.get("norm_inputs", True),
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
kernel_init=self.kernel_init(1.0),
name=f"up_{i}_attention_{j}")(x, textcontext)
# print("Upscaling ", i, x.shape)
Expand Down
12 changes: 10 additions & 2 deletions flaxdiff/models/simple_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class UViT(nn.Module):
precision: PrecisionLike = None
kernel_init: Callable = partial(kernel_init, scale=1.0)
add_residualblock_output: bool = False
norm_inputs: bool = False
explicitly_add_residual: bool = False

def setup(self):
if self.norm_groups > 0:
Expand Down Expand Up @@ -110,16 +112,20 @@ def __call__(self, x, temb, textcontext=None):
for i in range(self.num_layers // 2):
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
only_pure_attention=False,
norm_inputs=self.norm_inputs,
explicitly_add_residual=self.explicitly_add_residual,
kernel_init=self.kernel_init())(x)
skips.append(x)

# Middle block
x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
only_pure_attention=False,
norm_inputs=self.norm_inputs,
explicitly_add_residual=self.explicitly_add_residual,
kernel_init=self.kernel_init())(x)

# # Out blocks
Expand All @@ -131,6 +137,8 @@ def __call__(self, x, temb, textcontext=None):
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
only_pure_attention=False,
norm_inputs=self.norm_inputs,
explicitly_add_residual=self.explicitly_add_residual,
kernel_init=self.kernel_init())(x)

# print(f'Shape of x after transformer blocks: {x.shape}')
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.35.4',
version='0.1.35.5',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 2f06494

Please sign in to comment.