Skip to content

Commit

Permalink
Merge branch 'main' into 4bit-lora-loading
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Jan 15, 2025
2 parents c92758f + f9e957f commit a3f533b
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@
# limitations under the License.

import inspect
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from transformers import (
BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
PreTrainedModel,
T5EncoderModel,
T5TokenizerFast,
)

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
Expand Down Expand Up @@ -162,7 +164,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
Expand Down Expand Up @@ -194,10 +196,14 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
tokenizer_3 (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
image_encoder (`PreTrainedModel`, *optional*):
Pre-trained Vision Model for IP Adapter.
feature_extractor (`BaseImageProcessor`, *optional*):
Image processor for IP Adapter.
"""

model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
_optional_components = []
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
_optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]

def __init__(
Expand All @@ -211,6 +217,8 @@ def __init__(
tokenizer_2: CLIPTokenizer,
text_encoder_3: T5EncoderModel,
tokenizer_3: T5TokenizerFast,
image_encoder: PreTrainedModel = None,
feature_extractor: BaseImageProcessor = None,
):
super().__init__()

Expand All @@ -224,6 +232,8 @@ def __init__(
tokenizer_3=tokenizer_3,
transformer=transformer,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
Expand Down Expand Up @@ -818,6 +828,10 @@ def clip_skip(self):
def do_classifier_free_guidance(self):
return self._guidance_scale > 1

@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs

@property
def num_timesteps(self):
return self._num_timesteps
Expand All @@ -826,6 +840,84 @@ def num_timesteps(self):
def interrupt(self):
return self._interrupt

# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
"""Encodes the given image into a feature representation using a pre-trained image encoder.
Args:
image (`PipelineImageInput`):
Input image to be encoded.
device: (`torch.device`):
Torch device.
Returns:
`torch.Tensor`: The encoded image feature representation.
"""
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values

image = image.to(device=device, dtype=self.dtype)

return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]

# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
) -> torch.Tensor:
"""Prepares image embeddings for use in the IP-Adapter.
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
Args:
ip_adapter_image (`PipelineImageInput`, *optional*):
The input image to extract features from for IP-Adapter.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Precomputed image embeddings.
device: (`torch.device`, *optional*):
Torch device.
num_images_per_prompt (`int`, defaults to 1):
Number of images that should be generated per prompt.
do_classifier_free_guidance (`bool`, defaults to True):
Whether to use classifier free guidance or not.
"""
device = device or self._execution_device

if ip_adapter_image_embeds is not None:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
else:
single_image_embeds = ip_adapter_image_embeds
elif ip_adapter_image is not None:
single_image_embeds = self.encode_image(ip_adapter_image, device)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
else:
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")

image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)

if do_classifier_free_guidance:
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)

return image_embeds.to(device=device)

# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
def enable_sequential_cpu_offload(self, *args, **kwargs):
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
logger.warning(
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
)

super().enable_sequential_cpu_offload(*args, **kwargs)

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -853,8 +945,11 @@ def __call__(
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
Expand Down Expand Up @@ -890,9 +985,9 @@ def __call__(
mask_image_latent (`torch.Tensor`, `List[torch.Tensor]`):
`Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask
latents tensor will ge generated by `mask_image`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
padding_mask_crop (`int`, *optional*, defaults to `None`):
The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
Expand Down Expand Up @@ -953,12 +1048,22 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
ip_adapter_image (`PipelineImageInput`, *optional*):
Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
Expand Down Expand Up @@ -1006,6 +1111,7 @@ def __call__(

self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False

# 2. Define call parameters
Expand Down Expand Up @@ -1160,7 +1266,22 @@ def __call__(
f"The transformer {self.transformer.__class__} should have 16 input channels or 33 input channels, not {self.transformer.config.in_channels}."
)

# 7. Denoising loop
# 7. Prepare image embeddings
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)

if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
else:
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)

# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
Expand All @@ -1181,6 +1302,7 @@ def __call__(
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.7, 0.8]

@property
def dummy_input(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogView3PlusTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.6, 0.6]

@property
def dummy_input(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def get_dummy_components(self):
"tokenizer_3": tokenizer_3,
"transformer": transformer,
"vae": vae,
"image_encoder": None,
"feature_extractor": None,
}

def get_dummy_inputs(self, device, seed=0):
Expand Down

0 comments on commit a3f533b

Please sign in to comment.