Skip to content

Commit

Permalink
fix: attempted to fix unet bug
Browse files Browse the repository at this point in the history
  • Loading branch information
AshishKumar4 committed Oct 2, 2024
1 parent 2f06494 commit c955648
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 13 deletions.
1 change: 1 addition & 0 deletions flaxdiff/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import functools
import math
from .common import kernel_init
import jax.experimental.pallas.ops.tpu.flash_attention

class EfficientAttention(nn.Module):
"""
Expand Down
22 changes: 11 additions & 11 deletions flaxdiff/models/simple_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __call__(self, x, temb, textcontext):
features=self.feature_depths[0],
kernel_size=(3, 3),
strides=(1, 1),
kernel_init=self.kernel_init(1.0),
kernel_init=self.kernel_init(scale=1.0),
dtype=self.dtype,
precision=self.precision
)(x)
Expand All @@ -65,7 +65,7 @@ def __call__(self, x, temb, textcontext):
down_conv_type,
name=f"down_{i}_residual_{j}",
features=dim_in,
kernel_init=self.kernel_init(1.0),
kernel_init=self.kernel_init(scale=1.0),
kernel_size=(3, 3),
strides=(1, 1),
activation=self.activation,
Expand All @@ -85,7 +85,7 @@ def __call__(self, x, temb, textcontext):
force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
norm_inputs=attention_config.get("norm_inputs", True),
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
kernel_init=self.kernel_init(1.0),
kernel_init=self.kernel_init(scale=1.0),
name=f"down_{i}_attention_{j}")(x, textcontext)
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
downs.append(x)
Expand All @@ -108,7 +108,7 @@ def __call__(self, x, temb, textcontext):
middle_conv_type,
name=f"middle_res1_{j}",
features=middle_dim_out,
kernel_init=self.kernel_init(1.0),
kernel_init=self.kernel_init(scale=1.0),
kernel_size=(3, 3),
strides=(1, 1),
activation=self.activation,
Expand All @@ -129,13 +129,13 @@ def __call__(self, x, temb, textcontext):
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
norm_inputs=middle_attention.get("norm_inputs", True),
explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
kernel_init=self.kernel_init(1.0),
kernel_init=self.kernel_init(scale=1.0),
name=f"middle_attention_{j}")(x, textcontext)
x = ResidualBlock(
middle_conv_type,
name=f"middle_res2_{j}",
features=middle_dim_out,
kernel_init=self.kernel_init(1.0),
kernel_init=self.kernel_init(scale=1.0),
kernel_size=(3, 3),
strides=(1, 1),
activation=self.activation,
Expand All @@ -157,7 +157,7 @@ def __call__(self, x, temb, textcontext):
up_conv_type,# if j == 0 else "separable",
name=f"up_{i}_residual_{j}",
features=dim_out,
kernel_init=self.kernel_init(1.0),
kernel_init=self.kernel_init(scale=1.0),
kernel_size=kernel_size,
strides=(1, 1),
activation=self.activation,
Expand All @@ -177,7 +177,7 @@ def __call__(self, x, temb, textcontext):
force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
norm_inputs=attention_config.get("norm_inputs", True),
explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
kernel_init=self.kernel_init(1.0),
kernel_init=self.kernel_init(scale=1.0),
name=f"up_{i}_attention_{j}")(x, textcontext)
# print("Upscaling ", i, x.shape)
if i != len(feature_depths) - 1:
Expand All @@ -196,7 +196,7 @@ def __call__(self, x, temb, textcontext):
features=self.feature_depths[0],
kernel_size=(3, 3),
strides=(1, 1),
kernel_init=self.kernel_init(1.0),
kernel_init=self.kernel_init(scale=1.0),
dtype=self.dtype,
precision=self.precision
)(x)
Expand All @@ -207,7 +207,7 @@ def __call__(self, x, temb, textcontext):
conv_type,
name="final_residual",
features=self.feature_depths[0],
kernel_init=self.kernel_init(1.0),
kernel_init=self.kernel_init(scale=1.0),
kernel_size=(3,3),
strides=(1, 1),
activation=self.activation,
Expand All @@ -226,7 +226,7 @@ def __call__(self, x, temb, textcontext):
kernel_size=(3, 3),
strides=(1, 1),
# activation=jax.nn.mish
kernel_init=self.kernel_init(0.0),
kernel_init=self.kernel_init(scale=0.0),
dtype=self.dtype,
precision=self.precision
)(x)
Expand Down
2 changes: 1 addition & 1 deletion flaxdiff/models/simple_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class UViT(nn.Module):
kernel_init: Callable = partial(kernel_init, scale=1.0)
add_residualblock_output: bool = False
norm_inputs: bool = False
explicitly_add_residual: bool = False
explicitly_add_residual: bool = True

def setup(self):
if self.norm_groups > 0:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
setup(
name='flaxdiff',
packages=find_packages(),
version='0.1.35.5',
version='0.1.35.6',
description='A versatile and easy to understand Diffusion library',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
Expand Down

0 comments on commit c955648

Please sign in to comment.