Skip to content

Commit

Permalink
Fix IP Adapter Support for SAG Pipeline (huggingface#7260)
Browse files Browse the repository at this point in the history
* fix ip adapter support

* Update sag pipelines tests, adjust sag pipeline to pass tests

---------

Co-authored-by: YiYi Xu <[email protected]>
  • Loading branch information
stephen-iezzi and yiyixuxu authored Mar 30, 2024
1 parent f0c8156 commit ca61287
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,40 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state

return image_embeds, uncond_image_embeds

def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)

image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)

if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)

image_embeds.append(single_image_embeds)
else:
image_embeds = ip_adapter_image_embeds
return image_embeds

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
Expand Down Expand Up @@ -535,6 +569,7 @@ def __call__(
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
Expand Down Expand Up @@ -583,6 +618,9 @@ def __call__(
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Expand Down Expand Up @@ -636,13 +674,24 @@ def __call__(
# `sag_scale = 0` means no self-attention guidance
do_self_attention_guidance = sag_scale > 0.0

if ip_adapter_image is not None:
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
image_embeds, negative_image_embeds = self.encode_image(
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
do_classifier_free_guidance,
)

if do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
image_embeds = []
negative_image_embeds = []
for tmp_image_embeds in ip_adapter_image_embeds:
single_negative_image_embeds, single_image_embeds = tmp_image_embeds.chunk(2)
image_embeds.append(single_image_embeds)
negative_image_embeds.append(single_negative_image_embeds)
else:
image_embeds = ip_adapter_image_embeds

# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
Expand Down Expand Up @@ -687,8 +736,18 @@ def __call__(
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
added_uncond_kwargs = {"image_embeds": negative_image_embeds} if ip_adapter_image is not None else None
added_cond_kwargs = (
{"image_embeds": image_embeds}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
else None
)

if do_classifier_free_guidance:
added_uncond_kwargs = (
{"image_embeds": negative_image_embeds}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
else None
)

# 7. Denoising loop
store_processor = CrossAttnStoreProcessor()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device

from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin


enable_full_determinism()


class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
class StableDiffusionSAGPipelineFastTests(
IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionSAGPipeline
params = TEXT_TO_IMAGE_PARAMS
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
Expand Down

0 comments on commit ca61287

Please sign in to comment.