From e82743e44ed9a0a02e2c18bb34685bb4abfd72d8 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 6 Jan 2025 10:11:57 +0000 Subject: [PATCH] UNet2DModel mid_block_type --- src/diffusers/models/unets/unet_2d.py | 35 +++++++++++++---------- tests/models/unets/test_models_unet_2d.py | 29 +++++++++++++++++++ 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index bec62ce5cf45..090357237f46 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block types. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): - Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. + Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`. up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): @@ -103,6 +103,7 @@ def __init__( freq_shift: int = 0, flip_sin_to_cos: bool = True, down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), + mid_block_type: Optional[str] = "UNetMidBlock2D", up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), block_out_channels: Tuple[int, ...] = (224, 448, 672, 896), layers_per_block: int = 2, @@ -194,19 +195,22 @@ def __init__( self.down_blocks.append(down_block) # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - dropout=dropout, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], - resnet_groups=norm_num_groups, - attn_groups=attn_norm_num_groups, - add_attention=add_attention, - ) + if mid_block_type is None: + self.mid_block = None + else: + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], + resnet_groups=norm_num_groups, + attn_groups=attn_norm_num_groups, + add_attention=add_attention, + ) # up reversed_block_out_channels = list(reversed(block_out_channels)) @@ -322,7 +326,8 @@ def forward( down_block_res_samples += res_samples # 4. mid - sample = self.mid_block(sample, emb) + if self.mid_block is not None: + sample = self.mid_block(sample, emb) # 5. up skip_sample = None diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index ddf5f53511f7..a39b36ee20cc 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -105,6 +105,35 @@ def test_mid_block_attn_groups(self): expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_mid_block_none(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common() + mid_none_init_dict["mid_block_type"] = None + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + mid_none_model = self.model_class(**mid_none_init_dict) + mid_none_model.to(torch_device) + mid_none_model.eval() + + self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.") + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + with torch.no_grad(): + mid_none_output = mid_none_model(**mid_none_inputs_dict) + + if isinstance(mid_none_output, dict): + mid_none_output = mid_none_output.to_tuple()[0] + + self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.") + def test_gradient_checkpointing_is_applied(self): expected_set = { "AttnUpBlock2D",