From fffc32e159506e09a20fc1fa1b1e4b6932fdc3f7 Mon Sep 17 00:00:00 2001 From: algodd Date: Wed, 4 Sep 2024 14:34:47 +0000 Subject: [PATCH] Adds IPAdapter support by switching attention to IpAdapter attention --- tgate/PixArt_Alpha.py | 4 +- tgate/PixArt_Sigma.py | 4 +- tgate/SD.py | 4 +- tgate/SDXL.py | 4 +- tgate/SDXL_DeepCache.py | 4 +- tgate/SD_DeepCache.py | 4 +- tgate/SVD.py | 2 +- tgate/tgate_utils.py | 158 +++++++++++++++++++++++++++++++++++++++- 8 files changed, 168 insertions(+), 16 deletions(-) diff --git a/tgate/PixArt_Alpha.py b/tgate/PixArt_Alpha.py index 6dafe87..e15380b 100644 --- a/tgate/PixArt_Alpha.py +++ b/tgate/PixArt_Alpha.py @@ -254,7 +254,7 @@ def tgate( # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) register_forward(self.transformer, - 'Attention', + 'IPAdapterAttnProcessor2_0', ca_kward = { 'cache': False, 'reuse': False, @@ -304,7 +304,7 @@ def tgate( ) keep_shape = keep_shape if not lcm else lcm register_forward(self.transformer, - 'Attention', + 'IPAdapterAttnProcessor2_0', ca_kward=ca_kwards, sa_kward=sa_kwards, keep_shape=keep_shape diff --git a/tgate/PixArt_Sigma.py b/tgate/PixArt_Sigma.py index aed28f4..d040f18 100644 --- a/tgate/PixArt_Sigma.py +++ b/tgate/PixArt_Sigma.py @@ -242,7 +242,7 @@ def tgate( # 7. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) register_forward(self.transformer, - 'Attention', + 'IPAdapterAttnProcessor2_0', ca_kward = { 'cache': False, 'reuse': False, @@ -292,7 +292,7 @@ def tgate( ) keep_shape = keep_shape if not lcm else lcm register_forward(self.transformer, - 'Attention', + 'IPAdapterAttnProcessor2_0', ca_kward=ca_kwards, sa_kward=sa_kwards, keep_shape=keep_shape diff --git a/tgate/SD.py b/tgate/SD.py index 53472f9..4ec9d44 100644 --- a/tgate/SD.py +++ b/tgate/SD.py @@ -276,7 +276,7 @@ def tgate( self._num_timesteps = len(timesteps) register_forward(self.unet, - 'Attention', + 'IPAdapterAttnProcessor2_0', ca_kward = { 'cache': False, 'reuse': False, @@ -311,7 +311,7 @@ def tgate( warm_up=warm_up ) register_forward(self.unet, - 'Attention', + 'IPAdapterAttnProcessor2_0', ca_kward=ca_kwards, sa_kward=sa_kwards, keep_shape=keep_shape diff --git a/tgate/SDXL.py b/tgate/SDXL.py index 32b6c05..33b5953 100644 --- a/tgate/SDXL.py +++ b/tgate/SDXL.py @@ -417,7 +417,7 @@ def tgate( self._num_timesteps = len(timesteps) register_forward(self.unet, - 'Attention', + 'IPAdapterAttnProcessor2_0', ca_kward = { 'cache': False, 'reuse': False, @@ -460,7 +460,7 @@ def tgate( ) keep_shape = keep_shape if not lcm else lcm register_forward(self.unet, - 'Attention', + 'IPAdapterAttnProcessor2_0', ca_kward=ca_kwards, sa_kward=sa_kwards, keep_shape=keep_shape diff --git a/tgate/SDXL_DeepCache.py b/tgate/SDXL_DeepCache.py index f519f09..c2074e9 100644 --- a/tgate/SDXL_DeepCache.py +++ b/tgate/SDXL_DeepCache.py @@ -416,7 +416,7 @@ def tgate( self._num_timesteps = len(timesteps) register_forward(self.unet, - 'Attention', + 'IPAdapterAttnProcessor2_0', ca_kward = { 'cache': False, 'reuse': False, @@ -466,7 +466,7 @@ def tgate( 'reuse': False, } register_forward(self.unet, - 'Attention', + 'IPAdapterAttnProcessor2_0', ca_kward=ca_kwards, sa_kward=sa_kwards, keep_shape=keep_shape diff --git a/tgate/SD_DeepCache.py b/tgate/SD_DeepCache.py index ea54db0..f10149f 100644 --- a/tgate/SD_DeepCache.py +++ b/tgate/SD_DeepCache.py @@ -276,7 +276,7 @@ def tgate( self._num_timesteps = len(timesteps) register_forward(self.unet, - 'Attention', + 'IPAdapterAttnProcessor2_0', ca_kward = { 'cache': False, 'reuse': False, @@ -315,7 +315,7 @@ def tgate( warm_up=warm_up ) register_forward(self.unet, - 'Attention', + 'IPAdapterAttnProcessor2_0', ca_kward=ca_kwards, sa_kward=sa_kwards, keep_shape=keep_shape diff --git a/tgate/SVD.py b/tgate/SVD.py index 3850910..bb05285 100644 --- a/tgate/SVD.py +++ b/tgate/SVD.py @@ -244,7 +244,7 @@ def tgate( num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order register_forward(self.unet, - 'Attention', + 'IPAdapterAttnProcessor2_0', ca_kward = { 'cache': False, 'reuse': False, diff --git a/tgate/tgate_utils.py b/tgate/tgate_utils.py index 516d93c..94bf57c 100644 --- a/tgate/tgate_utils.py +++ b/tgate/tgate_utils.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +from diffusers.image_processor import IPAdapterMaskProcessor from diffusers.utils import ( BACKENDS_MAPPING, deprecate, @@ -79,7 +80,12 @@ def forward( if not hasattr(self,'cache'): self.cache = None attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) - unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters] + quiet_attn_parameters = {"ip_adapter_masks"} + unused_kwargs = [ + k + for k, _ in cross_attention_kwargs.items() + if k not in attn_parameters and k not in quiet_attn_parameters + ] if len(unused_kwargs) > 0: logger.warning( @@ -136,12 +142,13 @@ def tgate_processor( sa_cache = False, ca_reuse = False, sa_reuse = False, + ip_adapter_masks: Optional[torch.Tensor] = None, *args, **kwargs, ) -> torch.FloatTensor: r""" - A customized forward function of the `AttnProcessor2_0` class. + A customized forward function of the `IPAdapterAttnProcessor2_0` class. Args: hidden_states (`torch.Tensor`): @@ -150,6 +157,8 @@ def tgate_processor( The hidden states of the encoder. attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + ip_adapter_masks (`torch.Tensor`, *optional*): + The IP adapter masks to use. **cross_attention_kwargs: Additional keyword arguments to pass along to the cross attention. @@ -158,7 +167,7 @@ def tgate_processor( """ if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError("IPAdapterAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." @@ -166,6 +175,28 @@ def tgate_processor( residual = hidden_states + # separate ip_hidden_states from encoder_hidden_states + if encoder_hidden_states is not None: + if isinstance(encoder_hidden_states, tuple): + encoder_hidden_states, ip_hidden_states = encoder_hidden_states + else: + deprecation_message = ( + "You have passed a tensor as `encoder_hidden_states`. This is" + " deprecated and will be removed in a future release. Please make sure" + " to update your script to pass `encoder_hidden_states` as a tuple to" + " suppress this warning." + ) + deprecate( + "encoder_hidden_states not a tuple", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + end_pos = encoder_hidden_states.shape[1] - attn.num_tokens[0] + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + [encoder_hidden_states[:, end_pos:, :]], + ) cross_attn = encoder_hidden_states is not None self_attn = encoder_hidden_states is None @@ -225,6 +256,127 @@ def tgate_processor( hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) + if ip_adapter_masks is not None: + if not isinstance(ip_adapter_masks, List): + # for backward compatibility, we accept `ip_adapter_mask` as a tensor + # of shape [num_ip_adapter, 1, height, width] + ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1)) + if not (len(ip_adapter_masks) == len(attn.scale) == len(ip_hidden_states)): + raise ValueError( + f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must" + f" match length of self.scale array ({len(attn.scale)}) and number" + f" of ip_hidden_states ({len(ip_hidden_states)})" + ) + else: + for index, (mask, _scale, ip_state) in enumerate( + zip(ip_adapter_masks, attn.scale, ip_hidden_states) + ): + if not isinstance(mask, torch.Tensor) or mask.ndim != 4: + raise ValueError( + "Each element of the ip_adapter_masks array should be a" + " tensor with shape [1, num_images_for_ip_adapter, height," + " width]. Please use `IPAdapterMaskProcessor` to preprocess" + " your mask" + ) + if mask.shape[1] != ip_state.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match number" + f" of ip images ({ip_state.shape[1]}) at index {index}" + ) + if isinstance(_scale, list) and len(_scale) != mask.shape[1]: + raise ValueError( + f"Number of masks ({mask.shape[1]}) does not match " + f"number of scales ({len(_scale)}) at index {index}" + ) + else: + ip_adapter_masks = [None] * len(attn.scale) + + # for ip-adapter + for current_ip_hidden_states, _scale, to_k_ip, to_v_ip, mask in zip( + ip_hidden_states, attn.scale, attn.to_k_ip, attn.to_v_ip, ip_adapter_masks + ): + skip = False + if isinstance(_scale, list): + if all(s == 0 for s in _scale): + skip = True + elif _scale == 0: + skip = True + if not skip: + if mask is not None: + if not isinstance(_scale, list): + _scale = [_scale] * mask.shape[1] + + current_num_images = mask.shape[1] + for i in range(current_num_images): + ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :]) + ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :]) + + ip_key = ip_key.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + ip_value = ip_value.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + _current_ip_hidden_states = F.scaled_dot_product_attention( + query, + ip_key, + ip_value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ) + + _current_ip_hidden_states = _current_ip_hidden_states.transpose( + 1, 2 + ).reshape(batch_size, -1, attn.heads * head_dim) + _current_ip_hidden_states = _current_ip_hidden_states.to( + query.dtype + ) + + mask_downsample = IPAdapterMaskProcessor.downsample( + mask[:, i, :, :], + batch_size, + _current_ip_hidden_states.shape[1], + _current_ip_hidden_states.shape[2], + ) + + mask_downsample = mask_downsample.to( + dtype=query.dtype, device=query.device + ) + hidden_states = hidden_states + _scale[i] * ( + _current_ip_hidden_states * mask_downsample + ) + else: + ip_key = to_k_ip(current_ip_hidden_states) + ip_value = to_v_ip(current_ip_hidden_states) + + ip_key = ip_key.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + ip_value = ip_value.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + current_ip_hidden_states = F.scaled_dot_product_attention( + query, + ip_key, + ip_value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + ) + + current_ip_hidden_states = current_ip_hidden_states.transpose( + 1, 2 + ).reshape(batch_size, -1, attn.heads * head_dim) + current_ip_hidden_states = current_ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + _scale * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout