Skip to content

Commit

Permalink
UNet2DModel mid_block_type
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Jan 6, 2025
1 parent b572635 commit e82743e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 15 deletions.
35 changes: 20 additions & 15 deletions src/diffusers/models/unets/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
Expand Down Expand Up @@ -103,6 +103,7 @@ def __init__(
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
mid_block_type: Optional[str] = "UNetMidBlock2D",
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
layers_per_block: int = 2,
Expand Down Expand Up @@ -194,19 +195,22 @@ def __init__(
self.down_blocks.append(down_block)

# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)
if mid_block_type is None:
self.mid_block = None
else:
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
dropout=dropout,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
)

# up
reversed_block_out_channels = list(reversed(block_out_channels))
Expand Down Expand Up @@ -322,7 +326,8 @@ def forward(
down_block_res_samples += res_samples

# 4. mid
sample = self.mid_block(sample, emb)
if self.mid_block is not None:
sample = self.mid_block(sample, emb)

# 5. up
skip_sample = None
Expand Down
29 changes: 29 additions & 0 deletions tests/models/unets/test_models_unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,35 @@ def test_mid_block_attn_groups(self):
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

def test_mid_block_none(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict["mid_block_type"] = None

model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()

mid_none_model = self.model_class(**mid_none_init_dict)
mid_none_model.to(torch_device)
mid_none_model.eval()

self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")

with torch.no_grad():
output = model(**inputs_dict)

if isinstance(output, dict):
output = output.to_tuple()[0]

with torch.no_grad():
mid_none_output = mid_none_model(**mid_none_inputs_dict)

if isinstance(mid_none_output, dict):
mid_none_output = mid_none_output.to_tuple()[0]

self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")

def test_gradient_checkpointing_is_applied(self):
expected_set = {
"AttnUpBlock2D",
Expand Down

0 comments on commit e82743e

Please sign in to comment.