From 36e65f0097c778d9a2716be8a79ed88159fdebfb Mon Sep 17 00:00:00 2001 From: Qing Date: Sat, 23 Nov 2024 18:58:34 +0800 Subject: [PATCH] fix deprecate warning --- .../model/brushnet/brushnet_unet_forward.py | 163 ++++++++++++------ iopaint/model/power_paint/v2/BrushNet_CA.py | 2 +- .../model/power_paint/v2/unet_2d_condition.py | 2 +- 3 files changed, 113 insertions(+), 54 deletions(-) diff --git a/iopaint/model/brushnet/brushnet_unet_forward.py b/iopaint/model/brushnet/brushnet_unet_forward.py index 04e8f0a4a..0372f0c97 100644 --- a/iopaint/model/brushnet/brushnet_unet_forward.py +++ b/iopaint/model/brushnet/brushnet_unet_forward.py @@ -1,28 +1,33 @@ from typing import Union, Optional, Dict, Any, Tuple import torch -from diffusers.models.unet_2d_condition import UNet2DConditionOutput -from diffusers.utils import USE_PEFT_BACKEND, unscale_lora_layers, deprecate, scale_lora_layers +from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput +from diffusers.utils import ( + USE_PEFT_BACKEND, + unscale_lora_layers, + deprecate, + scale_lora_layers, +) def brushnet_unet_forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - mid_block_additional_residual: Optional[torch.Tensor] = None, - down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - down_block_add_samples: Optional[Tuple[torch.Tensor]] = None, - mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None, - up_block_add_samples: Optional[Tuple[torch.Tensor]] = None, + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + down_block_add_samples: Optional[Tuple[torch.Tensor]] = None, + mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None, + up_block_add_samples: Optional[Tuple[torch.Tensor]] = None, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. @@ -82,7 +87,7 @@ def brushnet_unet_forward( # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. - default_overall_up_factor = 2 ** self.num_upsamplers + default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False @@ -112,7 +117,9 @@ def brushnet_unet_forward( # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: - encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(sample.dtype) + ) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary @@ -132,7 +139,9 @@ def brushnet_unet_forward( emb = emb + class_emb aug_emb = self.get_aug_embed( - emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + emb=emb, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, ) if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb @@ -151,25 +160,43 @@ def brushnet_unet_forward( sample = self.conv_in(sample) # 2.5 GLIGEN position net - if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + if ( + cross_attention_kwargs is not None + and cross_attention_kwargs.get("gligen", None) is not None + ): cross_attention_kwargs = cross_attention_kwargs.copy() gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) - is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_controlnet = ( + mid_block_additional_residual is not None + and down_block_additional_residuals is not None + ) # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None # maintain backward compatibility for legacy usage, where # T2I-Adapter and ControlNet both use down_block_additional_residuals arg # but can only use one or the other - is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None - if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + is_brushnet = ( + down_block_add_samples is not None + and mid_block_add_sample is not None + and up_block_add_samples is not None + ) + if ( + not is_adapter + and mid_block_additional_residual is None + and down_block_additional_residuals is not None + ): deprecate( "T2I should not use down_block_additional_residuals", "1.3.0", @@ -187,16 +214,25 @@ def brushnet_unet_forward( sample = sample + down_block_add_samples.pop(0) for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_intrablock_additional_residuals) > 0: - additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + additional_residuals["additional_residuals"] = ( + down_intrablock_additional_residuals.pop(0) + ) if is_brushnet and len(down_block_add_samples) > 0: - additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0) - for _ in range( - len(downsample_block.resnets) + (downsample_block.downsamplers != None))] + additional_residuals["down_block_add_samples"] = [ + down_block_add_samples.pop(0) + for _ in range( + len(downsample_block.resnets) + + (downsample_block.downsamplers != None) + ) + ] sample, res_samples = downsample_block( hidden_states=sample, @@ -210,12 +246,17 @@ def brushnet_unet_forward( else: additional_residuals = {} if is_brushnet and len(down_block_add_samples) > 0: - additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0) - for _ in range( - len(downsample_block.resnets) + (downsample_block.downsamplers != None))] + additional_residuals["down_block_add_samples"] = [ + down_block_add_samples.pop(0) + for _ in range( + len(downsample_block.resnets) + + (downsample_block.downsamplers != None) + ) + ] - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale, - **additional_residuals) + sample, res_samples = downsample_block( + hidden_states=sample, temb=emb, scale=lora_scale, **additional_residuals + ) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) @@ -225,16 +266,23 @@ def brushnet_unet_forward( new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals + down_block_res_samples, down_block_additional_residuals ): - down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples = new_down_block_res_samples + ( + down_block_res_sample, + ) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: - if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + if ( + hasattr(self.mid_block, "has_cross_attention") + and self.mid_block.has_cross_attention + ): sample = self.mid_block( sample, emb, @@ -248,9 +296,9 @@ def brushnet_unet_forward( # To support T2I-Adapter-XL if ( - is_adapter - and len(down_intrablock_additional_residuals) > 0 - and sample.shape == down_intrablock_additional_residuals[0].shape + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape ): sample += down_intrablock_additional_residuals.pop(0) @@ -264,7 +312,7 @@ def brushnet_unet_forward( for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 - res_samples = down_block_res_samples[-len(upsample_block.resnets):] + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the @@ -272,12 +320,19 @@ def brushnet_unet_forward( if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): additional_residuals = {} if is_brushnet and len(up_block_add_samples) > 0: - additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0) - for _ in range( - len(upsample_block.resnets) + (upsample_block.upsamplers != None))] + additional_residuals["up_block_add_samples"] = [ + up_block_add_samples.pop(0) + for _ in range( + len(upsample_block.resnets) + + (upsample_block.upsamplers != None) + ) + ] sample = upsample_block( hidden_states=sample, @@ -293,9 +348,13 @@ def brushnet_unet_forward( else: additional_residuals = {} if is_brushnet and len(up_block_add_samples) > 0: - additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0) - for _ in range( - len(upsample_block.resnets) + (upsample_block.upsamplers != None))] + additional_residuals["up_block_add_samples"] = [ + up_block_add_samples.pop(0) + for _ in range( + len(upsample_block.resnets) + + (upsample_block.upsamplers != None) + ) + ] sample = upsample_block( hidden_states=sample, diff --git a/iopaint/model/power_paint/v2/BrushNet_CA.py b/iopaint/model/power_paint/v2/BrushNet_CA.py index b892c846c..e69ee1481 100644 --- a/iopaint/model/power_paint/v2/BrushNet_CA.py +++ b/iopaint/model/power_paint/v2/BrushNet_CA.py @@ -3,7 +3,7 @@ import torch from diffusers import UNet2DConditionModel -from diffusers.models.unet_2d_blocks import ( +from diffusers.models.unets.unet_2d_blocks import ( get_down_block, get_mid_block, get_up_block, diff --git a/iopaint/model/power_paint/v2/unet_2d_condition.py b/iopaint/model/power_paint/v2/unet_2d_condition.py index 80741de96..7a3165eed 100644 --- a/iopaint/model/power_paint/v2/unet_2d_condition.py +++ b/iopaint/model/power_paint/v2/unet_2d_condition.py @@ -15,7 +15,7 @@ import torch import torch.utils.checkpoint -from diffusers.models.unet_2d_condition import UNet2DConditionOutput +from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput from diffusers.utils import ( USE_PEFT_BACKEND, deprecate,