Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IPAdapters support for Tgate #20

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tgate/PixArt_Alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def tgate(
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
register_forward(self.transformer,
'Attention',
'IPAdapterAttnProcessor2_0',
ca_kward = {
'cache': False,
'reuse': False,
Expand Down Expand Up @@ -304,7 +304,7 @@ def tgate(
)
keep_shape = keep_shape if not lcm else lcm
register_forward(self.transformer,
'Attention',
'IPAdapterAttnProcessor2_0',
ca_kward=ca_kwards,
sa_kward=sa_kwards,
keep_shape=keep_shape
Expand Down
4 changes: 2 additions & 2 deletions tgate/PixArt_Sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def tgate(
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
register_forward(self.transformer,
'Attention',
'IPAdapterAttnProcessor2_0',
ca_kward = {
'cache': False,
'reuse': False,
Expand Down Expand Up @@ -292,7 +292,7 @@ def tgate(
)
keep_shape = keep_shape if not lcm else lcm
register_forward(self.transformer,
'Attention',
'IPAdapterAttnProcessor2_0',
ca_kward=ca_kwards,
sa_kward=sa_kwards,
keep_shape=keep_shape
Expand Down
4 changes: 2 additions & 2 deletions tgate/SD.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def tgate(
self._num_timesteps = len(timesteps)

register_forward(self.unet,
'Attention',
'IPAdapterAttnProcessor2_0',
ca_kward = {
'cache': False,
'reuse': False,
Expand Down Expand Up @@ -311,7 +311,7 @@ def tgate(
warm_up=warm_up
)
register_forward(self.unet,
'Attention',
'IPAdapterAttnProcessor2_0',
ca_kward=ca_kwards,
sa_kward=sa_kwards,
keep_shape=keep_shape
Expand Down
4 changes: 2 additions & 2 deletions tgate/SDXL.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def tgate(
self._num_timesteps = len(timesteps)

register_forward(self.unet,
'Attention',
'IPAdapterAttnProcessor2_0',
ca_kward = {
'cache': False,
'reuse': False,
Expand Down Expand Up @@ -460,7 +460,7 @@ def tgate(
)
keep_shape = keep_shape if not lcm else lcm
register_forward(self.unet,
'Attention',
'IPAdapterAttnProcessor2_0',
ca_kward=ca_kwards,
sa_kward=sa_kwards,
keep_shape=keep_shape
Expand Down
4 changes: 2 additions & 2 deletions tgate/SDXL_DeepCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,7 @@ def tgate(
self._num_timesteps = len(timesteps)

register_forward(self.unet,
'Attention',
'IPAdapterAttnProcessor2_0',
ca_kward = {
'cache': False,
'reuse': False,
Expand Down Expand Up @@ -466,7 +466,7 @@ def tgate(
'reuse': False,
}
register_forward(self.unet,
'Attention',
'IPAdapterAttnProcessor2_0',
ca_kward=ca_kwards,
sa_kward=sa_kwards,
keep_shape=keep_shape
Expand Down
4 changes: 2 additions & 2 deletions tgate/SD_DeepCache.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def tgate(
self._num_timesteps = len(timesteps)

register_forward(self.unet,
'Attention',
'IPAdapterAttnProcessor2_0',
ca_kward = {
'cache': False,
'reuse': False,
Expand Down Expand Up @@ -315,7 +315,7 @@ def tgate(
warm_up=warm_up
)
register_forward(self.unet,
'Attention',
'IPAdapterAttnProcessor2_0',
ca_kward=ca_kwards,
sa_kward=sa_kwards,
keep_shape=keep_shape
Expand Down
2 changes: 1 addition & 1 deletion tgate/SVD.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def tgate(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

register_forward(self.unet,
'Attention',
'IPAdapterAttnProcessor2_0',
ca_kward = {
'cache': False,
'reuse': False,
Expand Down
158 changes: 155 additions & 3 deletions tgate/tgate_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
from diffusers.image_processor import IPAdapterMaskProcessor
from diffusers.utils import (
BACKENDS_MAPPING,
deprecate,
Expand Down Expand Up @@ -79,7 +80,12 @@ def forward(
if not hasattr(self,'cache'):
self.cache = None
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
quiet_attn_parameters = {"ip_adapter_masks"}
unused_kwargs = [
k
for k, _ in cross_attention_kwargs.items()
if k not in attn_parameters and k not in quiet_attn_parameters
]

if len(unused_kwargs) > 0:
logger.warning(
Expand Down Expand Up @@ -136,12 +142,13 @@ def tgate_processor(
sa_cache = False,
ca_reuse = False,
sa_reuse = False,
ip_adapter_masks: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:

r"""
A customized forward function of the `AttnProcessor2_0` class.
A customized forward function of the `IPAdapterAttnProcessor2_0` class.

Args:
hidden_states (`torch.Tensor`):
Expand All @@ -150,6 +157,8 @@ def tgate_processor(
The hidden states of the encoder.
attention_mask (`torch.Tensor`, *optional*):
The attention mask to use. If `None`, no mask is applied.
ip_adapter_masks (`torch.Tensor`, *optional*):
The IP adapter masks to use.
**cross_attention_kwargs:
Additional keyword arguments to pass along to the cross attention.

Expand All @@ -158,14 +167,36 @@ def tgate_processor(
"""

if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
raise ImportError("IPAdapterAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)

residual = hidden_states

# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`. This is"
" deprecated and will be removed in a future release. Please make sure"
" to update your script to pass `encoder_hidden_states` as a tuple to"
" suppress this warning."
)
deprecate(
"encoder_hidden_states not a tuple",
"1.0.0",
deprecation_message,
standard_warn=False,
)
end_pos = encoder_hidden_states.shape[1] - attn.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)
cross_attn = encoder_hidden_states is not None
self_attn = encoder_hidden_states is None

Expand Down Expand Up @@ -225,6 +256,127 @@ def tgate_processor(
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor
# of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(attn.scale) == len(ip_hidden_states)):
raise ValueError(
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must"
f" match length of self.scale array ({len(attn.scale)}) and number"
f" of ip_hidden_states ({len(ip_hidden_states)})"
)
else:
for index, (mask, _scale, ip_state) in enumerate(
zip(ip_adapter_masks, attn.scale, ip_hidden_states)
):
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a"
" tensor with shape [1, num_images_for_ip_adapter, height,"
" width]. Please use `IPAdapterMaskProcessor` to preprocess"
" your mask"
)
if mask.shape[1] != ip_state.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match number"
f" of ip images ({ip_state.shape[1]}) at index {index}"
)
if isinstance(_scale, list) and len(_scale) != mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(_scale)}) at index {index}"
)
else:
ip_adapter_masks = [None] * len(attn.scale)

# for ip-adapter
for current_ip_hidden_states, _scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, attn.scale, attn.to_k_ip, attn.to_v_ip, ip_adapter_masks
):
skip = False
if isinstance(_scale, list):
if all(s == 0 for s in _scale):
skip = True
elif _scale == 0:
skip = True
if not skip:
if mask is not None:
if not isinstance(_scale, list):
_scale = [_scale] * mask.shape[1]

current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])

ip_key = ip_key.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
ip_value = ip_value.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
_current_ip_hidden_states = F.scaled_dot_product_attention(
query,
ip_key,
ip_value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
)

_current_ip_hidden_states = _current_ip_hidden_states.transpose(
1, 2
).reshape(batch_size, -1, attn.heads * head_dim)
_current_ip_hidden_states = _current_ip_hidden_states.to(
query.dtype
)

mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)

mask_downsample = mask_downsample.to(
dtype=query.dtype, device=query.device
)
hidden_states = hidden_states + _scale[i] * (
_current_ip_hidden_states * mask_downsample
)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = ip_key.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
ip_value = ip_value.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
current_ip_hidden_states = F.scaled_dot_product_attention(
query,
ip_key,
ip_value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
)

current_ip_hidden_states = current_ip_hidden_states.transpose(
1, 2
).reshape(batch_size, -1, attn.heads * head_dim)
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)

hidden_states = hidden_states + _scale * current_ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
Expand Down