Skip to content

Commit

Permalink
[refactor embeddings] gligen + ip-adapter (#6244)
Browse files Browse the repository at this point in the history
* refactor ip-adapter-imageproj, gligen

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
  • Loading branch information
yiyixuxu and yiyixuxu authored Dec 28, 2023
1 parent 1ac07d8 commit 4c483de
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 27 deletions.
8 changes: 4 additions & 4 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn

from ..models.embeddings import ImageProjection, MLPProjection, Resampler
from ..models.embeddings import ImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
USE_PEFT_BACKEND,
Expand Down Expand Up @@ -712,7 +712,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
cross_attention_dim = state_dict["proj.3.weight"].shape[0]

image_projection = MLPProjection(
image_projection = IPAdapterFullImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
)

Expand All @@ -730,7 +730,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict):
hidden_dims = state_dict["latents"].shape[2]
heads = state_dict["layers.0.0.to_q.weight"].shape[0] // 64

image_projection = Resampler(
image_projection = IPAdapterPlusImageProjection(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
Expand Down Expand Up @@ -780,7 +780,7 @@ def _load_ip_adapter_weights(self, state_dict):
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]

# Set encoder_hid_proj after loading ip_adapter weights,
# because `Resampler` also has `attn_processors`.
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None

# set ip-adapter cross-attention processors & load state_dict
Expand Down
37 changes: 21 additions & 16 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def forward(self, image_embeds: torch.FloatTensor):
return image_embeds


class MLPProjection(nn.Module):
class IPAdapterFullImageProjection(nn.Module):
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024):
super().__init__()
from .attention import FeedForward
Expand Down Expand Up @@ -621,29 +621,34 @@ def shape(x):
return a[:, 0, :] # cls_token


class FourierEmbedder(nn.Module):
def __init__(self, num_freqs=64, temperature=100):
super().__init__()
def get_fourier_embeds_from_boundingbox(embed_dim, box):
"""
Args:
embed_dim: int
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
Returns:
[B x N x embed_dim] tensor of positional embeddings
"""

batch_size, num_boxes = box.shape[:2]

self.num_freqs = num_freqs
self.temperature = temperature
emb = 100 ** (torch.arange(embed_dim) / embed_dim)
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
emb = emb * box.unsqueeze(-1)

freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
freq_bands = freq_bands[None, None, None]
self.register_buffer("freq_bands", freq_bands, persistent=False)
emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)

def __call__(self, x):
x = self.freq_bands * x.unsqueeze(-1)
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
return emb


class PositionNet(nn.Module):
class GLIGENTextBoundingboxProjection(nn.Module):
def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8):
super().__init__()
self.positive_len = positive_len
self.out_dim = out_dim

self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.fourier_embedder_dim = fourier_freqs
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy

if isinstance(out_dim, tuple):
Expand Down Expand Up @@ -692,7 +697,7 @@ def forward(
masks = masks.unsqueeze(-1)

# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C

# learnable null embedding
xyxy_null = self.null_position_feature.view(1, 1, -1)
Expand Down Expand Up @@ -787,7 +792,7 @@ def forward(self, caption):
return hidden_states


class Resampler(nn.Module):
class IPAdapterPlusImageProjection(nn.Module):
"""Resampler of IP-Adapter Plus.
Args:
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
)
from .embeddings import (
GaussianFourierProjection,
GLIGENTextBoundingboxProjection,
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
PositionNet,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
Expand Down Expand Up @@ -615,7 +615,7 @@ def __init__(
positive_len = cross_attention_dim[0]

feature_type = "text-only" if attention_type == "gated" else "text-image"
self.position_net = PositionNet(
self.position_net = GLIGENTextBoundingboxProjection(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def __call__(self, x):
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)


class PositionNet(nn.Module):
class GLIGENTextBoundingboxProjection(nn.Module):
def __init__(self, positive_len, out_dim, feature_type, fourier_freqs=8):
super().__init__()
self.positive_len = positive_len
Expand Down Expand Up @@ -820,7 +820,7 @@ def __init__(
positive_len = cross_attention_dim[0]

feature_type = "text-only" if attention_type == "gated" else "text-image"
self.position_net = PositionNet(
self.position_net = GLIGENTextBoundingboxProjection(
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def __call__(
)
gligen_phrases = gligen_phrases[:max_objs]
gligen_boxes = gligen_boxes[:max_objs]
# prepare batched input to the PositionNet (boxes, phrases, mask)
# prepare batched input to the GLIGENTextBoundingboxProjection (boxes, phrases, mask)
# Get tokens for phrases from pre-trained CLIPTokenizer
tokenizer_inputs = self.tokenizer(gligen_phrases, padding=True, return_tensors="pt").to(device)
# For the token, we use the same pre-trained text encoder
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from diffusers import UNet2DConditionModel
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, IPAdapterAttnProcessor
from diffusers.models.embeddings import ImageProjection, Resampler
from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
Expand Down Expand Up @@ -133,7 +133,7 @@ def create_ip_adapter_plus_state_dict(model):

# "image_proj" (ImageProjection layer weights)
cross_attention_dim = model.config["cross_attention_dim"]
image_projection = Resampler(
image_projection = IPAdapterPlusImageProjection(
embed_dims=cross_attention_dim, output_dims=cross_attention_dim, dim_head=32, heads=2, num_queries=4
)

Expand Down

0 comments on commit 4c483de

Please sign in to comment.