-
-
-
initial image
-
-
-
-
generated video
-
+
+
-### I2VGen-XL
-
-[I2VGen-XL](../api/pipelines/i2vgenxl) is a diffusion model that can generate higher resolution videos than SVD and it is also capable of accepting text prompts in addition to images. The model is trained with two hierarchical encoders (detail and global encoder) to better capture low and high-level details in images. These learned details are used to train a video diffusion model which refines the video resolution and details in the generated video.
+
+
-You can use I2VGen-XL by loading the [`I2VGenXLPipeline`], and passing a text and image prompt to generate a video.
+[LTX-Video (LTXV)](https://huggingface.co/Lightricks/LTX-Video) is a diffusion transformer (DiT) with a focus on speed. It generates 768x512 resolution videos at 24 frames per second (fps), enabling near real-time generation of high-quality videos. LTXV is relatively lightweight compared to other modern video generation models, making it possible to run on consumer GPUs.
```py
import torch
-from diffusers import I2VGenXLPipeline
-from diffusers.utils import export_to_gif, load_image
-
-pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
-pipeline.enable_model_cpu_offload()
-
-image_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/i2vgen_xl_images/img_0009.png"
-image = load_image(image_url).convert("RGB")
+from diffusers import LTXPipeline
+from diffusers.utils import export_to_video
-prompt = "Papers were floating in the air on a table in the library"
-negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"
-generator = torch.manual_seed(8888)
+pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16).to("cuda")
-frames = pipeline(
+prompt = "A man walks towards a window, looks out, and then turns around. He has short, dark hair, dark skin, and is wearing a brown coat over a red and gray scarf. He walks from left to right towards a window, his gaze fixed on something outside. The camera follows him from behind at a medium distance. The room is brightly lit, with white walls and a large window covered by a white curtain. As he approaches the window, he turns his head slightly to the left, then back to the right. He then turns his entire body to the right, facing the window. The camera remains stationary as he stands in front of the window. The scene is captured in real-life footage."
+video = pipe(
prompt=prompt,
- image=image,
+ width=704,
+ height=480,
+ num_frames=161,
num_inference_steps=50,
- negative_prompt=negative_prompt,
- guidance_scale=9.0,
- generator=generator
).frames[0]
-export_to_gif(frames, "i2v.gif")
+export_to_video(video, "output.mp4", fps=24)
+```
+
+
+
+
+
+
+
+
+> [!TIP]
+> Mochi-1 is a 10B parameter model and requires a lot of memory. Refer to the Mochi [Quantization](../api/pipelines/mochi#quantization) guide to learn how to quantize the model. CogVideoX and LTX-Video are more lightweight options that can still generate high-quality videos.
+
+[Mochi-1](https://huggingface.co/genmo/mochi-1-preview) introduces the Asymmetric Diffusion Transformer (AsymmDiT) and Asymmetric Variational Autoencoder (AsymmVAE) to reduces memory requirements. AsymmVAE causally compresses videos 128x to improve memory efficiency, and AsymmDiT jointly attends to the compressed video tokens and user text tokens. This model is noted for generating videos with high-quality motion dynamics and strong prompt adherence.
+
+```py
+import torch
+from diffusers import MochiPipeline
+from diffusers.utils import export_to_video
+
+pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
+
+# reduce memory requirements
+pipe.enable_model_cpu_offload()
+pipe.enable_vae_tiling()
+
+prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
+video = pipe(prompt, num_frames=84).frames[0]
+export_to_video(video, "output.mp4", fps=30)
+```
+
+
+
+
+
+
+
+
+[StableVideoDiffusion (SVD)](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) is based on the Stable Diffusion 2.1 model and it is trained on images, then low-resolution videos, and finally a smaller dataset of high-resolution videos. This model generates a short 2-4 second video from an initial image.
+
+```py
+import torch
+from diffusers import StableVideoDiffusionPipeline
+from diffusers.utils import load_image, export_to_video
+
+pipeline = StableVideoDiffusionPipeline.from_pretrained(
+ "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
+)
+
+# reduce memory requirements
+pipeline.enable_model_cpu_offload()
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png")
+image = image.resize((1024, 576))
+
+generator = torch.manual_seed(42)
+frames = pipeline(image, decode_chunk_size=8, generator=generator).frames[0]
+export_to_video(frames, "generated.mp4", fps=7)
```
-
+
initial image
-
+
generated video
-### AnimateDiff
+
+
-[AnimateDiff](../api/pipelines/animatediff) is an adapter model that inserts a motion module into a pretrained diffusion model to animate an image. The adapter is trained on video clips to learn motion which is used to condition the generation process to create a video. It is faster and easier to only train the adapter and it can be loaded into most diffusion models, effectively turning them into "video models".
+[AnimateDiff](https://huggingface.co/guoyww/animatediff) is an adapter model that inserts a motion module into a pretrained diffusion model to animate an image. The adapter is trained on video clips to learn motion which is used to condition the generation process to create a video. It is faster and easier to only train the adapter and it can be loaded into most diffusion models, effectively turning them into “video models”.
-Start by loading a [`MotionAdapter`].
+Load a `MotionAdapter` and pass it to the [`AnimateDiffPipeline`].
```py
import torch
@@ -166,11 +206,6 @@ from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
from diffusers.utils import export_to_gif
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
-```
-
-Then load a finetuned Stable Diffusion model with the [`AnimateDiffPipeline`].
-
-```py
pipeline = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter, torch_dtype=torch.float16)
scheduler = DDIMScheduler.from_pretrained(
"emilianJR/epiCRealism",
@@ -181,13 +216,11 @@ scheduler = DDIMScheduler.from_pretrained(
steps_offset=1,
)
pipeline.scheduler = scheduler
+
+# reduce memory requirements
pipeline.enable_vae_slicing()
pipeline.enable_model_cpu_offload()
-```
-Create a prompt and generate the video.
-
-```py
output = pipeline(
prompt="A space rocket with trails of smoke behind it launching into space from the desert, 4k, high resolution",
negative_prompt="bad quality, worse quality, low resolution",
@@ -201,38 +234,11 @@ export_to_gif(frames, "animation.gif")
```
-
+
-### ModelscopeT2V
-
-[ModelscopeT2V](../api/pipelines/text_to_video) adds spatial and temporal convolutions and attention to a UNet, and it is trained on image-text and video-text datasets to enhance what it learns during training. The model takes a prompt, encodes it and creates text embeddings which are denoised by the UNet, and then decoded by a VQGAN into a video.
-
-
-
-ModelScopeT2V generates watermarked videos due to the datasets it was trained on. To use a watermark-free model, try the [cerspense/zeroscope_v2_76w](https://huggingface.co/cerspense/zeroscope_v2_576w) model with the [`TextToVideoSDPipeline`] first, and then upscale it's output with the [cerspense/zeroscope_v2_XL](https://huggingface.co/cerspense/zeroscope_v2_XL) checkpoint using the [`VideoToVideoSDPipeline`].
-
-
-
-Load a ModelScopeT2V checkpoint into the [`DiffusionPipeline`] along with a prompt to generate a video.
-
-```py
-import torch
-from diffusers import DiffusionPipeline
-from diffusers.utils import export_to_video
-
-pipeline = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
-pipeline.enable_model_cpu_offload()
-pipeline.enable_vae_slicing()
-
-prompt = "Confident teddy bear surfer rides the wave in the tropics"
-video_frames = pipeline(prompt).frames[0]
-export_to_video(video_frames, "modelscopet2v.mp4", fps=10)
-```
-
-
-
-
+
+
## Configure model parameters
@@ -548,3 +554,9 @@ If memory is not an issue and you want to optimize for speed, try wrapping the U
+ pipeline.to("cuda")
+ pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
+
+## Quantization
+
+Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
+
+Refer to the [Quantization](../../quantization/overview) to learn more about supported quantization backends (bitsandbytes, torchao, gguf) and selecting a quantization backend that supports your use case.
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
index 542b8505874f..923683ae7c38 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
@@ -160,7 +160,7 @@ def save_model_card(
from diffusers import AutoPipelineForText2Image
import torch
{diffusers_imports_pivotal}
-pipeline = AutoPipelineForText2Image.from_pretrained('runwayml/stable-diffusion-v1-5', torch_dtype=torch.float16).to('cuda')
+pipeline = AutoPipelineForText2Image.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5', torch_dtype=torch.float16).to('cuda')
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
{diffusers_example_pivotal}
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
diff --git a/examples/community/README.md b/examples/community/README.md
index 611a278af88e..c7c40c46ef2d 100755
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -33,12 +33,12 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
-| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
+| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_comparison.ipynb) | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
| MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
-| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - | [Ray Wang](https://wrong.wang) |
-| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
+| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_unclip.ipynb) | [Ray Wang](https://wrong.wang) |
+| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_text_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
-| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - | [Aengus (Duc-Anh)](https://github.com/aengusng8) |
+| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)| [Aengus (Duc-Anh)](https://github.com/aengusng8) |
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/edict_image_pipeline.ipynb) | [Joqsan Azocar](https://github.com/Joqsan) |
@@ -50,7 +50,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon)
| Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) |
| Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LsqilswLR40XLLcp6XFOl5nKb_wOe26W?usp=sharing) | [Andrew Zhu](https://xhinker.medium.com/) |
-| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) |
+| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
diff --git a/examples/community/adaptive_mask_inpainting.py b/examples/community/adaptive_mask_inpainting.py
index a9de26b29a89..b4f6b6ef668f 100644
--- a/examples/community/adaptive_mask_inpainting.py
+++ b/examples/community/adaptive_mask_inpainting.py
@@ -372,7 +372,7 @@ def __init__(
self.register_adaptive_mask_model()
self.register_adaptive_mask_settings()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -386,7 +386,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -450,7 +450,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py
index 46d12ba1f2aa..23423594c54b 100644
--- a/examples/community/composable_stable_diffusion.py
+++ b/examples/community/composable_stable_diffusion.py
@@ -89,7 +89,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -103,7 +103,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -162,7 +162,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
diff --git a/examples/community/edict_pipeline.py b/examples/community/edict_pipeline.py
index ac977f79abec..a7bc892ddf93 100644
--- a/examples/community/edict_pipeline.py
+++ b/examples/community/edict_pipeline.py
@@ -35,7 +35,7 @@ def __init__(
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(
diff --git a/examples/community/fresco_v2v.py b/examples/community/fresco_v2v.py
index ab191ecf0d81..2784e2f238f6 100644
--- a/examples/community/fresco_v2v.py
+++ b/examples/community/fresco_v2v.py
@@ -1342,7 +1342,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
diff --git a/examples/community/gluegen.py b/examples/community/gluegen.py
index 91026c5d966f..54cc562d5583 100644
--- a/examples/community/gluegen.py
+++ b/examples/community/gluegen.py
@@ -221,7 +221,7 @@ def __init__(
language_adapter=language_adapter,
tensor_norm=tensor_norm,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py
index 4dfb7a39155f..292c9aa2bc47 100644
--- a/examples/community/img2img_inpainting.py
+++ b/examples/community/img2img_inpainting.py
@@ -95,7 +95,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
diff --git a/examples/community/instaflow_one_step.py b/examples/community/instaflow_one_step.py
index 3fef02287186..2af24ab8b703 100644
--- a/examples/community/instaflow_one_step.py
+++ b/examples/community/instaflow_one_step.py
@@ -109,7 +109,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -123,7 +123,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -182,7 +182,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/interpolate_stable_diffusion.py b/examples/community/interpolate_stable_diffusion.py
index 52b2707f33f7..99614635ee13 100644
--- a/examples/community/interpolate_stable_diffusion.py
+++ b/examples/community/interpolate_stable_diffusion.py
@@ -86,7 +86,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py
index c7dc775eeee3..8b6d147724bd 100644
--- a/examples/community/ip_adapter_face_id.py
+++ b/examples/community/ip_adapter_face_id.py
@@ -191,7 +191,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -205,7 +205,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -265,7 +265,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/kohya_hires_fix.py b/examples/community/kohya_hires_fix.py
index 0e36f32b19a3..ddbb28896e13 100644
--- a/examples/community/kohya_hires_fix.py
+++ b/examples/community/kohya_hires_fix.py
@@ -463,6 +463,6 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/latent_consistency_img2img.py b/examples/community/latent_consistency_img2img.py
index 5fe53ab6b830..6c532c7f76c1 100644
--- a/examples/community/latent_consistency_img2img.py
+++ b/examples/community/latent_consistency_img2img.py
@@ -69,7 +69,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(
diff --git a/examples/community/latent_consistency_interpolate.py b/examples/community/latent_consistency_interpolate.py
index 84adc125b191..34cdb0fec73b 100644
--- a/examples/community/latent_consistency_interpolate.py
+++ b/examples/community/latent_consistency_interpolate.py
@@ -273,7 +273,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/latent_consistency_txt2img.py b/examples/community/latent_consistency_txt2img.py
index 9f25a6db2722..7b60f5bb875c 100755
--- a/examples/community/latent_consistency_txt2img.py
+++ b/examples/community/latent_consistency_txt2img.py
@@ -67,7 +67,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(
diff --git a/examples/community/llm_grounded_diffusion.py b/examples/community/llm_grounded_diffusion.py
index 49c074911354..07fbc15350a9 100644
--- a/examples/community/llm_grounded_diffusion.py
+++ b/examples/community/llm_grounded_diffusion.py
@@ -336,7 +336,7 @@ def __init__(
# This is copied from StableDiffusionPipeline, with hook initizations for LMD+.
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -350,7 +350,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -410,7 +410,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py
index ec27acdce331..73ea8fffd2e4 100644
--- a/examples/community/lpw_stable_diffusion.py
+++ b/examples/community/lpw_stable_diffusion.py
@@ -496,7 +496,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -510,7 +510,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -568,7 +568,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(
diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py
index 13d1e2a1156a..d23eca6059b4 100644
--- a/examples/community/lpw_stable_diffusion_xl.py
+++ b/examples/community/lpw_stable_diffusion_xl.py
@@ -673,7 +673,7 @@ def __init__(
image_encoder=image_encoder,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -827,7 +827,9 @@ def encode_prompt(
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
@@ -879,7 +881,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py
index 0c85ad118752..0cd85ced59a1 100644
--- a/examples/community/matryoshka.py
+++ b/examples/community/matryoshka.py
@@ -3766,7 +3766,7 @@ def __init__(
else:
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -3780,7 +3780,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- # if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ # if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
# deprecation_message = (
# f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
# " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
diff --git a/examples/community/multilingual_stable_diffusion.py b/examples/community/multilingual_stable_diffusion.py
index dc335e0b585e..5dcc75c9e20b 100644
--- a/examples/community/multilingual_stable_diffusion.py
+++ b/examples/community/multilingual_stable_diffusion.py
@@ -98,7 +98,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
diff --git a/examples/community/pipeline_animatediff_controlnet.py b/examples/community/pipeline_animatediff_controlnet.py
index bedf002d024c..9f99ad248be2 100644
--- a/examples/community/pipeline_animatediff_controlnet.py
+++ b/examples/community/pipeline_animatediff_controlnet.py
@@ -188,7 +188,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
diff --git a/examples/community/pipeline_animatediff_img2video.py b/examples/community/pipeline_animatediff_img2video.py
index 0a578d4b8ef6..f7f0cf31c5dd 100644
--- a/examples/community/pipeline_animatediff_img2video.py
+++ b/examples/community/pipeline_animatediff_img2video.py
@@ -308,7 +308,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
diff --git a/examples/community/pipeline_animatediff_ipex.py b/examples/community/pipeline_animatediff_ipex.py
index dc65e76bc43b..06508f217c4c 100644
--- a/examples/community/pipeline_animatediff_ipex.py
+++ b/examples/community/pipeline_animatediff_ipex.py
@@ -162,7 +162,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
diff --git a/examples/community/pipeline_demofusion_sdxl.py b/examples/community/pipeline_demofusion_sdxl.py
index f83d1b401420..b21902e9798f 100644
--- a/examples/community/pipeline_demofusion_sdxl.py
+++ b/examples/community/pipeline_demofusion_sdxl.py
@@ -166,7 +166,7 @@ def __init__(
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
@@ -290,7 +290,9 @@ def encode_prompt(
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
@@ -342,7 +344,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/examples/community/pipeline_fabric.py b/examples/community/pipeline_fabric.py
index 02fdcd04c103..75d724bd7304 100644
--- a/examples/community/pipeline_fabric.py
+++ b/examples/community/pipeline_fabric.py
@@ -179,7 +179,7 @@ def __init__(
tokenizer=tokenizer,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
diff --git a/examples/community/pipeline_flux_differential_img2img.py b/examples/community/pipeline_flux_differential_img2img.py
index 84d5027c3eab..a66e2b1c7c8a 100644
--- a/examples/community/pipeline_flux_differential_img2img.py
+++ b/examples/community/pipeline_flux_differential_img2img.py
@@ -221,13 +221,12 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
- vae_latent_channels=self.vae.config.latent_channels,
+ vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=False,
do_convert_grayscale=True,
diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py
index 8bed15afecd8..42fed90762da 100644
--- a/examples/community/pipeline_flux_rf_inversion.py
+++ b/examples/community/pipeline_flux_rf_inversion.py
@@ -219,9 +219,7 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
@@ -419,7 +417,7 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
)
image = image.to(dtype)
- x0 = self.vae.encode(image.to(self.device)).latent_dist.sample()
+ x0 = self.vae.encode(image.to(self._execution_device)).latent_dist.sample()
x0 = (x0 - self.vae.config.shift_factor) * self.vae.config.scaling_factor
x0 = x0.to(dtype)
return x0, resized
diff --git a/examples/community/pipeline_flux_with_cfg.py b/examples/community/pipeline_flux_with_cfg.py
index 04af685ec872..0b27fd2bcddf 100644
--- a/examples/community/pipeline_flux_with_cfg.py
+++ b/examples/community/pipeline_flux_with_cfg.py
@@ -190,9 +190,7 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
diff --git a/examples/community/pipeline_hunyuandit_differential_img2img.py b/examples/community/pipeline_hunyuandit_differential_img2img.py
index 8cf2830f25ab..a294ff782450 100644
--- a/examples/community/pipeline_hunyuandit_differential_img2img.py
+++ b/examples/community/pipeline_hunyuandit_differential_img2img.py
@@ -327,9 +327,7 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
diff --git a/examples/community/pipeline_kolors_differential_img2img.py b/examples/community/pipeline_kolors_differential_img2img.py
index e5570248d22b..7734ef8f164a 100644
--- a/examples/community/pipeline_kolors_differential_img2img.py
+++ b/examples/community/pipeline_kolors_differential_img2img.py
@@ -209,9 +209,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py
index 508e84177928..172241c817fd 100644
--- a/examples/community/pipeline_prompt2prompt.py
+++ b/examples/community/pipeline_prompt2prompt.py
@@ -131,7 +131,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -145,7 +145,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -205,7 +205,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/pipeline_sdxl_style_aligned.py b/examples/community/pipeline_sdxl_style_aligned.py
index 8328bc2caed9..50e0ca0f9f24 100644
--- a/examples/community/pipeline_sdxl_style_aligned.py
+++ b/examples/community/pipeline_sdxl_style_aligned.py
@@ -488,7 +488,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -628,7 +628,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -688,7 +690,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py
index 8cee5ecbc141..50952304fc1e 100644
--- a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py
+++ b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py
@@ -207,7 +207,7 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
)
diff --git a/examples/community/pipeline_stable_diffusion_boxdiff.py b/examples/community/pipeline_stable_diffusion_boxdiff.py
index 6490c1400138..6d36a9a8a389 100644
--- a/examples/community/pipeline_stable_diffusion_boxdiff.py
+++ b/examples/community/pipeline_stable_diffusion_boxdiff.py
@@ -417,7 +417,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -431,7 +431,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -491,7 +491,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/pipeline_stable_diffusion_pag.py b/examples/community/pipeline_stable_diffusion_pag.py
index cea2c9735747..9dda2b5a0a1e 100644
--- a/examples/community/pipeline_stable_diffusion_pag.py
+++ b/examples/community/pipeline_stable_diffusion_pag.py
@@ -384,7 +384,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -398,7 +398,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -458,7 +458,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
index 1ac651a1fe60..8a709ab46757 100644
--- a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
+++ b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
@@ -151,7 +151,7 @@ def __init__(
watermarker=watermarker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor, resample="bilinear")
# self.register_to_config(requires_safety_checker=requires_safety_checker)
self.register_to_config(max_noise_level=max_noise_level)
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
index ae495979f366..d80cb209ec0a 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
@@ -226,7 +226,7 @@ def __init__(
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -359,7 +359,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -419,7 +421,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
index 94ca71cf7b1b..d8c52a78b104 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
@@ -374,7 +374,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -507,7 +507,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -567,7 +569,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py b/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py
index 584820e86254..e74ea263017f 100644
--- a/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py
+++ b/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py
@@ -258,7 +258,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -394,7 +394,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -454,7 +456,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/examples/community/pipeline_stable_diffusion_xl_ipex.py b/examples/community/pipeline_stable_diffusion_xl_ipex.py
index 022dfb1abf82..bc430955282e 100644
--- a/examples/community/pipeline_stable_diffusion_xl_ipex.py
+++ b/examples/community/pipeline_stable_diffusion_xl_ipex.py
@@ -253,7 +253,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
@@ -390,7 +390,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -450,7 +452,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/examples/community/pipeline_zero1to3.py b/examples/community/pipeline_zero1to3.py
index 95bb37ce02b7..9c1f2362b1c8 100644
--- a/examples/community/pipeline_zero1to3.py
+++ b/examples/community/pipeline_zero1to3.py
@@ -108,7 +108,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -122,7 +122,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -181,7 +181,7 @@ def __init__(
feature_extractor=feature_extractor,
cc_projection=cc_projection,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
# self.model_mode = None
diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py
index d9c616ab5ebc..706b22bbb88d 100644
--- a/examples/community/rerender_a_video.py
+++ b/examples/community/rerender_a_video.py
@@ -30,10 +30,17 @@
from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
-from diffusers.utils import BaseOutput, deprecate, logging
+from diffusers.utils import BaseOutput, deprecate, is_torch_xla_available, logging
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -345,7 +352,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -775,7 +782,7 @@ def __call__(
self.attn_state.reset()
# 4.1 prepare frames
- image = self.image_processor.preprocess(frames[0]).to(dtype=torch.float32)
+ image = self.image_processor.preprocess(frames[0]).to(dtype=self.dtype)
first_image = image[0] # C, H, W
# 4.2 Prepare controlnet_conditioning_image
@@ -919,8 +926,8 @@ def __call__(
prev_image = frames[idx - 1]
control_image = control_frames[idx]
# 5.1 prepare frames
- image = self.image_processor.preprocess(image).to(dtype=torch.float32)
- prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32)
+ image = self.image_processor.preprocess(image).to(dtype=self.dtype)
+ prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)
warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
self.flow_model, first_image, image[0], first_result, False, self.device
@@ -1100,6 +1107,9 @@ def denoising_loop(latents, mask=None, xtrg=None, noise_rescale=None):
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
return latents
if mask_start_t <= mask_end_t:
diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py
index c7c88d6fdcc7..6aa4067d695d 100644
--- a/examples/community/stable_diffusion_controlnet_img2img.py
+++ b/examples/community/stable_diffusion_controlnet_img2img.py
@@ -179,7 +179,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(
diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py
index b473ffe79933..2d19e26b4220 100644
--- a/examples/community/stable_diffusion_controlnet_inpaint.py
+++ b/examples/community/stable_diffusion_controlnet_inpaint.py
@@ -278,7 +278,7 @@ def __init__(
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(
diff --git a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
index 8928f34239e3..4363a2294b63 100644
--- a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
+++ b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
@@ -263,7 +263,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def _encode_prompt(
diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py
index 123892f6229a..3cae3e6df4f3 100644
--- a/examples/community/stable_diffusion_ipex.py
+++ b/examples/community/stable_diffusion_ipex.py
@@ -105,7 +105,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -119,7 +119,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -178,7 +178,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1):
diff --git a/examples/community/stable_diffusion_mega.py b/examples/community/stable_diffusion_mega.py
index 95b4b03e4de1..77e5011d2a70 100644
--- a/examples/community/stable_diffusion_mega.py
+++ b/examples/community/stable_diffusion_mega.py
@@ -66,7 +66,7 @@ def __init__(
requires_safety_checker: bool = True,
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py
index efb0fa89dbfc..b54ebf27f715 100644
--- a/examples/community/stable_diffusion_reference.py
+++ b/examples/community/stable_diffusion_reference.py
@@ -132,7 +132,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -146,7 +146,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -219,7 +219,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/stable_diffusion_repaint.py b/examples/community/stable_diffusion_repaint.py
index 980e9a155997..115a6b005565 100644
--- a/examples/community/stable_diffusion_repaint.py
+++ b/examples/community/stable_diffusion_repaint.py
@@ -187,7 +187,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -201,7 +201,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -274,7 +274,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
diff --git a/examples/community/stable_diffusion_tensorrt_img2img.py b/examples/community/stable_diffusion_tensorrt_img2img.py
index 91540d1f4159..453e2d8d679c 100755
--- a/examples/community/stable_diffusion_tensorrt_img2img.py
+++ b/examples/community/stable_diffusion_tensorrt_img2img.py
@@ -710,7 +710,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -724,7 +724,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -806,7 +806,7 @@ def __init__(
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/stable_diffusion_tensorrt_inpaint.py b/examples/community/stable_diffusion_tensorrt_inpaint.py
index b6f6711a53e7..8d0c7bedc904 100755
--- a/examples/community/stable_diffusion_tensorrt_inpaint.py
+++ b/examples/community/stable_diffusion_tensorrt_inpaint.py
@@ -714,7 +714,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -728,7 +728,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -810,7 +810,7 @@ def __init__(
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py
index f8761053ed1a..f94f114663bc 100755
--- a/examples/community/stable_diffusion_tensorrt_txt2img.py
+++ b/examples/community/stable_diffusion_tensorrt_txt2img.py
@@ -626,7 +626,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -640,7 +640,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -722,7 +722,7 @@ def __init__(
self.engine = {} # loaded in build_engines()
self.vae.forward = self.vae.decode
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/examples/community/text_inpainting.py b/examples/community/text_inpainting.py
index c4378ab96f28..d73082b6cf38 100644
--- a/examples/community/text_inpainting.py
+++ b/examples/community/text_inpainting.py
@@ -71,7 +71,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -85,7 +85,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
diff --git a/examples/community/wildcard_stable_diffusion.py b/examples/community/wildcard_stable_diffusion.py
index c866ce2ae904..3c42c54f71f8 100644
--- a/examples/community/wildcard_stable_diffusion.py
+++ b/examples/community/wildcard_stable_diffusion.py
@@ -120,7 +120,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py
index 78eae4499ad2..097eaed8b504 100644
--- a/examples/dreambooth/train_dreambooth_lora_sd3.py
+++ b/examples/dreambooth/train_dreambooth_lora_sd3.py
@@ -29,7 +29,7 @@
import torch
import torch.utils.checkpoint
import transformers
-from accelerate import Accelerator
+from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
@@ -1292,11 +1292,17 @@ def save_model_hook(models, weights, output_dir):
text_encoder_two_lora_layers_to_save = None
for model in models:
- if isinstance(model, type(unwrap_model(transformer))):
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ model = unwrap_model(model)
+ if args.upcast_before_saving:
+ model = model.to(torch.float32)
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
- elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two
+ elif args.train_text_encoder and isinstance(
+ unwrap_model(model), type(unwrap_model(text_encoder_one))
+ ): # or text_encoder_two
# both text encoders are of the same class, so we check hidden size to distinguish between the two
- hidden_size = unwrap_model(model).config.hidden_size
+ model = unwrap_model(model)
+ hidden_size = model.config.hidden_size
if hidden_size == 768:
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif hidden_size == 1280:
@@ -1305,7 +1311,8 @@ def save_model_hook(models, weights, output_dir):
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
- weights.pop()
+ if weights:
+ weights.pop()
StableDiffusion3Pipeline.save_lora_weights(
output_dir,
@@ -1319,17 +1326,31 @@ def load_model_hook(models, input_dir):
text_encoder_one_ = None
text_encoder_two_ = None
- while len(models) > 0:
- model = models.pop()
+ if not accelerator.distributed_type == DistributedType.DEEPSPEED:
+ while len(models) > 0:
+ model = models.pop()
- if isinstance(model, type(unwrap_model(transformer))):
- transformer_ = model
- elif isinstance(model, type(unwrap_model(text_encoder_one))):
- text_encoder_one_ = model
- elif isinstance(model, type(unwrap_model(text_encoder_two))):
- text_encoder_two_ = model
- else:
- raise ValueError(f"unexpected save model: {model.__class__}")
+ if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
+ transformer_ = unwrap_model(model)
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = unwrap_model(model)
+ elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
+ text_encoder_two_ = unwrap_model(model)
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ else:
+ transformer_ = SD3Transformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer"
+ )
+ transformer_.add_adapter(transformer_lora_config)
+ if args.train_text_encoder:
+ text_encoder_one_ = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder"
+ )
+ text_encoder_two_ = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2"
+ )
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
@@ -1829,7 +1850,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
progress_bar.update(1)
global_step += 1
- if accelerator.is_main_process:
+ if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
diff --git a/examples/flux-control/README.md b/examples/flux-control/README.md
index 26ad9d06a2af..14afa499db0d 100644
--- a/examples/flux-control/README.md
+++ b/examples/flux-control/README.md
@@ -121,7 +121,7 @@ prompt = "A couple, 4k photo, highly detailed"
gen_images = pipe(
prompt=prompt,
- condition_image=image,
+ control_image=image,
num_inference_steps=50,
joint_attention_kwargs={"scale": 0.9},
guidance_scale=25.,
@@ -190,7 +190,7 @@ prompt = "A couple, 4k photo, highly detailed"
gen_images = pipe(
prompt=prompt,
- condition_image=image,
+ control_image=image,
num_inference_steps=50,
guidance_scale=25.,
).images[0]
@@ -200,5 +200,5 @@ gen_images.save("output.png")
## Things to note
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
-* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used.
+* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
\ No newline at end of file
diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py
index 35f9a5f80342..7d0e28069054 100644
--- a/examples/flux-control/train_control_flux.py
+++ b/examples/flux-control/train_control_flux.py
@@ -122,7 +122,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
for _ in range(args.num_validation_images):
with autocast_ctx:
- # need to fix in pipeline_flux_controlnet
image = pipeline(
prompt=validation_prompt,
control_image=validation_image,
@@ -159,7 +158,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
- formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+ formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
@@ -188,7 +187,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
img_str += f"![images_{i})](./images_{i}.png)\n"
model_description = f"""
-# control-lora-{repo_id}
+# flux-control-{repo_id}
These are Control weights trained on {base_model} with new type of conditioning.
{img_str}
@@ -434,7 +433,7 @@ def parse_args(input_args=None):
"--conditioning_image_column",
type=str,
default="conditioning_image",
- help="The column of the dataset containing the controlnet conditioning image.",
+ help="The column of the dataset containing the control conditioning image.",
)
parser.add_argument(
"--caption_column",
@@ -442,6 +441,7 @@ def parse_args(input_args=None):
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
+ parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
parser.add_argument(
"--max_train_samples",
type=int,
@@ -468,7 +468,7 @@ def parse_args(input_args=None):
default=None,
nargs="+",
help=(
- "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ "A set of paths to the control conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
@@ -505,7 +505,11 @@ def parse_args(input_args=None):
default=None,
help="Path to the jsonl file containing the training data.",
)
-
+ parser.add_argument(
+ "--only_target_transformer_blocks",
+ action="store_true",
+ help="If we should only target the transformer blocks to train along with the input layer (`x_embedder`).",
+ )
parser.add_argument(
"--guidance_scale",
type=float,
@@ -581,7 +585,7 @@ def parse_args(input_args=None):
if args.resolution % 8 != 0:
raise ValueError(
- "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
)
return args
@@ -665,7 +669,12 @@ def preprocess_train(examples):
conditioning_images = [image_transforms(image) for image in conditioning_images]
examples["pixel_values"] = images
examples["conditioning_pixel_values"] = conditioning_images
- examples["captions"] = list(examples[args.caption_column])
+
+ is_caption_list = isinstance(examples[args.caption_column][0], list)
+ if is_caption_list:
+ examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
+ else:
+ examples["captions"] = list(examples[args.caption_column])
return examples
@@ -765,7 +774,8 @@ def main(args):
subfolder="scheduler",
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
- flux_transformer.requires_grad_(True)
+ if not args.only_target_transformer_blocks:
+ flux_transformer.requires_grad_(True)
vae.requires_grad_(False)
# cast down and move to the CPU
@@ -797,6 +807,12 @@ def main(args):
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
+ if args.only_target_transformer_blocks:
+ flux_transformer.x_embedder.requires_grad_(True)
+ for name, module in flux_transformer.named_modules():
+ if "transformer_blocks" in name:
+ module.requires_grad_(True)
+
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
@@ -974,6 +990,32 @@ def load_model_hook(models, input_dir):
else:
initial_global_step = 0
+ if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
+ logger.info("Logging some dataset samples.")
+ formatted_images = []
+ formatted_control_images = []
+ all_prompts = []
+ for i, batch in enumerate(train_dataloader):
+ images = (batch["pixel_values"] + 1) / 2
+ control_images = (batch["conditioning_pixel_values"] + 1) / 2
+ prompts = batch["captions"]
+
+ if len(formatted_images) > 10:
+ break
+
+ for img, control_img, prompt in zip(images, control_images, prompts):
+ formatted_images.append(img)
+ formatted_control_images.append(control_img)
+ all_prompts.append(prompt)
+
+ logged_artifacts = []
+ for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
+ logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
+ logged_artifacts.append(wandb.Image(img, caption=prompt))
+
+ wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
+ wandb_tracker[0].log({"dataset_samples": logged_artifacts})
+
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py
index b176a685c963..44c684395849 100644
--- a/examples/flux-control/train_control_lora_flux.py
+++ b/examples/flux-control/train_control_lora_flux.py
@@ -132,7 +132,6 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
for _ in range(args.num_validation_images):
with autocast_ctx:
- # need to fix in pipeline_flux_controlnet
image = pipeline(
prompt=validation_prompt,
control_image=validation_image,
@@ -169,7 +168,7 @@ def log_validation(flux_transformer, args, accelerator, weight_dtype, step, is_f
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
- formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+ formatted_images.append(wandb.Image(validation_image, caption="Conditioning"))
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
@@ -198,7 +197,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
img_str += f"![images_{i})](./images_{i}.png)\n"
model_description = f"""
-# controlnet-lora-{repo_id}
+# control-lora-{repo_id}
These are Control LoRA weights trained on {base_model} with new type of conditioning.
{img_str}
@@ -256,7 +255,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--output_dir",
type=str,
- default="controlnet-lora",
+ default="control-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
@@ -466,7 +465,7 @@ def parse_args(input_args=None):
"--conditioning_image_column",
type=str,
default="conditioning_image",
- help="The column of the dataset containing the controlnet conditioning image.",
+ help="The column of the dataset containing the control conditioning image.",
)
parser.add_argument(
"--caption_column",
@@ -474,6 +473,7 @@ def parse_args(input_args=None):
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
+ parser.add_argument("--log_dataset_samples", action="store_true", help="Whether to log somple dataset samples.")
parser.add_argument(
"--max_train_samples",
type=int,
@@ -500,7 +500,7 @@ def parse_args(input_args=None):
default=None,
nargs="+",
help=(
- "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ "A set of paths to the control conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
@@ -613,7 +613,7 @@ def parse_args(input_args=None):
if args.resolution % 8 != 0:
raise ValueError(
- "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the Flux transformer."
)
return args
@@ -697,7 +697,12 @@ def preprocess_train(examples):
conditioning_images = [image_transforms(image) for image in conditioning_images]
examples["pixel_values"] = images
examples["conditioning_pixel_values"] = conditioning_images
- examples["captions"] = list(examples[args.caption_column])
+
+ is_caption_list = isinstance(examples[args.caption_column][0], list)
+ if is_caption_list:
+ examples["captions"] = [max(example, key=len) for example in examples[args.caption_column]]
+ else:
+ examples["captions"] = list(examples[args.caption_column])
return examples
@@ -923,11 +928,28 @@ def load_model_hook(models, input_dir):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
-
else:
transformer_ = FluxTransformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer"
).to(accelerator.device, weight_dtype)
+
+ # Handle input dimension doubling before adding adapter
+ with torch.no_grad():
+ initial_input_channels = transformer_.config.in_channels
+ new_linear = torch.nn.Linear(
+ transformer_.x_embedder.in_features * 2,
+ transformer_.x_embedder.out_features,
+ bias=transformer_.x_embedder.bias is not None,
+ dtype=transformer_.dtype,
+ device=transformer_.device,
+ )
+ new_linear.weight.zero_()
+ new_linear.weight[:, :initial_input_channels].copy_(transformer_.x_embedder.weight)
+ if transformer_.x_embedder.bias is not None:
+ new_linear.bias.copy_(transformer_.x_embedder.bias)
+ transformer_.x_embedder = new_linear
+ transformer_.register_to_config(in_channels=initial_input_channels * 2)
+
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = FluxControlPipeline.lora_state_dict(input_dir)
@@ -1115,6 +1137,32 @@ def load_model_hook(models, input_dir):
else:
initial_global_step = 0
+ if accelerator.is_main_process and args.report_to == "wandb" and args.log_dataset_samples:
+ logger.info("Logging some dataset samples.")
+ formatted_images = []
+ formatted_control_images = []
+ all_prompts = []
+ for i, batch in enumerate(train_dataloader):
+ images = (batch["pixel_values"] + 1) / 2
+ control_images = (batch["conditioning_pixel_values"] + 1) / 2
+ prompts = batch["captions"]
+
+ if len(formatted_images) > 10:
+ break
+
+ for img, control_img, prompt in zip(images, control_images, prompts):
+ formatted_images.append(img)
+ formatted_control_images.append(control_img)
+ all_prompts.append(prompt)
+
+ logged_artifacts = []
+ for img, control_img, prompt in zip(formatted_images, formatted_control_images, all_prompts):
+ logged_artifacts.append(wandb.Image(control_img, caption="Conditioning"))
+ logged_artifacts.append(wandb.Image(img, caption=prompt))
+
+ wandb_tracker = [tracker for tracker in accelerator.trackers if tracker.name == "wandb"]
+ wandb_tracker[0].log({"dataset_samples": logged_artifacts})
+
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
diff --git a/examples/research_projects/instructpix2pix_lora/README.md b/examples/research_projects/instructpix2pix_lora/README.md
index cfcd98926c07..25f7931b47d4 100644
--- a/examples/research_projects/instructpix2pix_lora/README.md
+++ b/examples/research_projects/instructpix2pix_lora/README.md
@@ -2,6 +2,34 @@
This extended LoRA training script was authored by [Aiden-Frost](https://github.com/Aiden-Frost).
This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py). This script provides further support add LoRA layers for unet model.
+## Running locally with PyTorch
+### Installing the dependencies
+
+Before running the scripts, make sure to install the library's training dependencies:
+
+**Important**
+
+To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
+```bash
+git clone https://github.com/huggingface/diffusers
+cd diffusers
+pip install .
+```
+
+Then cd in the example folder and run
+```bash
+pip install -r requirements.txt
+```
+
+And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
+
+```bash
+accelerate config
+```
+
+Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
+
+
## Training script example
```bash
@@ -9,7 +37,7 @@ export MODEL_ID="timbrooks/instruct-pix2pix"
export DATASET_ID="instruction-tuning-sd/cartoonization"
export OUTPUT_DIR="instructPix2Pix-cartoonization"
-accelerate launch finetune_instruct_pix2pix.py \
+accelerate launch train_instruct_pix2pix_lora.py \
--pretrained_model_name_or_path=$MODEL_ID \
--dataset_name=$DATASET_ID \
--enable_xformers_memory_efficient_attention \
@@ -24,7 +52,10 @@ accelerate launch finetune_instruct_pix2pix.py \
--rank=4 \
--output_dir=$OUTPUT_DIR \
--report_to=wandb \
- --push_to_hub
+ --push_to_hub \
+ --original_image_column="original_image" \
+ --edited_image_column="cartoonized_image" \
+ --edit_prompt_column="edit_prompt"
```
## Inference
diff --git a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
index 997d448fa281..fcb927c680a0 100644
--- a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
+++ b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
@@ -14,7 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""Script to fine-tune Stable Diffusion for InstructPix2Pix."""
+"""
+ Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
+ Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
+"""
import argparse
import logging
@@ -30,6 +33,7 @@
import PIL
import requests
import torch
+import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
@@ -39,21 +43,28 @@
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
+from peft import LoraConfig
+from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
-from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
-from diffusers.training_utils import EMAModel
-from diffusers.utils import check_min_version, deprecate, is_wandb_available
+from diffusers.training_utils import EMAModel, cast_training_params
+from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.26.0.dev0")
+check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -63,6 +74,92 @@
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
+def save_model_card(
+ repo_id: str,
+ images: list = None,
+ base_model: str = None,
+ dataset_name: str = None,
+ repo_folder: str = None,
+):
+ img_str = ""
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ img_str += f"![img_{i}](./image_{i}.png)\n"
+
+ model_description = f"""
+# LoRA text2image fine-tuning - {repo_id}
+These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
+{img_str}
+"""
+
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="creativeml-openrail-m",
+ base_model=base_model,
+ model_description=model_description,
+ inference=True,
+ )
+
+ tags = [
+ "stable-diffusion",
+ "stable-diffusion-diffusers",
+ "text-to-image",
+ "instruct-pix2pix",
+ "diffusers",
+ "diffusers-training",
+ "lora",
+ ]
+ model_card = populate_model_card(model_card, tags=tags)
+
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(accelerator.device)
+ pipeline.set_progress_bar_config(disable=True)
+
+ # run inference
+ original_image = download_image(args.val_image_url)
+ edited_images = []
+ if torch.backends.mps.is_available():
+ autocast_ctx = nullcontext()
+ else:
+ autocast_ctx = torch.autocast(accelerator.device.type)
+
+ with autocast_ctx:
+ for _ in range(args.num_validation_images):
+ edited_images.append(
+ pipeline(
+ args.validation_prompt,
+ image=original_image,
+ num_inference_steps=20,
+ image_guidance_scale=1.5,
+ guidance_scale=7,
+ generator=generator,
+ ).images[0]
+ )
+
+ for tracker in accelerator.trackers:
+ if tracker.name == "wandb":
+ wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
+ for edited_image in edited_images:
+ wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
+ tracker.log({"validation": wandb_table})
+
+ return edited_images
+
+
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
parser.add_argument(
@@ -417,11 +514,6 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
- if args.report_to == "wandb":
- if not is_wandb_available():
- raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
- import wandb
-
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -467,49 +559,58 @@ def main():
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
)
+ # InstructPix2Pix uses an additional image for conditioning. To accommodate that,
+ # it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
+ # then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
+ # from the pre-trained checkpoints. For the extra channels added to the first layer, they are
+ # initialized to zero.
+ logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
+ in_channels = 8
+ out_channels = unet.conv_in.out_channels
+ unet.register_to_config(in_channels=in_channels)
+
+ with torch.no_grad():
+ new_conv_in = nn.Conv2d(
+ in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
+ )
+ new_conv_in.weight.zero_()
+ new_conv_in.weight[:, :in_channels, :, :].copy_(unet.conv_in.weight)
+ unet.conv_in = new_conv_in
+
# Freeze vae, text_encoder and unet
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
# referred to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py
- unet_lora_parameters = []
- for attn_processor_name, attn_processor in unet.attn_processors.items():
- # Parse the attention module.
- attn_module = unet
- for n in attn_processor_name.split(".")[:-1]:
- attn_module = getattr(attn_module, n)
-
- # Set the `lora_layer` attribute of the attention-related matrices.
- attn_module.to_q.set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
- )
- )
- attn_module.to_k.set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
- )
- )
+ # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
- attn_module.to_v.set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
- )
- )
- attn_module.to_out[0].set_lora_layer(
- LoRALinearLayer(
- in_features=attn_module.to_out[0].in_features,
- out_features=attn_module.to_out[0].out_features,
- rank=args.rank,
- )
- )
+ # Freeze the unet parameters before adding adapters
+ unet.requires_grad_(False)
- # Accumulate the LoRA params to optimize.
- unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
- unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
- unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
- unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
+ unet_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.rank,
+ init_lora_weights="gaussian",
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+ )
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ unet.to(accelerator.device, dtype=weight_dtype)
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+
+ # Add adapter and make sure the trainable params are in float32.
+ unet.add_adapter(unet_lora_config)
+ if args.mixed_precision == "fp16":
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(unet, dtype=torch.float32)
# Create EMA for the unet.
if args.use_ema:
@@ -528,6 +629,13 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
+ trainable_params = filter(lambda p: p.requires_grad, unet.parameters())
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -540,7 +648,8 @@ def save_model_hook(models, weights, output_dir):
model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again
- weights.pop()
+ if weights:
+ weights.pop()
def load_model_hook(models, input_dir):
if args.use_ema:
@@ -589,9 +698,9 @@ def load_model_hook(models, input_dir):
else:
optimizer_cls = torch.optim.AdamW
- # train on only unet_lora_parameters
+ # train on only lora_layers
optimizer = optimizer_cls(
- unet_lora_parameters,
+ trainable_params,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -730,22 +839,27 @@ def collate_fn(examples):
)
# Scheduler and math around the number of training steps.
- overrode_max_train_steps = False
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
- args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
- overrode_max_train_steps = True
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
- num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
- num_training_steps=args.max_train_steps * accelerator.num_processes,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
)
# Prepare everything with our `accelerator`.
- unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
)
if args.use_ema:
@@ -765,8 +879,14 @@ def collate_fn(examples):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
- if overrode_max_train_steps:
+ if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -885,7 +1005,7 @@ def collate_fn(examples):
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
- model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
+ model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
@@ -895,7 +1015,7 @@ def collate_fn(examples):
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
- accelerator.clip_grad_norm_(unet_lora_parameters, args.max_grad_norm)
+ accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
@@ -903,7 +1023,7 @@ def collate_fn(examples):
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
- ema_unet.step(unet_lora_parameters)
+ ema_unet.step(trainable_params)
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
@@ -933,6 +1053,16 @@ def collate_fn(examples):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
+ unwrapped_unet = unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(
+ get_peft_model_state_dict(unwrapped_unet)
+ )
+
+ StableDiffusionInstructPix2PixPipeline.save_lora_weights(
+ save_directory=save_path,
+ unet_lora_layers=unet_lora_state_dict,
+ safe_serialization=True,
+ )
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -959,45 +1089,22 @@ def collate_fn(examples):
# The models need unwrapping because for compatibility in distributed training mode.
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
- unet=accelerator.unwrap_model(unet),
- text_encoder=accelerator.unwrap_model(text_encoder),
- vae=accelerator.unwrap_model(vae),
+ unet=unwrap_model(unet),
+ text_encoder=unwrap_model(text_encoder),
+ vae=unwrap_model(vae),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
- pipeline = pipeline.to(accelerator.device)
- pipeline.set_progress_bar_config(disable=True)
# run inference
- original_image = download_image(args.val_image_url)
- edited_images = []
- if torch.backends.mps.is_available():
- autocast_ctx = nullcontext()
- else:
- autocast_ctx = torch.autocast(accelerator.device.type)
-
- with autocast_ctx:
- for _ in range(args.num_validation_images):
- edited_images.append(
- pipeline(
- args.validation_prompt,
- image=original_image,
- num_inference_steps=20,
- image_guidance_scale=1.5,
- guidance_scale=7,
- generator=generator,
- ).images[0]
- )
+ log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+ )
- for tracker in accelerator.trackers:
- if tracker.name == "wandb":
- wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
- for edited_image in edited_images:
- wandb_table.add_data(
- wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
- )
- tracker.log({"validation": wandb_table})
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
@@ -1008,22 +1115,47 @@ def collate_fn(examples):
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
- unet = accelerator.unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())
+ # store only LORA layers
+ unet = unet.to(torch.float32)
+
+ unwrapped_unet = unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
+ StableDiffusionInstructPix2PixPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ unet_lora_layers=unet_lora_state_dict,
+ safe_serialization=True,
+ )
+
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
- text_encoder=accelerator.unwrap_model(text_encoder),
- vae=accelerator.unwrap_model(vae),
- unet=unet,
+ text_encoder=unwrap_model(text_encoder),
+ vae=unwrap_model(vae),
+ unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
)
- # store only LORA layers
- unet.save_attn_procs(args.output_dir)
+ pipeline.load_lora_weights(args.output_dir)
+
+ images = None
+ if (args.val_image_url is not None) and (args.validation_prompt is not None):
+ images = log_validation(
+ pipeline,
+ args,
+ accelerator,
+ generator,
+ )
if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ dataset_name=args.dataset_name,
+ repo_folder=args.output_dir,
+ )
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
@@ -1031,31 +1163,6 @@ def collate_fn(examples):
ignore_patterns=["step_*", "epoch_*"],
)
- if args.validation_prompt is not None:
- edited_images = []
- pipeline = pipeline.to(accelerator.device)
- with torch.autocast(str(accelerator.device).replace(":0", "")):
- for _ in range(args.num_validation_images):
- edited_images.append(
- pipeline(
- args.validation_prompt,
- image=original_image,
- num_inference_steps=20,
- image_guidance_scale=1.5,
- guidance_scale=7,
- generator=generator,
- ).images[0]
- )
-
- for tracker in accelerator.trackers:
- if tracker.name == "wandb":
- wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
- for edited_image in edited_images:
- wandb_table.add_data(
- wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
- )
- tracker.log({"test": wandb_table})
-
accelerator.end_training()
diff --git a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py
index aace66f9c18e..d7f882974a22 100644
--- a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py
+++ b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py
@@ -310,7 +310,7 @@ def __init__(
controlnet=controlnet,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
diff --git a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
index cb4260d4653f..19c1f30d82da 100644
--- a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
+++ b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
@@ -233,7 +233,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
diff --git a/examples/research_projects/rdm/pipeline_rdm.py b/examples/research_projects/rdm/pipeline_rdm.py
index f8093a3f217d..e84568786f50 100644
--- a/examples/research_projects/rdm/pipeline_rdm.py
+++ b/examples/research_projects/rdm/pipeline_rdm.py
@@ -78,7 +78,7 @@ def __init__(
feature_extractor=feature_extractor,
)
# Copy from statement here and all the methods we take from stable_diffusion_pipeline
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.retriever = retriever
diff --git a/examples/research_projects/realfill/requirements.txt b/examples/research_projects/realfill/requirements.txt
index 8fbaf908a2c8..96f504ece1f3 100644
--- a/examples/research_projects/realfill/requirements.txt
+++ b/examples/research_projects/realfill/requirements.txt
@@ -6,4 +6,4 @@ torch==2.2.0
torchvision>=0.16
ftfy==6.1.1
tensorboard==2.14.0
-Jinja2==3.1.4
+Jinja2==3.1.5
diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py
index 7e1eee2e6367..1ddbf93e4b78 100644
--- a/examples/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_sdxl.py
@@ -919,7 +919,7 @@ def preprocess_train(examples):
# fingerprint used by the cache for the other processes to load the result
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
new_fingerprint = Hasher.hash(args)
- new_fingerprint_for_vae = Hasher.hash(vae_path)
+ new_fingerprint_for_vae = Hasher.hash((vae_path, args))
train_dataset_with_embeddings = train_dataset.map(
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
)
diff --git a/scripts/convert_blipdiffusion_to_diffusers.py b/scripts/convert_blipdiffusion_to_diffusers.py
index 03cf67e5476b..2c286ea0fdc7 100644
--- a/scripts/convert_blipdiffusion_to_diffusers.py
+++ b/scripts/convert_blipdiffusion_to_diffusers.py
@@ -303,10 +303,11 @@ def save_blip_diffusion_model(model, args):
qformer = get_qformer(model)
qformer.eval()
- text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
- vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
-
- unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
+ text_encoder = ContextCLIPTextModel.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="text_encoder"
+ )
+ vae = AutoencoderKL.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae")
+ unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
vae.eval()
text_encoder.eval()
scheduler = PNDMScheduler(
@@ -316,7 +317,7 @@ def save_blip_diffusion_model(model, args):
set_alpha_to_one=False,
skip_prk_steps=True,
)
- tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
+ tokenizer = CLIPTokenizer.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="tokenizer")
image_processor = BlipImageProcessor()
blip_diffusion = BlipDiffusionPipeline(
tokenizer=tokenizer,
diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py
index 2f1732817be3..99a9ff322251 100644
--- a/scripts/convert_sana_to_diffusers.py
+++ b/scripts/convert_sana_to_diffusers.py
@@ -25,6 +25,7 @@
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
+ "Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
@@ -89,7 +90,10 @@ def main(args):
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
# scheduler
- flow_shift = 3.0
+ if args.image_size == 4096:
+ flow_shift = 6.0
+ else:
+ flow_shift = 3.0
# model config
if args.model_type == "SanaMS_1600M_P1_D20":
@@ -99,7 +103,7 @@ def main(args):
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
- interpolation_scale = {512: None, 1024: None, 2048: 1.0}
+ interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
for depth in range(layer_num):
# Transformer blocks.
@@ -272,9 +276,9 @@ def main(args):
"--image_size",
default=1024,
type=int,
- choices=[512, 1024, 2048],
+ choices=[512, 1024, 2048, 4096],
required=False,
- help="Image size of pretrained model, 512, 1024 or 2048.",
+ help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index 07c2c2272422..e064aeba43b6 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -973,3 +973,178 @@ def swap_scale_shift(weight):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
+
+
+def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
+ converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
+
+ def remap_norm_scale_shift_(key, state_dict):
+ weight = state_dict.pop(key)
+ shift, scale = weight.chunk(2, dim=0)
+ new_weight = torch.cat([scale, shift], dim=0)
+ state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
+
+ def remap_txt_in_(key, state_dict):
+ def rename_key(key):
+ new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
+ new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
+ new_key = new_key.replace("txt_in", "context_embedder")
+ new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
+ new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
+ new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
+ new_key = new_key.replace("mlp", "ff")
+ return new_key
+
+ if "self_attn_qkv" in key:
+ weight = state_dict.pop(key)
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
+ state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
+ else:
+ state_dict[rename_key(key)] = state_dict.pop(key)
+
+ def remap_img_attn_qkv_(key, state_dict):
+ weight = state_dict.pop(key)
+ if "lora_A" in key:
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = weight
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = weight
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = weight
+ else:
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
+ state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
+ state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
+
+ def remap_txt_attn_qkv_(key, state_dict):
+ weight = state_dict.pop(key)
+ if "lora_A" in key:
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = weight
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = weight
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = weight
+ else:
+ to_q, to_k, to_v = weight.chunk(3, dim=0)
+ state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
+ state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
+ state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
+
+ def remap_single_transformer_blocks_(key, state_dict):
+ hidden_size = 3072
+
+ if "linear1.lora_A.weight" in key or "linear1.lora_B.weight" in key:
+ linear1_weight = state_dict.pop(key)
+ if "lora_A" in key:
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
+ ".linear1.lora_A.weight"
+ )
+ state_dict[f"{new_key}.attn.to_q.lora_A.weight"] = linear1_weight
+ state_dict[f"{new_key}.attn.to_k.lora_A.weight"] = linear1_weight
+ state_dict[f"{new_key}.attn.to_v.lora_A.weight"] = linear1_weight
+ state_dict[f"{new_key}.proj_mlp.lora_A.weight"] = linear1_weight
+ else:
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
+ q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
+ ".linear1.lora_B.weight"
+ )
+ state_dict[f"{new_key}.attn.to_q.lora_B.weight"] = q
+ state_dict[f"{new_key}.attn.to_k.lora_B.weight"] = k
+ state_dict[f"{new_key}.attn.to_v.lora_B.weight"] = v
+ state_dict[f"{new_key}.proj_mlp.lora_B.weight"] = mlp
+
+ elif "linear1.lora_A.bias" in key or "linear1.lora_B.bias" in key:
+ linear1_bias = state_dict.pop(key)
+ if "lora_A" in key:
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
+ ".linear1.lora_A.bias"
+ )
+ state_dict[f"{new_key}.attn.to_q.lora_A.bias"] = linear1_bias
+ state_dict[f"{new_key}.attn.to_k.lora_A.bias"] = linear1_bias
+ state_dict[f"{new_key}.attn.to_v.lora_A.bias"] = linear1_bias
+ state_dict[f"{new_key}.proj_mlp.lora_A.bias"] = linear1_bias
+ else:
+ split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
+ new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(
+ ".linear1.lora_B.bias"
+ )
+ state_dict[f"{new_key}.attn.to_q.lora_B.bias"] = q_bias
+ state_dict[f"{new_key}.attn.to_k.lora_B.bias"] = k_bias
+ state_dict[f"{new_key}.attn.to_v.lora_B.bias"] = v_bias
+ state_dict[f"{new_key}.proj_mlp.lora_B.bias"] = mlp_bias
+
+ else:
+ new_key = key.replace("single_blocks", "single_transformer_blocks")
+ new_key = new_key.replace("linear2", "proj_out")
+ new_key = new_key.replace("q_norm", "attn.norm_q")
+ new_key = new_key.replace("k_norm", "attn.norm_k")
+ state_dict[new_key] = state_dict.pop(key)
+
+ TRANSFORMER_KEYS_RENAME_DICT = {
+ "img_in": "x_embedder",
+ "time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
+ "time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
+ "guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
+ "guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
+ "vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
+ "vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
+ "double_blocks": "transformer_blocks",
+ "img_attn_q_norm": "attn.norm_q",
+ "img_attn_k_norm": "attn.norm_k",
+ "img_attn_proj": "attn.to_out.0",
+ "txt_attn_q_norm": "attn.norm_added_q",
+ "txt_attn_k_norm": "attn.norm_added_k",
+ "txt_attn_proj": "attn.to_add_out",
+ "img_mod.linear": "norm1.linear",
+ "img_norm1": "norm1.norm",
+ "img_norm2": "norm2",
+ "img_mlp": "ff",
+ "txt_mod.linear": "norm1_context.linear",
+ "txt_norm1": "norm1.norm",
+ "txt_norm2": "norm2_context",
+ "txt_mlp": "ff_context",
+ "self_attn_proj": "attn.to_out.0",
+ "modulation.linear": "norm.linear",
+ "pre_norm": "norm.norm",
+ "final_layer.norm_final": "norm_out.norm",
+ "final_layer.linear": "proj_out",
+ "fc1": "net.0.proj",
+ "fc2": "net.2",
+ "input_embedder": "proj_in",
+ }
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
+ "txt_in": remap_txt_in_,
+ "img_attn_qkv": remap_img_attn_qkv_,
+ "txt_attn_qkv": remap_txt_attn_qkv_,
+ "single_blocks": remap_single_transformer_blocks_,
+ "final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
+ }
+
+ # Some folks attempt to make their state dict compatible with diffusers by adding "transformer." prefix to all keys
+ # and use their custom code. To make sure both "original" and "attempted diffusers" loras work as expected, we make
+ # sure that both follow the same initial format by stripping off the "transformer." prefix.
+ for key in list(converted_state_dict.keys()):
+ if key.startswith("transformer."):
+ converted_state_dict[key[len("transformer.") :]] = converted_state_dict.pop(key)
+ if key.startswith("diffusion_model."):
+ converted_state_dict[key[len("diffusion_model.") :]] = converted_state_dict.pop(key)
+
+ # Rename and remap the state dict keys
+ for key in list(converted_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
+
+ for key in list(converted_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ # Add back the "transformer." prefix
+ for key in list(converted_state_dict.keys()):
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
+
+ return converted_state_dict
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 351295e938ff..b5fda3c88635 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -36,6 +36,7 @@
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
+ _convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
@@ -2277,8 +2278,24 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], *
super().unfuse_lora(components=components)
- # We override this here account for `_transformer_norm_layers`.
- def unload_lora_weights(self):
+ # We override this here account for `_transformer_norm_layers` and `_overwritten_params`.
+ def unload_lora_weights(self, reset_to_overwritten_params=False):
+ """
+ Unloads the LoRA parameters.
+
+ Args:
+ reset_to_overwritten_params (`bool`, defaults to `False`): Whether to reset the LoRA-loaded modules
+ to their original params. Refer to the [Flux
+ documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
+
+ Examples:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the LoRA parameters.
+ >>> pipeline.unload_lora_weights()
+ >>> ...
+ ```
+ """
super().unload_lora_weights()
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
@@ -2286,7 +2303,7 @@ def unload_lora_weights(self):
transformer.load_state_dict(transformer._transformer_norm_layers, strict=False)
transformer._transformer_norm_layers = None
- if getattr(transformer, "_overwritten_params", None) is not None:
+ if reset_to_overwritten_params and getattr(transformer, "_overwritten_params", None) is not None:
overwritten_params = transformer._overwritten_params
module_names = set()
@@ -2466,7 +2483,9 @@ def _maybe_expand_lora_state_dict(cls, transformer, lora_state_dict):
continue
base_param_name = (
- f"{k.replace(prefix, '')}.base_layer.weight" if is_peft_loaded else f"{k.replace(prefix, '')}.weight"
+ f"{k.replace(prefix, '')}.base_layer.weight"
+ if is_peft_loaded and f"{k.replace(prefix, '')}.base_layer.weight" in transformer_state_dict
+ else f"{k.replace(prefix, '')}.weight"
)
base_weight_param = transformer_state_dict[base_param_name]
lora_A_param = lora_state_dict[f"{prefix}{k}.lora_A.weight"]
@@ -3989,7 +4008,6 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -4000,7 +4018,7 @@ def lora_state_dict(
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+ We support loading original format HunyuanVideo LoRA checkpoints.
This function is experimental and might change in the future.
@@ -4083,6 +4101,10 @@ def lora_state_dict(
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+ is_original_hunyuan_video = any("img_attn_qkv" in k for k in state_dict)
+ if is_original_hunyuan_video:
+ state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
+
return state_dict
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
@@ -4221,10 +4243,9 @@ def save_lora_weights(
safe_serialization=safe_serialization,
)
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
- components: List[str] = ["transformer", "text_encoder"],
+ components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
@@ -4265,8 +4286,7 @@ def fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
- # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
- def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py
index c0cbfc713857..c5c9bea29b8a 100644
--- a/src/diffusers/loaders/single_file.py
+++ b/src/diffusers/loaders/single_file.py
@@ -329,7 +329,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
>>> # Enable float16 and move to GPU
>>> pipeline = StableDiffusionPipeline.from_single_file(
- ... "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
+ ... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
... torch_dtype=torch.float16,
... )
>>> pipeline.to("cuda")
diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py
index 79dc2691b9e4..b65069e60d50 100644
--- a/src/diffusers/loaders/single_file_model.py
+++ b/src/diffusers/loaders/single_file_model.py
@@ -25,6 +25,7 @@
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
+ convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
@@ -106,6 +107,10 @@
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
"default_subfolder": "transformer",
},
+ "AuraFlowTransformer2DModel": {
+ "checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
}
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index b623576e3990..cefba48275cf 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -94,6 +94,12 @@
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
"animatediff_rgb": "controlnet_cond_embedding.weight",
+ "auraflow": [
+ "double_layers.0.attn.w2q.weight",
+ "double_layers.0.attn.w1q.weight",
+ "cond_seq_linear.weight",
+ "t_embedder.mlp.0.weight",
+ ],
"flux": [
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
@@ -109,6 +115,7 @@
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
+ "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -153,6 +160,7 @@
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
+ "auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
@@ -165,6 +173,7 @@
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
+ "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
}
# Use to configure model sample size when original config is provided
@@ -633,6 +642,15 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
model_type = "hunyuan-video"
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]):
+ model_type = "auraflow"
+
+ elif (
+ CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
+ and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
+ ):
+ model_type = "instruct-pix2pix"
+
else:
model_type = "v1"
@@ -2082,6 +2100,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
+
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
@@ -2681,3 +2700,95 @@ def update_state_dict_(state_dict, old_key, new_key):
handler_fn_inplace(key, checkpoint)
return checkpoint
+
+
+def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {}
+ state_dict_keys = list(checkpoint.keys())
+
+ # Handle register tokens and positional embeddings
+ converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None)
+
+ # Handle time step projection
+ converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None)
+ converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None)
+ converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None)
+ converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None)
+
+ # Handle context embedder
+ converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None)
+
+ # Calculate the number of layers
+ def calculate_layers(keys, key_prefix):
+ layers = set()
+ for k in keys:
+ if key_prefix in k:
+ layer_num = int(k.split(".")[1]) # get the layer number
+ layers.add(layer_num)
+ return len(layers)
+
+ mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
+ single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
+
+ # MMDiT blocks
+ for i in range(mmdit_layers):
+ # Feed-forward
+ path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
+ weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
+ for orig_k, diffuser_k in path_mapping.items():
+ for k, v in weight_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop(
+ f"double_layers.{i}.{orig_k}.{k}.weight", None
+ )
+
+ # Norms
+ path_mapping = {"modX": "norm1", "modC": "norm1_context"}
+ for orig_k, diffuser_k in path_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop(
+ f"double_layers.{i}.{orig_k}.1.weight", None
+ )
+
+ # Attentions
+ x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
+ context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
+ for attn_mapping in [x_attn_mapping, context_attn_mapping]:
+ for k, v in attn_mapping.items():
+ converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
+ f"double_layers.{i}.attn.{k}.weight", None
+ )
+
+ # Single-DiT blocks
+ for i in range(single_dit_layers):
+ # Feed-forward
+ mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
+ for k, v in mapping.items():
+ converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop(
+ f"single_layers.{i}.mlp.{k}.weight", None
+ )
+
+ # Norms
+ converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
+ f"single_layers.{i}.modCX.1.weight", None
+ )
+
+ # Attentions
+ x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
+ for k, v in x_attn_mapping.items():
+ converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop(
+ f"single_layers.{i}.attn.{k}.weight", None
+ )
+ # Final blocks
+ converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None)
+
+ # Handle the final norm layer
+ norm_weight = checkpoint.pop("modF.1.weight", None)
+ if norm_weight is not None:
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None)
+ else:
+ converted_state_dict["norm_out.linear.weight"] = None
+
+ converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding")
+ converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight")
+ converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
+
+ return converted_state_dict
diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py
index 0162d67a340c..095d154cb4fe 100644
--- a/src/diffusers/loaders/textual_inversion.py
+++ b/src/diffusers/loaders/textual_inversion.py
@@ -333,7 +333,7 @@ def load_textual_inversion(
from diffusers import StableDiffusionPipeline
import torch
- model_id = "runwayml/stable-diffusion-v1-5"
+ model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.load_textual_inversion("sd-concepts-library/cat-toy")
@@ -352,7 +352,7 @@ def load_textual_inversion(
from diffusers import StableDiffusionPipeline
import torch
- model_id = "runwayml/stable-diffusion-v1-5"
+ model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
@@ -469,7 +469,7 @@ def unload_textual_inversion(
from diffusers import AutoPipelineForText2Image
import torch
- pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
+ pipeline = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
# Example 1
pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py
index 7050968b6de5..d84c52c98440 100644
--- a/src/diffusers/loaders/unet.py
+++ b/src/diffusers/loaders/unet.py
@@ -343,6 +343,17 @@ def _process_lora(
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
+
+ if "lora_bias" in lora_config_kwargs:
+ if lora_config_kwargs["lora_bias"]:
+ if is_peft_version("<=", "0.13.2"):
+ raise ValueError(
+ "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
+ )
+ else:
+ if is_peft_version("<=", "0.13.2"):
+ lora_config_kwargs.pop("lora_bias")
+
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
diff --git a/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/src/diffusers/models/autoencoders/consistency_decoder_vae.py
index a97249f79473..4759b9141242 100644
--- a/src/diffusers/models/autoencoders/consistency_decoder_vae.py
+++ b/src/diffusers/models/autoencoders/consistency_decoder_vae.py
@@ -60,7 +60,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
>>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
>>> pipe = StableDiffusionPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
... ).to("cuda")
>>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 1768c81ce039..c64b9587be77 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -1248,7 +1248,8 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
- freqs_dtype = torch.float32 if is_mps else torch.float64
+ is_npu = ids.device.type == "npu"
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i],
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index d6efcc736487..789aeccf9b7f 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -920,14 +920,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else: # else let accelerate handle loading and dispatching.
# Load weights and dispatch according to the device_map
# by default the device_map is None and the weights are loaded on the CPU
- force_hook = True
device_map = _determine_device_map(
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
)
if device_map is None and is_sharded:
# we load the parameters on the cpu
device_map = {"": "cpu"}
- force_hook = False
try:
accelerate.load_checkpoint_and_dispatch(
model,
@@ -937,7 +935,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
- force_hooks=force_hook,
strict=True,
)
except AttributeError as e:
@@ -967,7 +964,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
- force_hooks=force_hook,
strict=True,
)
model._undo_temp_convert_self_to_deprecated_attention_blocks()
@@ -1214,7 +1210,7 @@ def _get_signature_keys(cls, obj):
# Adapted from `transformers` modeling_utils.py
def _get_no_split_modules(self, device_map: str):
"""
- Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
+ Get the modules of the model that should not be split when using device_map. We iterate through the modules to
get the underlying `_no_split_modules`.
Args:
diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py
index b3f29e6b6224..b35488a89282 100644
--- a/src/diffusers/models/transformers/auraflow_transformer_2d.py
+++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py
@@ -20,6 +20,7 @@
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
@@ -253,7 +254,7 @@ def forward(
return encoder_hidden_states, hidden_states
-class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
+class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
index b47d439774cc..e83c5be75b44 100644
--- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py
+++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
@@ -210,6 +210,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
_supports_gradient_checkpointing = True
+ _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
@register_to_config
def __init__(
diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py
index 027ab5fecefd..bc3877627529 100644
--- a/src/diffusers/models/transformers/sana_transformer.py
+++ b/src/diffusers/models/transformers/sana_transformer.py
@@ -250,7 +250,6 @@ def __init__(
inner_dim = num_attention_heads * attention_head_dim
# 1. Patch Embedding
- interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
self.patch_embed = PatchEmbed(
height=sample_size,
width=sample_size,
@@ -258,6 +257,7 @@ def __init__(
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
+ pos_embed_type="sincos" if interpolation_scale is not None else None,
)
# 2. Additional condition embeddings
diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py
index fe9c7290b063..81039fd49e0d 100644
--- a/src/diffusers/models/transformers/transformer_allegro.py
+++ b/src/diffusers/models/transformers/transformer_allegro.py
@@ -221,6 +221,8 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
Scaling factor to apply in 3D positional embeddings across time dimension.
"""
+ _supports_gradient_checkpointing = True
+
@register_to_config
def __init__(
self,
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 94d852f6df4b..369509a3a35e 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
+ _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"]
@register_to_config
def __init__(
diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py
index dc2eb26f9d30..f5e92700b2f3 100644
--- a/src/diffusers/models/transformers/transformer_flux.py
+++ b/src/diffusers/models/transformers/transformer_flux.py
@@ -85,11 +85,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
def forward(
self,
- hidden_states: torch.FloatTensor,
- temb: torch.FloatTensor,
- image_rotary_emb=None,
- joint_attention_kwargs=None,
- ):
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -117,15 +117,22 @@ class FluxTransformerBlock(nn.Module):
Reference: https://arxiv.org/abs/2403.03206
- Parameters:
- dim (`int`): The number of channels in the input and output.
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
- attention_head_dim (`int`): The number of channels in each head.
- context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
- processing of `context` conditions.
+ Args:
+ dim (`int`):
+ The embedding dimension of the block.
+ num_attention_heads (`int`):
+ The number of attention heads to use.
+ attention_head_dim (`int`):
+ The number of dimensions to use for each attention head.
+ qk_norm (`str`, defaults to `"rms_norm"`):
+ The normalization to use for the query and key tensors.
+ eps (`float`, defaults to `1e-6`):
+ The epsilon value to use for the normalization.
"""
- def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
+ def __init__(
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
+ ):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
@@ -164,12 +171,12 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_no
def forward(
self,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor,
- temb: torch.FloatTensor,
- image_rotary_emb=None,
- joint_attention_kwargs=None,
- ):
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
@@ -227,16 +234,30 @@ class FluxTransformer2DModel(
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
- Parameters:
- patch_size (`int`): Patch size to turn the input data into small patches.
- in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
- num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
- num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
- num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
- joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
- pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
- guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
+ Args:
+ patch_size (`int`, defaults to `1`):
+ Patch size to turn the input data into small patches.
+ in_channels (`int`, defaults to `64`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `None`):
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
+ num_layers (`int`, defaults to `19`):
+ The number of layers of dual stream DiT blocks to use.
+ num_single_layers (`int`, defaults to `38`):
+ The number of layers of single stream DiT blocks to use.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of dimensions to use for each attention head.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of attention heads to use.
+ joint_attention_dim (`int`, defaults to `4096`):
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
+ `encoder_hidden_states`).
+ pooled_projection_dim (`int`, defaults to `768`):
+ The number of dimensions to use for the pooled projection.
+ guidance_embeds (`bool`, defaults to `False`):
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
+ The dimensions to use for the rotary positional embeddings.
"""
_supports_gradient_checkpointing = True
@@ -259,7 +280,7 @@ def __init__(
):
super().__init__()
self.out_channels = out_channels or in_channels
- self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+ self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
@@ -267,20 +288,20 @@ def __init__(
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
- embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
- self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
- self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
- num_attention_heads=self.config.num_attention_heads,
- attention_head_dim=self.config.attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
)
- for i in range(self.config.num_layers)
+ for _ in range(num_layers)
]
)
@@ -288,10 +309,10 @@ def __init__(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
- num_attention_heads=self.config.num_attention_heads,
- attention_head_dim=self.config.attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
)
- for i in range(self.config.num_single_layers)
+ for _ in range(num_single_layers)
]
)
@@ -418,16 +439,16 @@ def forward(
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py
index e3f24d97f3fa..044f2048775f 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -542,6 +542,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"""
_supports_gradient_checkpointing = True
+ _no_split_modules = [
+ "HunyuanVideoTransformerBlock",
+ "HunyuanVideoSingleTransformerBlock",
+ "HunyuanVideoPatchEmbed",
+ "HunyuanVideoTokenRefiner",
+ ]
@register_to_config
def __init__(
@@ -713,14 +719,15 @@ def forward(
condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length
attention_mask = torch.zeros(
- batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool
- ) # [B, N, N]
+ batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
+ ) # [B, N]
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
for i in range(batch_size):
- attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True
+ attention_mask[i, : effective_sequence_length[i]] = True
+ attention_mask = attention_mask.unsqueeze(1) # [B, 1, N], for broadcasting across attention heads
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py
index bec62ce5cf45..090357237f46 100644
--- a/src/diffusers/models/unets/unet_2d.py
+++ b/src/diffusers/models/unets/unet_2d.py
@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
Tuple of downsample block types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
- Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
@@ -103,6 +103,7 @@ def __init__(
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
+ mid_block_type: Optional[str] = "UNetMidBlock2D",
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
layers_per_block: int = 2,
@@ -194,19 +195,22 @@ def __init__(
self.down_blocks.append(down_block)
# mid
- self.mid_block = UNetMidBlock2D(
- in_channels=block_out_channels[-1],
- temb_channels=time_embed_dim,
- dropout=dropout,
- resnet_eps=norm_eps,
- resnet_act_fn=act_fn,
- output_scale_factor=mid_block_scale_factor,
- resnet_time_scale_shift=resnet_time_scale_shift,
- attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
- resnet_groups=norm_num_groups,
- attn_groups=attn_norm_num_groups,
- add_attention=add_attention,
- )
+ if mid_block_type is None:
+ self.mid_block = None
+ else:
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ dropout=dropout,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ attn_groups=attn_norm_num_groups,
+ add_attention=add_attention,
+ )
# up
reversed_block_out_channels = list(reversed(block_out_channels))
@@ -322,7 +326,8 @@ def forward(
down_block_res_samples += res_samples
# 4. mid
- sample = self.mid_block(sample, emb)
+ if self.mid_block is not None:
+ sample = self.mid_block(sample, emb)
# 5. up
skip_sample = None
diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py
index b3650dc6cee1..91aedf2cdbe6 100644
--- a/src/diffusers/pipelines/allegro/pipeline_allegro.py
+++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py
@@ -33,6 +33,7 @@
deprecate,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -41,6 +42,14 @@
from .pipeline_output import AllegroPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
logger = logging.get_logger(__name__)
if is_bs4_available():
@@ -194,10 +203,10 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_temporal = (
- self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -921,6 +930,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
video = self.decode_latents(latents)
diff --git a/src/diffusers/pipelines/amused/pipeline_amused.py b/src/diffusers/pipelines/amused/pipeline_amused.py
index a8c24b0aeecc..12f7dc7c59d4 100644
--- a/src/diffusers/pipelines/amused/pipeline_amused.py
+++ b/src/diffusers/pipelines/amused/pipeline_amused.py
@@ -20,10 +20,18 @@
from ...image_processor import VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
-from ...utils import replace_example_docstring
+from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -66,7 +74,9 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
+ self.vae_scale_factor = (
+ 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
+ )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad()
@@ -297,6 +307,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type == "latent":
output = latents
else:
diff --git a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py
index c74275b414d4..7ac05b39c3a8 100644
--- a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py
+++ b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py
@@ -20,10 +20,18 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
-from ...utils import replace_example_docstring
+from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -81,7 +89,9 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
+ self.vae_scale_factor = (
+ 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
+ )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
@torch.no_grad()
@@ -323,6 +333,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type == "latent":
output = latents
else:
diff --git a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
index 24801e0ef977..d908c32745c2 100644
--- a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
+++ b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py
@@ -21,10 +21,18 @@
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models import UVit2DModel, VQModel
from ...schedulers import AmusedScheduler
-from ...utils import replace_example_docstring
+from ...utils import is_torch_xla_available, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -89,7 +97,9 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
+ self.vae_scale_factor = (
+ 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8
+ )
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
@@ -354,6 +364,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, timestep, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type == "latent":
output = latents
else:
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
index cb6f50f43c4f..5c1d1e2ae0ba 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py
@@ -34,6 +34,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -47,8 +48,16 @@
from .pipeline_output import AnimateDiffPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -139,7 +148,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
@@ -844,6 +853,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 9. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
index 626e46acbf7f..90c66e9e1973 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py
@@ -32,7 +32,7 @@
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
@@ -41,8 +41,16 @@
from .pipeline_output import AnimateDiffPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -180,7 +188,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_video_processor = VideoProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -1090,6 +1098,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 9. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
index 6016917537b9..c037c239a3b5 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
@@ -48,6 +48,7 @@
)
from ...utils import (
USE_PEFT_BACKEND,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -60,8 +61,16 @@
from .pipeline_output import AnimateDiffPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -307,7 +316,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
@@ -438,7 +447,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -497,8 +508,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1261,6 +1274,9 @@ def __call__(
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
index 6dde7d6686ee..42e0c6632632 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py
@@ -30,6 +30,7 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -42,8 +43,16 @@
from .pipeline_output import AnimateDiffPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```python
@@ -188,7 +197,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -994,6 +1003,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 11. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
index b0adbea77445..edac6bfd9e4e 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py
@@ -31,7 +31,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
@@ -40,8 +40,16 @@
from .pipeline_output import AnimateDiffPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -243,7 +251,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
def encode_prompt(
@@ -1037,6 +1045,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 10. Post-processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
index 10a27af246f7..1a75d658b3ad 100644
--- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
+++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py
@@ -39,7 +39,7 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
@@ -48,8 +48,16 @@
from .pipeline_output import AnimateDiffPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -270,7 +278,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_video_processor = VideoProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -1325,6 +1333,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 11. Post-processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
index 105ca40f773f..14c6d44fc586 100644
--- a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
+++ b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py
@@ -22,13 +22,21 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import logging, replace_example_docstring
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -94,7 +102,7 @@ def __init__(
scheduler=scheduler,
vocoder=vocoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def _encode_prompt(
self,
@@ -530,6 +538,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post-processing
mel_spectrogram = self.decode_latents(latents)
diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
index b45771d7de74..63a8b702f5e1 100644
--- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
@@ -48,8 +48,20 @@
if is_librosa_available():
import librosa
+
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -207,7 +219,7 @@ def __init__(
scheduler=scheduler,
vocoder=vocoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing
def enable_vae_slicing(self):
@@ -1033,6 +1045,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
self.maybe_free_model_hooks()
# 8. Post-processing
diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
index 0bb3fb7368d8..d3326c54973f 100644
--- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
+++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
@@ -146,9 +146,7 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def check_inputs(
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index f3a05c2c661f..8bbf1ebe9fa5 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -293,7 +293,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
If you get the error message below, you need to finetune the weights for your downstream task:
```
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
@@ -385,7 +385,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
```py
>>> from diffusers import AutoPipelineForText2Image
- >>> pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> pipeline = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> image = pipeline(prompt).images[0]
```
"""
@@ -448,7 +448,7 @@ def from_pipe(cls, pipeline, **kwargs):
>>> from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
>>> pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", requires_safety_checker=False
... )
>>> pipe_t2i = AutoPipelineForText2Image.from_pipe(pipe_i2i)
@@ -528,7 +528,9 @@ def from_pipe(cls, pipeline, **kwargs):
if k not in text_2_image_kwargs
}
- missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(text_2_image_kwargs.keys())
+ missing_modules = (
+ set(expected_modules) - set(text_2_image_cls._optional_components) - set(text_2_image_kwargs.keys())
+ )
if len(missing_modules) > 0:
raise ValueError(
@@ -587,7 +589,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
If you get the error message below, you need to finetune the weights for your downstream task:
```
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
@@ -679,7 +681,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
```py
>>> from diffusers import AutoPipelineForImage2Image
- >>> pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> pipeline = AutoPipelineForImage2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> image = pipeline(prompt, image).images[0]
```
"""
@@ -754,7 +756,7 @@ def from_pipe(cls, pipeline, **kwargs):
>>> from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
>>> pipe_t2i = AutoPipelineForText2Image.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", requires_safety_checker=False
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", requires_safety_checker=False
... )
>>> pipe_i2i = AutoPipelineForImage2Image.from_pipe(pipe_t2i)
@@ -838,7 +840,9 @@ def from_pipe(cls, pipeline, **kwargs):
if k not in image_2_image_kwargs
}
- missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(image_2_image_kwargs.keys())
+ missing_modules = (
+ set(expected_modules) - set(image_2_image_cls._optional_components) - set(image_2_image_kwargs.keys())
+ )
if len(missing_modules) > 0:
raise ValueError(
@@ -896,7 +900,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
If you get the error message below, you need to finetune the weights for your downstream task:
```
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
@@ -988,7 +992,7 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
```py
>>> from diffusers import AutoPipelineForInpainting
- >>> pipeline = AutoPipelineForInpainting.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> pipeline = AutoPipelineForInpainting.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
```
"""
@@ -1141,7 +1145,9 @@ def from_pipe(cls, pipeline, **kwargs):
if k not in inpainting_kwargs
}
- missing_modules = set(expected_modules) - set(pipeline._optional_components) - set(inpainting_kwargs.keys())
+ missing_modules = (
+ set(expected_modules) - set(inpainting_cls._optional_components) - set(inpainting_kwargs.keys())
+ )
if len(missing_modules) > 0:
raise ValueError(
diff --git a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
index ff23247b5f81..cbd8bef67945 100644
--- a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
+++ b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py
@@ -20,6 +20,7 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -30,8 +31,16 @@
from .modeling_ctx_clip import ContextCLIPTextModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -336,6 +345,9 @@ def __call__(
latents,
)["prev_sample"]
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
index a1555402ccf6..d78d5508dc7f 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
@@ -26,12 +26,19 @@
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
-from ...utils import logging, replace_example_docstring
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -183,14 +190,12 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_temporal = (
- self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
- )
- self.vae_scaling_factor_image = (
- self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -755,6 +760,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
index e4c6ca1206fe..46e7b9ee468e 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
@@ -27,12 +27,19 @@
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import logging, replace_example_docstring
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -190,14 +197,12 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_temporal = (
- self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
- )
- self.vae_scaling_factor_image = (
- self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -810,6 +815,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
index 6842123ff798..58793902345a 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
@@ -29,6 +29,7 @@
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -37,6 +38,13 @@
from .pipeline_output import CogVideoXPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -203,14 +211,12 @@ def __init__(
scheduler=scheduler,
)
self.vae_scale_factor_spatial = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_temporal = (
- self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
- )
- self.vae_scaling_factor_image = (
- self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -868,6 +874,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
index 945f7694caae..333e3418dca2 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
@@ -27,12 +27,19 @@
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
-from ...utils import logging, replace_example_docstring
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -206,14 +213,12 @@ def __init__(
)
self.vae_scale_factor_spatial = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_temporal = (
- self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
- )
- self.vae_scaling_factor_image = (
- self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@@ -836,6 +841,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
index 8bed88c275cf..0cd3943fbcd2 100644
--- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
+++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py
@@ -24,11 +24,18 @@
from ...models import AutoencoderKL, CogView3PlusTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
-from ...utils import logging, replace_example_docstring
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import CogView3PipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -153,9 +160,7 @@ def __init__(
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -656,6 +661,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
index d2f67a698917..f0c71655e628 100644
--- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
+++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
@@ -19,6 +19,7 @@
from ...models import UNet2DModel
from ...schedulers import CMStochasticIterativeScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -26,6 +27,13 @@
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -263,6 +271,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, sample)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 6. Post-process image sample
image = self.postprocess_image(sample, output_type=output_type)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
index 582f51ab480e..214835062a05 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py
@@ -80,7 +80,7 @@
>>> # load control net and stable diffusion v1-5
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
>>> pipe = StableDiffusionControlNetPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> # speed up diffusion process with faster scheduler and memory optimization
@@ -198,8 +198,8 @@ class StableDiffusionControlNetPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -254,7 +254,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
index 86e0ddef663e..88c387d48dd2 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
@@ -21,6 +21,7 @@
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -31,8 +32,16 @@
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -401,6 +410,10 @@ def __call__(
t,
latents,
)["prev_sample"]
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
index 59ac30d70d77..73ffeeb5e79c 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
@@ -30,6 +30,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -41,6 +42,13 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -71,7 +79,7 @@
>>> # load control net and stable diffusion v1-5
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
>>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> # speed up diffusion process with faster scheduler and memory optimization
@@ -168,8 +176,8 @@ class StableDiffusionControlNetImg2ImgPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -224,7 +232,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -1294,6 +1302,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index 977b852a89c9..875dbed38c4d 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -32,6 +32,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -43,6 +44,13 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -83,7 +91,7 @@
... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
... )
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
@@ -141,11 +149,11 @@ class StableDiffusionControlNetInpaintPipeline(
This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting
- ([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as
- default text-to-image Stable Diffusion checkpoints
- ([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image
- Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as
- [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
+ ([stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting))
+ as well as default text-to-image Stable Diffusion checkpoints
+ ([stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)).
+ Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on
+ those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
@@ -167,8 +175,8 @@ class StableDiffusionControlNetInpaintPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -223,7 +231,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -1476,6 +1484,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
index c6c4ce935a1f..38e63f56b2f3 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py
@@ -60,6 +60,16 @@
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -264,7 +274,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -406,7 +416,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -465,8 +477,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1622,7 +1636,7 @@ def denoising_value_valid(dnv):
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
@@ -1829,6 +1843,9 @@ def denoising_value_valid(dnv):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
self.upcast_vae()
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
index 536c00ee361c..77d496cf831d 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py
@@ -62,6 +62,16 @@
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -275,7 +285,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -415,7 +425,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -474,8 +486,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1548,6 +1562,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
index 0c4b250af6e6..86588a5b3851 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py
@@ -62,6 +62,16 @@
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -267,7 +277,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -408,7 +418,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -467,8 +479,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1608,6 +1622,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
index 7012f3b95458..56f6c9149c6e 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
@@ -60,6 +60,16 @@
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -246,7 +256,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -388,7 +398,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -447,8 +459,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1755,6 +1769,9 @@ def denoising_value_valid(dnv):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
self.upcast_vae()
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
index dcd885f7d604..a2e50d4f3e09 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
@@ -60,6 +60,17 @@
if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -257,7 +268,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -397,7 +408,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -456,8 +469,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1454,6 +1469,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
index 95cf067fce12..d4409c54b01c 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
@@ -61,6 +61,17 @@
if is_invisible_watermark_available():
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -281,7 +292,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -422,7 +433,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -481,8 +494,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1573,6 +1588,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
index 8a2cc08dbb2b..3d4b19ea552c 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
@@ -75,7 +75,10 @@
... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32
... )
>>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ ... controlnet=controlnet,
+ ... revision="flax",
+ ... dtype=jnp.float32,
... )
>>> params["controlnet"] = controlnet_params
@@ -132,8 +135,8 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -175,7 +178,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def prepare_text_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
index c8464f8108ea..f01c8cc4674d 100644
--- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py
@@ -269,9 +269,7 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = (
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
index 1de7ba424d54..d2e3e0f34519 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
@@ -17,14 +17,16 @@
import torch
from transformers import (
+ BaseImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
+ PreTrainedModel,
T5EncoderModel,
T5TokenizerFast,
)
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
+from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from ...models.transformers import SD3Transformer2DModel
@@ -138,7 +140,9 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
+class StableDiffusion3ControlNetPipeline(
+ DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin
+):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
@@ -174,10 +178,14 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
additional conditioning.
+ image_encoder (`PreTrainedModel`, *optional*):
+ Pre-trained Vision Model for IP Adapter.
+ feature_extractor (`BaseImageProcessor`, *optional*):
+ Image processor for IP Adapter.
"""
- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
- _optional_components = []
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
def __init__(
@@ -194,6 +202,8 @@ def __init__(
controlnet: Union[
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
],
+ image_encoder: PreTrainedModel = None,
+ feature_extractor: BaseImageProcessor = None,
):
super().__init__()
if isinstance(controlnet, (list, tuple)):
@@ -223,10 +233,10 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
controlnet=controlnet,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
@@ -727,6 +737,84 @@ def num_timesteps(self):
def interrupt(self):
return self._interrupt
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
+ def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
+
+ Args:
+ image (`PipelineImageInput`):
+ Input image to be encoded.
+ device: (`torch.device`):
+ Torch device.
+
+ Returns:
+ `torch.Tensor`: The encoded image feature representation.
+ """
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=self.dtype)
+
+ return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ ) -> torch.Tensor:
+ """Prepares image embeddings for use in the IP-Adapter.
+
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
+
+ Args:
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ The input image to extract features from for IP-Adapter.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Precomputed image embeddings.
+ device: (`torch.device`, *optional*):
+ Torch device.
+ num_images_per_prompt (`int`, defaults to 1):
+ Number of images that should be generated per prompt.
+ do_classifier_free_guidance (`bool`, defaults to True):
+ Whether to use classifier free guidance or not.
+ """
+ device = device or self._execution_device
+
+ if ip_adapter_image_embeds is not None:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
+ else:
+ single_image_embeds = ip_adapter_image_embeds
+ elif ip_adapter_image is not None:
+ single_image_embeds = self.encode_image(ip_adapter_image, device)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.zeros_like(single_image_embeds)
+ else:
+ raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
+
+ image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
+ image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
+
+ return image_embeds.to(device=device)
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
+ def enable_sequential_cpu_offload(self, *args, **kwargs):
+ if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
+ logger.warning(
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
+ )
+
+ super().enable_sequential_cpu_offload(*args, **kwargs)
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -754,6 +842,8 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -843,6 +933,12 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1040,7 +1136,22 @@ def __call__(
# SD35 official 8b controlnet does not use encoder_hidden_states
controlnet_encoder_hidden_states = None
- # 7. Denoising loop
+ # 7. Prepare image embeddings
+ if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
+ ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
+ else:
+ self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
+
+ # 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
index 5d5249922f8d..1040ff265985 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
@@ -230,9 +230,7 @@ def __init__(
scheduler=scheduler,
controlnet=controlnet,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_resize=True, do_convert_rgb=True, do_normalize=True
)
diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
index ca10e65de8a4..901ca25c576c 100644
--- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
+++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
@@ -30,6 +30,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -41,6 +42,13 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -178,7 +186,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -884,6 +892,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
index 326cfdab7be7..acf1f5489ec1 100644
--- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py
@@ -54,6 +54,16 @@
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -196,7 +206,7 @@ def __init__(
scheduler=scheduler,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -336,7 +346,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -395,8 +407,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1074,6 +1088,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# manually for max memory savings
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
self.upcast_vae()
diff --git a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
index bcd36c412b54..ed342f66804a 100644
--- a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
+++ b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py
@@ -17,11 +17,18 @@
import torch
-from ...utils import logging
+from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -146,6 +153,9 @@ def __call__(
# 2. compute previous audio sample: x_t -> t_t-1
audio = self.scheduler.step(model_output, t, audio).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
audio = audio.clamp(-1, 1).float().cpu().numpy()
audio = audio[:, :, :original_sample_size]
diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py
index a3b967ed369b..1b424f5742f2 100644
--- a/src/diffusers/pipelines/ddim/pipeline_ddim.py
+++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py
@@ -17,10 +17,19 @@
import torch
from ...schedulers import DDIMScheduler
+from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
class DDIMPipeline(DiffusionPipeline):
r"""
Pipeline for image generation.
@@ -143,6 +152,9 @@ def __call__(
model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
index bb03a8d66758..e58a53b5b7e8 100644
--- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
+++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py
@@ -17,10 +17,19 @@
import torch
+from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
class DDPMPipeline(DiffusionPipeline):
r"""
Pipeline for image generation.
@@ -116,6 +125,9 @@ def __call__(
# 2. compute previous image: x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
index f545b24bec5c..150978de6e5e 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py
@@ -14,6 +14,7 @@
BACKENDS_MAPPING,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -24,8 +25,16 @@
from .watermark import IFWatermarker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -735,6 +744,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = intermediate_images
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
index 07017912575d..a92d7be6a11c 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py
@@ -17,6 +17,7 @@
PIL_INTERPOLATION,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -27,8 +28,16 @@
from .watermark import IFWatermarker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -856,6 +865,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = intermediate_images
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
index 6685ba6d774a..f39a63f22e11 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py
@@ -35,6 +35,16 @@
import ftfy
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -974,6 +984,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = intermediate_images
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
index 7fca0bc0443c..030821b789aa 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py
@@ -17,6 +17,7 @@
PIL_INTERPOLATION,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -27,8 +28,16 @@
from .watermark import IFWatermarker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -975,6 +984,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = intermediate_images
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
index 4f04a1de2a6e..8ea5e16090c2 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py
@@ -35,6 +35,16 @@
import ftfy
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -1085,6 +1095,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = intermediate_images
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
index 891963f2a904..da3d2ea087e0 100644
--- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
+++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py
@@ -34,6 +34,16 @@
import ftfy
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -831,6 +841,9 @@ def __call__(
if callback is not None and i % callback_steps == 0:
callback(i, t, intermediate_images)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = intermediate_images
if output_type == "pil":
diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
index a1930da4180e..705bf3661ffb 100644
--- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
@@ -210,7 +210,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -224,7 +224,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -284,7 +284,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
index e40b6efd71ab..af77cac3cb8a 100644
--- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
+++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -238,7 +238,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -252,7 +252,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -312,7 +312,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
index 777be883cb9d..70ad47074ca2 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
@@ -184,7 +184,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -243,7 +243,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py
index 0aa5e68bfcb4..e9553a8d99b0 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py
@@ -93,7 +93,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -107,7 +107,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
index ce7ad3b0dfe9..f4483fc47b79 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
@@ -140,7 +140,7 @@ def __init__(
)
deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False)
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -154,7 +154,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -213,7 +213,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
index 9e91986896bd..06db871daf62 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
@@ -121,7 +121,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
index be21900ab55a..d486a32f6a4c 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
@@ -143,7 +143,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
index 2978972200c7..509f25620950 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
@@ -365,7 +365,7 @@ def __init__(
caption_generator=caption_generator,
inverse_scheduler=inverse_scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py
index c8dc18e2e8ac..4fb437958abd 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py
@@ -76,7 +76,7 @@ def __init__(
vae=vae,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
@torch.no_grad()
def image_variation(
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
index 2212651fbb5b..0065279bc0b1 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py
@@ -94,7 +94,7 @@ def __init__(
vae=vae,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
if self.text_unet is not None and (
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
index 62d3e83a4790..7dfc7e961825 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
@@ -77,7 +77,7 @@ def __init__(
vae=vae,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
index de4c2ac9b7f4..1d6771793f39 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
@@ -82,7 +82,7 @@ def __init__(
vae=vae,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
if self.text_unet is not None:
diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py
index 14321b5f33cf..cf5ebbce2ba8 100644
--- a/src/diffusers/pipelines/dit/pipeline_dit.py
+++ b/src/diffusers/pipelines/dit/pipeline_dit.py
@@ -24,10 +24,19 @@
from ...models import AutoencoderKL, DiTTransformer2DModel
from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
class DiTPipeline(DiffusionPipeline):
r"""
Pipeline for image generation based on a Transformer backbone instead of a UNet.
@@ -211,6 +220,9 @@ def __call__(
# compute previous image: x_t -> x_t-1
latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if guidance_scale > 1:
latents, _ = latent_model_input.chunk(2, dim=0)
else:
diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index 41db7515d0f4..c23b660300db 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -206,9 +206,7 @@ def __init__(
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py
index 9a43c03fdd5d..8aece8527556 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py
@@ -213,12 +213,8 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
- self.vae_latent_channels = (
- self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ self.vae_latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
index 63b9a109ef47..c386f41c8827 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py
@@ -227,9 +227,7 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
index e6d36af18197..192b690f69e5 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py
@@ -258,15 +258,14 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2,
- vae_latent_channels=self.vae.config.latent_channels,
+ vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
index 2e0b8fc46eea..30e244bae000 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
@@ -229,9 +229,7 @@ def __init__(
scheduler=scheduler,
controlnet=controlnet,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
index c5886e8fd757..d8aefc3942e9 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py
@@ -227,9 +227,7 @@ def __init__(
scheduler=scheduler,
controlnet=controlnet,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
index a64da0a36f62..bfc96eeb8dab 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py
@@ -230,15 +230,14 @@ def __init__(
controlnet=controlnet,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2,
- vae_latent_channels=self.vae.config.latent_channels,
+ vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py
index 3f0875c6cfde..ed8623e31733 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py
@@ -221,15 +221,14 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2,
- vae_latent_channels=self.vae.config.latent_channels,
+ vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
index c0ea2c0f9436..a63ecdadbd0c 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py
@@ -211,9 +211,7 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
index bfc05b2f7079..2be8e75973ef 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
@@ -208,15 +208,14 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2,
- vae_latent_channels=self.vae.config.latent_channels,
+ vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
index 3b0956a32da3..5c3d6ce611cc 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
@@ -23,15 +23,23 @@
from ...loaders import HunyuanVideoLoraLoaderMixin
from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import logging, replace_example_docstring
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HunyuanVideoPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```python
@@ -184,12 +192,8 @@ def __init__(
tokenizer_2=tokenizer_2,
)
- self.vae_scale_factor_temporal = (
- self.vae.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
- )
- self.vae_scale_factor_spatial = (
- self.vae.spatial_compression_ratio if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_llama_prompt_embeds(
@@ -671,6 +675,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
video = self.vae.decode(latents, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
index 6f542cb59f46..6a5cf298d2d4 100644
--- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
+++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
@@ -240,9 +240,7 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = (
diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
index f528b60e6ed7..58d65a190d5b 100644
--- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
+++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py
@@ -27,6 +27,7 @@
from ...schedulers import DDIMScheduler
from ...utils import (
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -35,8 +36,16 @@
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -133,7 +142,7 @@ def __init__(
unet=unet,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# `do_resize=False` as we do custom resizing.
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False)
@@ -711,6 +720,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
index b2041e101564..b5f4acf5c05a 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py
@@ -22,6 +22,7 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -30,8 +31,16 @@
from .text_encoder import MultilingualCLIP
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -385,6 +394,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
index ef5241fee5d2..5d56efef9287 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py
@@ -25,6 +25,7 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -33,8 +34,16 @@
from .text_encoder import MultilingualCLIP
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -478,6 +487,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 7. post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
index 778b6e314c0d..cce5f0b3d5bc 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py
@@ -29,6 +29,7 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -37,8 +38,16 @@
from .text_encoder import MultilingualCLIP
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -613,6 +622,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
index b5152d71cb6b..a348deef8b29 100644
--- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
+++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py
@@ -24,6 +24,7 @@
from ...schedulers import UnCLIPScheduler
from ...utils import (
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -31,8 +32,16 @@
from ..pipeline_utils import DiffusionPipeline
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -519,6 +528,9 @@ def __call__(
prev_timestep=prev_timestep,
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
latents = self.prior.post_process_latents(latents)
image_embeddings = latents
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
index 471db61556f5..a584674540d8 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py
@@ -18,13 +18,21 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
-from ...utils import deprecate, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -296,6 +304,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
index 0130c3951b38..bada59080c7b 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py
@@ -19,14 +19,23 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -297,6 +306,10 @@ def __call__(
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
index 12be1534c642..4f6c4188bd48 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py
@@ -22,14 +22,23 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -358,6 +367,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
index 899273a1a736..624748896911 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py
@@ -21,13 +21,21 @@
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -372,6 +380,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
index b5ba7a0011a1..482093a4bb29 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py
@@ -25,13 +25,21 @@
from ... import __version__
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -526,6 +534,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
index f2134b22b40b..d05a7fbdb1b8 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py
@@ -7,6 +7,7 @@
from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -15,8 +16,16 @@
from ..pipeline_utils import DiffusionPipeline
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -524,6 +533,9 @@ def __call__(
)
text_mask = callback_outputs.pop("text_mask", text_mask)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
latents = self.prior.post_process_latents(latents)
image_embeddings = latents
diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
index ec6509bb3cb5..56d326e26e6e 100644
--- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
+++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py
@@ -7,6 +7,7 @@
from ...models import PriorTransformer
from ...schedulers import UnCLIPScheduler
from ...utils import (
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -15,8 +16,16 @@
from ..pipeline_utils import DiffusionPipeline
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -538,6 +547,9 @@ def __call__(
prev_timestep=prev_timestep,
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
latents = self.prior.post_process_latents(latents)
image_embeddings = latents
diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
index 8dbae2a1909a..5309f94a53c8 100644
--- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
+++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py
@@ -8,6 +8,7 @@
from ...schedulers import DDPMScheduler
from ...utils import (
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -15,8 +16,16 @@
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -549,6 +558,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
index 81c45c4fb6f8..fbdad79db445 100644
--- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
+++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
@@ -12,6 +12,7 @@
from ...schedulers import DDPMScheduler
from ...utils import (
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -19,8 +20,16 @@
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -617,6 +626,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# post-processing
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors.py b/src/diffusers/pipelines/kolors/pipeline_kolors.py
index 1d2d07572d68..dce060f726a8 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors.py
@@ -188,9 +188,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
index 6ddda7acf2a8..890a67fb3e25 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
@@ -207,9 +207,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
index e985648abace..1c59ca7d6d7c 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
@@ -30,6 +30,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -40,6 +41,13 @@
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -226,7 +234,7 @@ def __init__(
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
@@ -952,6 +960,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
denoised = denoised.to(prompt_embeds.dtype)
if not output_type == "latent":
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
index d110cd464522..a3d9917d3376 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
@@ -29,6 +29,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -39,8 +40,16 @@
from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -209,7 +218,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -881,6 +890,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
denoised = denoised.to(prompt_embeds.dtype)
if not output_type == "latent":
image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
index cd63637b6c2f..d079e71fe38e 100644
--- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
@@ -25,10 +25,19 @@
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
+from ...utils import is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
class LDMTextToImagePipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using latent diffusion.
@@ -202,6 +211,9 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# scale and decode the image latents with vae
latents = 1 / self.vqvae.config.scaling_factor * latents
image = self.vqvae.decode(latents).sample
diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
index bb72b4d4eb8e..879722e6a0e2 100644
--- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
+++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
@@ -15,11 +15,19 @@
LMSDiscreteScheduler,
PNDMScheduler,
)
-from ...utils import PIL_INTERPOLATION
+from ...utils import PIL_INTERPOLATION, is_torch_xla_available
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
def preprocess(image):
w, h = image.size
w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
@@ -174,6 +182,9 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# decode the image latents with the VQVAE
image = self.vqvae.decode(latents).sample
image = torch.clamp(image, -1.0, 1.0)
diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py
index 19c4a6d1ddf9..852a2b7b795e 100644
--- a/src/diffusers/pipelines/latte/pipeline_latte.py
+++ b/src/diffusers/pipelines/latte/pipeline_latte.py
@@ -32,6 +32,7 @@
BaseOutput,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -39,8 +40,16 @@
from ...video_processor import VideoProcessor
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -180,7 +189,7 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
@@ -836,6 +845,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latents":
video = self.decode_latents(latents, video_length, decode_chunk_size=14)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
index f0f71080d0a3..3c1c2924e9db 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
@@ -19,6 +19,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -29,26 +30,32 @@
from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
- >>> import PIL
- >>> import requests
>>> import torch
- >>> from io import BytesIO
>>> from diffusers import LEditsPPPipelineStableDiffusion
>>> from diffusers.utils import load_image
>>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+ ... "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
... )
+ >>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda")
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
- >>> image = load_image(img_url).convert("RGB")
+ >>> image = load_image(img_url).resize((512, 512))
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)
@@ -152,7 +159,7 @@ def __init__(self, device):
# The gaussian kernel is the product of the gaussian function of each dimension.
kernel = 1
- meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
+ meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
@@ -318,7 +325,7 @@ def __init__(
"The scheduler has been changed to DPMSolverMultistepScheduler."
)
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -332,7 +339,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -391,7 +398,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -706,6 +713,35 @@ def clip_skip(self):
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1182,6 +1218,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
@@ -1271,6 +1310,8 @@ def invert(
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
and respective VAE reconstruction(s).
"""
+ if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
+ raise ValueError("height and width must be a factor of 32.")
# Reset attn processor, we do not want to store attn maps during inversion
self.unet.set_attn_processor(AttnProcessor())
@@ -1349,6 +1390,9 @@ def invert(
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1)
zs = zs.flip(0)
self.zs = zs
@@ -1360,6 +1404,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
image = self.image_processor.preprocess(
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
+ height, width = image.shape[-2:]
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(
+ "Image height and width must be a factor of 32. "
+ "Consider down-sampling the input using the `height` and `width` parameters"
+ )
resized = self.image_processor.postprocess(image=image, output_type="pil")
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
index 834445bfcd06..fe45d7f9fa2e 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py
@@ -72,25 +72,18 @@
Examples:
```py
>>> import torch
- >>> import PIL
- >>> import requests
- >>> from io import BytesIO
>>> from diffusers import LEditsPPPipelineStableDiffusionXL
+ >>> from diffusers.utils import load_image
>>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
- ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ... "stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", torch_dtype=torch.float16
... )
+ >>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda")
-
- >>> def download_image(url):
- ... response = requests.get(url)
- ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
-
-
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
- >>> image = download_image(img_url)
+ >>> image = load_image(img_url).resize((1024, 1024))
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2)
@@ -197,7 +190,7 @@ def __init__(self, device):
# The gaussian kernel is the product of the gaussian function of each dimension.
kernel = 1
- meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
+ meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
@@ -379,7 +372,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
if not isinstance(scheduler, DDIMScheduler) and not isinstance(scheduler, DPMSolverMultistepScheduler):
@@ -768,6 +761,35 @@ def denoising_end(self):
def num_timesteps(self):
return self._num_timesteps
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
def prepare_unet(self, attention_store, PnP: bool = False):
attn_procs = {}
@@ -1401,6 +1423,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
image = self.image_processor.preprocess(
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
+ height, width = image.shape[-2:]
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(
+ "Image height and width must be a factor of 32. "
+ "Consider down-sampling the input using the `height` and `width` parameters"
+ )
resized = self.image_processor.postprocess(image=image, output_type="pil")
if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
@@ -1439,6 +1467,10 @@ def invert(
crops_coords_top_left: Tuple[int, int] = (0, 0),
num_zero_noise_steps: int = 3,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ resize_mode: Optional[str] = "default",
+ crops_coords: Optional[Tuple[int, int, int, int]] = None,
):
r"""
The function to the pipeline for image inversion as described by the [LEDITS++
@@ -1486,6 +1518,8 @@ def invert(
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
and respective VAE reconstruction(s).
"""
+ if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
+ raise ValueError("height and width must be a factor of 32.")
# Reset attn processor, we do not want to store attn maps during inversion
self.unet.set_attn_processor(AttnProcessor())
@@ -1510,7 +1544,14 @@ def invert(
do_classifier_free_guidance = source_guidance_scale > 1.0
# 1. prepare image
- x0, resized = self.encode_image(image, dtype=self.text_encoder_2.dtype)
+ x0, resized = self.encode_image(
+ image,
+ dtype=self.text_encoder_2.dtype,
+ height=height,
+ width=width,
+ resize_mode=resize_mode,
+ crops_coords=crops_coords,
+ )
width = x0.shape[2] * self.vae_scale_factor
height = x0.shape[3] * self.vae_scale_factor
self.size = (height, width)
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py
index e98aedcc97a5..c49918cb7d21 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py
@@ -186,16 +186,22 @@ def __init__(
scheduler=scheduler,
)
- self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio if hasattr(self, "vae") else 32
- self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio if hasattr(self, "vae") else 8
- self.transformer_spatial_patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") else 1
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
+ )
self.transformer_temporal_patch_size = (
- self.transformer.config.patch_size_t if hasattr(self, "transformer") else 1
+ self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
self.tokenizer_max_length = (
- self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
)
def _get_t5_prompt_embeds(
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
index 691c1bd80376..b1dcc41d887e 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
@@ -205,16 +205,22 @@ def __init__(
scheduler=scheduler,
)
- self.vae_spatial_compression_ratio = self.vae.spatial_compression_ratio if hasattr(self, "vae") else 32
- self.vae_temporal_compression_ratio = self.vae.temporal_compression_ratio if hasattr(self, "vae") else 8
- self.transformer_spatial_patch_size = self.transformer.config.patch_size if hasattr(self, "transformer") else 1
+ self.vae_spatial_compression_ratio = (
+ self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
+ )
+ self.vae_temporal_compression_ratio = (
+ self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
+ )
+ self.transformer_spatial_patch_size = (
+ self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
+ )
self.transformer_temporal_patch_size = (
- self.transformer.config.patch_size_t if hasattr(self, "transformer") else 1
+ self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
self.tokenizer_max_length = (
- self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128
+ self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
)
self.default_height = 512
diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py
index 0a59d98919f0..52bb6546031d 100644
--- a/src/diffusers/pipelines/lumina/pipeline_lumina.py
+++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py
@@ -31,6 +31,7 @@
BACKENDS_MAPPING,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -38,8 +39,16 @@
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -874,6 +883,9 @@ def __call__(
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
index a602ba611ea5..e5cd62e35773 100644
--- a/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
@@ -37,6 +37,7 @@
)
from ...utils import (
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -46,6 +47,13 @@
from .marigold_image_processing import MarigoldImageProcessor
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -174,7 +182,7 @@ def __init__(
default_processing_resolution=default_processing_resolution,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.scale_invariant = scale_invariant
self.shift_invariant = shift_invariant
@@ -517,6 +525,9 @@ def __call__(
noise, t, batch_pred_latent, generator=generator
).prev_sample # [B,4,h,w]
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
pred_latents.append(batch_pred_latent)
pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w]
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
index aa9ad36ffc35..22f155f92022 100644
--- a/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
@@ -36,6 +36,7 @@
)
from ...utils import (
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -44,6 +45,13 @@
from .marigold_image_processing import MarigoldImageProcessor
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -161,7 +169,7 @@ def __init__(
default_processing_resolution=default_processing_resolution,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.use_full_z_range = use_full_z_range
self.default_denoising_steps = default_denoising_steps
@@ -493,6 +501,9 @@ def __call__(
noise, t, batch_pred_latent, generator=generator
).prev_sample # [B,4,h,w]
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
pred_latents.append(batch_pred_latent)
pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w]
diff --git a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
index 728635da6d4d..73837af7d429 100644
--- a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
+++ b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py
@@ -42,8 +42,20 @@
if is_librosa_available():
import librosa
+
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -111,7 +123,7 @@ def __init__(
scheduler=scheduler,
vocoder=vocoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def _encode_prompt(
self,
@@ -603,6 +615,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
self.maybe_free_model_hooks()
# 8. Post-processing
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
index 28c4f3d32b78..bc90073cba77 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
@@ -30,6 +30,7 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -42,6 +43,13 @@
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -251,7 +259,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -1293,6 +1301,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
index 3ad9cbf45f0d..bc7a4b57affd 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
@@ -31,6 +31,7 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -43,6 +44,13 @@
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -228,7 +236,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -1505,6 +1513,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
index 15a93357470f..83540885bfb2 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py
@@ -62,6 +62,16 @@
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -280,7 +290,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -421,7 +431,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -480,8 +492,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1560,6 +1574,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
index 19c26b98ba37..b84f5d555914 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py
@@ -62,6 +62,16 @@
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -270,7 +280,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -413,7 +423,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -472,8 +484,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1626,6 +1640,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
index dea1f12696b2..a6a8deb5883c 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
@@ -245,9 +245,7 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = (
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
index 3e84f44adcf7..458a4d4667bf 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py
@@ -202,9 +202,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
index b2fbdd683e86..d927a7961a16 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
@@ -29,6 +29,7 @@
deprecate,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -43,8 +44,16 @@
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -172,7 +181,7 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.set_pag_applied_layers(pag_applied_layers)
@@ -843,6 +852,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py
index 03662bb37158..f363a1a557bc 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py
@@ -30,6 +30,7 @@
BACKENDS_MAPPING,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -43,8 +44,16 @@
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -162,7 +171,11 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
+ self.vae_scale_factor = (
+ 2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
+ if hasattr(self, "vae") and self.vae is not None
+ else 8
+ )
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.set_pag_applied_layers(
@@ -863,6 +876,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type == "latent":
image = latents
else:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_sd.py
index 6220a00f2c22..86c810ab1a10 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd.py
@@ -27,6 +27,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -39,8 +40,16 @@
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -207,7 +216,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -221,7 +230,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -281,7 +290,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1034,6 +1043,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
index d1b96e75574f..0285239aaa8d 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
@@ -200,9 +200,7 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
index 24e31fa4cfc7..121be4ce2c07 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
@@ -216,9 +216,7 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
index 1e81fa3a158c..d3a015e569c1 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
@@ -26,6 +26,7 @@
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -40,8 +41,16 @@
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -147,7 +156,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
self.set_pag_applied_layers(pag_applied_layers)
@@ -847,6 +856,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 9. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
index b7a695be17e5..c38fcf86c4a7 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
@@ -30,6 +30,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -42,8 +43,16 @@
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -202,7 +211,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -216,7 +225,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -276,7 +285,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1066,6 +1075,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
index ff6ba8a6a853..8fb677e56bbb 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
@@ -28,6 +28,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -40,8 +41,16 @@
from .pag_utils import PAGMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -234,7 +243,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -248,7 +257,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -308,7 +317,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -1318,6 +1327,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
condition_kwargs = {}
if isinstance(self.vae, AsymmetricAutoencoderKL):
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
index c2611164a049..856b07102363 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
@@ -275,7 +275,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
@@ -415,7 +415,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -474,8 +476,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
index 6d634d524848..93dcca0ea9d6 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
@@ -298,7 +298,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -436,7 +436,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -495,8 +497,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
index 7f85c13ac561..fdf3df2f4d6a 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
@@ -314,7 +314,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -526,7 +526,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -585,8 +587,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
index b225fd71edf8..55a9f47145a2 100644
--- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
+++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
@@ -23,7 +23,7 @@
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
@@ -31,6 +31,13 @@
from .image_encoder import PaintByExampleImageEncoder
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -209,7 +216,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -604,6 +611,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
self.maybe_free_model_hooks()
if not output_type == "latent":
diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py
index b7dfcd39edce..df8499ab900a 100644
--- a/src/diffusers/pipelines/pia/pipeline_pia.py
+++ b/src/diffusers/pipelines/pia/pipeline_pia.py
@@ -37,6 +37,7 @@
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -48,8 +49,16 @@
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -195,7 +204,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
@@ -928,6 +937,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 9. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py
index f7b101124181..5486bc35f035 100644
--- a/src/diffusers/pipelines/pipeline_flax_utils.py
+++ b/src/diffusers/pipelines/pipeline_flax_utils.py
@@ -237,15 +237,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If you get the error message below, you need to finetune the weights for your downstream task:
```
- Some weights of FlaxUNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of FlaxUNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
```
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- - A string, the *repo id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained pipeline
- hosted on the Hub.
+ - A string, the *repo id* (for example `stable-diffusion-v1-5/stable-diffusion-v1-5`) of a
+ pretrained pipeline hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
using [`~FlaxDiffusionPipeline.save_pretrained`].
dtype (`str` or `jnp.dtype`, *optional*):
@@ -293,7 +293,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
>>> # Requires to be logged in to Hugging Face hub,
>>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> pipeline, params = FlaxDiffusionPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5",
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5",
... variant="bf16",
... dtype=jnp.bfloat16,
... )
@@ -301,7 +301,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
>>> # Download pipeline, but use a different scheduler
>>> from diffusers import FlaxDPMSolverMultistepScheduler
- >>> model_id = "runwayml/stable-diffusion-v1-5"
+ >>> model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
>>> dpmpp, dpmpp_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
... model_id,
... subfolder="scheduler",
@@ -559,7 +559,7 @@ def components(self) -> Dict[str, Any]:
... )
>>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jnp.bfloat16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="bf16", dtype=jnp.bfloat16
... )
>>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
```
diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py
index 0a7a222ec007..23f1279e203d 100644
--- a/src/diffusers/pipelines/pipeline_loading_utils.py
+++ b/src/diffusers/pipelines/pipeline_loading_utils.py
@@ -813,9 +813,9 @@ def _maybe_raise_warning_for_inpainting(pipeline_class, pretrained_model_name_or
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
- " checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
+ " checkpoint: https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting instead or adapting your"
f" checkpoint {pretrained_model_name_or_path} to the format of"
- " https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
+ " https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting. Note that we do not actively maintain"
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
)
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index c505c5a262a3..527724d1de1a 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -411,6 +411,13 @@ def module_is_offloaded(module):
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
+
+ is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
+ if is_pipeline_device_mapped:
+ raise ValueError(
+ "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
+ )
+
if device and torch.device(device).type == "cuda":
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
raise ValueError(
@@ -422,12 +429,6 @@ def module_is_offloaded(module):
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
)
- is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
- if is_pipeline_device_mapped:
- raise ValueError(
- "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
- )
-
# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and device and torch.device(device).type == "cuda":
@@ -516,7 +517,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
If you get the error message below, you need to finetune the weights for your downstream task:
```
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
@@ -643,7 +644,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
>>> # Download pipeline that requires an authorization token
>>> # For more information on access tokens, please refer to this section
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
- >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> # Use a different scheduler
>>> from diffusers import LMSDiscreteScheduler
@@ -1555,7 +1556,7 @@ def components(self) -> Dict[str, Any]:
... StableDiffusionInpaintPipeline,
... )
- >>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> text2img = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
>>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
```
@@ -1688,7 +1689,7 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto
>>> from diffusers import StableDiffusionPipeline
>>> pipe = StableDiffusionPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5",
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5",
... torch_dtype=torch.float16,
... use_safetensors=True,
... )
@@ -1735,7 +1736,7 @@ def from_pipe(cls, pipeline, **kwargs):
```py
>>> from diffusers import StableDiffusionPipeline, StableDiffusionSAGPipeline
- >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> new_pipe = StableDiffusionSAGPipeline.from_pipe(pipe)
```
"""
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
index 391b831166d2..46a7337051ef 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
@@ -29,6 +29,7 @@
deprecate,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -36,8 +37,16 @@
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -285,7 +294,7 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
@@ -943,6 +952,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
index 64e1e5bae06c..356ba3a29af3 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
@@ -29,6 +29,7 @@
deprecate,
is_bs4_available,
is_ftfy_available,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -41,8 +42,16 @@
)
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
if is_bs4_available():
from bs4 import BeautifulSoup
@@ -211,7 +220,7 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300
@@ -854,6 +863,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py
index c90dec4d41b3..afc2f74c9e8f 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana.py
@@ -63,6 +63,49 @@
import ftfy
+ASPECT_RATIO_4096_BIN = {
+ "0.25": [2048.0, 8192.0],
+ "0.26": [2048.0, 7936.0],
+ "0.27": [2048.0, 7680.0],
+ "0.28": [2048.0, 7424.0],
+ "0.32": [2304.0, 7168.0],
+ "0.33": [2304.0, 6912.0],
+ "0.35": [2304.0, 6656.0],
+ "0.4": [2560.0, 6400.0],
+ "0.42": [2560.0, 6144.0],
+ "0.48": [2816.0, 5888.0],
+ "0.5": [2816.0, 5632.0],
+ "0.52": [2816.0, 5376.0],
+ "0.57": [3072.0, 5376.0],
+ "0.6": [3072.0, 5120.0],
+ "0.68": [3328.0, 4864.0],
+ "0.72": [3328.0, 4608.0],
+ "0.78": [3584.0, 4608.0],
+ "0.82": [3584.0, 4352.0],
+ "0.88": [3840.0, 4352.0],
+ "0.94": [3840.0, 4096.0],
+ "1.0": [4096.0, 4096.0],
+ "1.07": [4096.0, 3840.0],
+ "1.13": [4352.0, 3840.0],
+ "1.21": [4352.0, 3584.0],
+ "1.29": [4608.0, 3584.0],
+ "1.38": [4608.0, 3328.0],
+ "1.46": [4864.0, 3328.0],
+ "1.67": [5120.0, 3072.0],
+ "1.75": [5376.0, 3072.0],
+ "2.0": [5632.0, 2816.0],
+ "2.09": [5888.0, 2816.0],
+ "2.4": [6144.0, 2560.0],
+ "2.5": [6400.0, 2560.0],
+ "2.89": [6656.0, 2304.0],
+ "3.0": [6912.0, 2304.0],
+ "3.11": [7168.0, 2304.0],
+ "3.62": [7424.0, 2048.0],
+ "3.75": [7680.0, 2048.0],
+ "3.88": [7936.0, 2048.0],
+ "4.0": [8192.0, 2048.0],
+}
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -619,7 +662,7 @@ def __call__(
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
- clean_caption: bool = True,
+ clean_caption: bool = False,
use_resolution_binning: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
@@ -734,7 +777,9 @@ def __call__(
# 1. Check inputs. Raise error if not correct
if use_resolution_binning:
- if self.transformer.config.sample_size == 64:
+ if self.transformer.config.sample_size == 128:
+ aspect_ratio_bin = ASPECT_RATIO_4096_BIN
+ elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
index 6f83071f3e85..a8c374259349 100644
--- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
+++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
@@ -9,12 +9,19 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import SemanticStableDiffusionPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -87,7 +94,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -701,6 +708,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py
index f87f28e06c4a..ef8a95daefa4 100644
--- a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py
+++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py
@@ -25,6 +25,7 @@
from ...schedulers import HeunDiscreteScheduler
from ...utils import (
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -33,8 +34,16 @@
from .renderer import ShapERenderer
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -291,6 +300,9 @@ def __call__(
sample=latents,
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# Offload all models
self.maybe_free_model_hooks()
diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
index 7cc145e4c3e2..c0d1e38e0994 100644
--- a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
+++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py
@@ -24,6 +24,7 @@
from ...schedulers import HeunDiscreteScheduler
from ...utils import (
BaseOutput,
+ is_torch_xla_available,
logging,
replace_example_docstring,
)
@@ -32,8 +33,16 @@
from .renderer import ShapERenderer
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -278,6 +287,9 @@ def __call__(
sample=latents,
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type not in ["np", "pil", "latent", "mesh"]:
raise ValueError(
f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}"
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
index 111ccc40c5a5..e3b9ec44005a 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
@@ -19,14 +19,22 @@
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
-from ...utils import is_torch_version, logging, replace_example_docstring
+from ...utils import is_torch_version, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -503,6 +511,9 @@ def __call__(
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
index 058dbf6b0797..241c454e103e 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
@@ -23,13 +23,21 @@
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
-from ...utils import BaseOutput, logging, replace_example_docstring
+from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
EXAMPLE_DOC_STRING = """
@@ -611,6 +619,9 @@ def __call__(
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# Offload all models
self.maybe_free_model_hooks()
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
index 5d6ffd463cc3..71dbf989bf92 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
@@ -55,7 +55,7 @@
>>> from diffusers import FlaxStableDiffusionPipeline
>>> pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jax.numpy.bfloat16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="bf16", dtype=jax.numpy.bfloat16
... )
>>> prompt = "a photo of an astronaut riding a horse on mars"
@@ -100,8 +100,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -141,8 +141,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -162,7 +162,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
index 7792bc097595..c2d918156084 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
@@ -124,8 +124,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -165,7 +165,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]):
if not isinstance(prompt, (str, list)):
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
index f6bb0ac299b3..2367ca36fc8e 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
@@ -127,8 +127,8 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
[`FlaxDPMSolverMultistepScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -168,8 +168,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -189,7 +189,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def prepare_inputs(
self,
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
index 2e34dcb83c01..9917276e0a1f 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
@@ -57,7 +57,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -71,7 +71,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
index c39409886bd9..92c82d61b8f2 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py
@@ -78,7 +78,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -109,7 +110,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -123,7 +124,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
index 18d8050826cc..ddd2e27dedaf 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
@@ -76,7 +76,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -108,7 +109,7 @@ def __init__(
super().__init__()
logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -122,7 +123,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
index cd9ec57fb879..ef84cdd38b6d 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
@@ -83,7 +83,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -97,7 +97,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
index ac6c8253e432..8bfe273b2fb9 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
@@ -55,7 +55,9 @@
>>> import torch
>>> from diffusers import StableDiffusionPipeline
- >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionPipeline.from_pretrained(
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
+ ... )
>>> pipe = pipe.to("cuda")
>>> prompt = "a photo of an astronaut riding a horse on mars"
@@ -184,8 +186,8 @@ class StableDiffusionPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -209,7 +211,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -223,7 +225,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -266,8 +268,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -288,7 +290,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
index 7801b0d01dff..abd67ae577ea 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py
@@ -28,11 +28,26 @@
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import (
+ PIL_INTERPOLATION,
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -124,8 +139,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -145,7 +160,7 @@ def __init__(
depth_estimator=depth_estimator,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -861,6 +876,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
index 93a8bd160318..308a0753b175 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py
@@ -24,13 +24,20 @@
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -57,8 +64,8 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMi
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -106,8 +113,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -126,7 +133,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -401,6 +408,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
self.maybe_free_model_hooks()
if not output_type == "latent":
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
index 9cd5673c9359..17e8f0eb494f 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
@@ -32,6 +32,7 @@
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -43,8 +44,16 @@
from .safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -56,7 +65,7 @@
>>> from diffusers import StableDiffusionImg2ImgPipeline
>>> device = "cuda"
- >>> model_id_or_path = "runwayml/stable-diffusion-v1-5"
+ >>> model_id_or_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
>>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
>>> pipe = pipe.to(device)
@@ -205,8 +214,8 @@ class StableDiffusionImg2ImgPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -230,7 +239,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -244,7 +253,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -282,8 +291,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -304,7 +313,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1120,6 +1129,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
index 49c38c800942..9d3dfd30607a 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py
@@ -27,13 +27,27 @@
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -146,8 +160,8 @@ class StableDiffusionInpaintPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -171,7 +185,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -185,7 +199,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -224,8 +238,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -250,7 +264,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -1014,7 +1028,7 @@ def __call__(
>>> mask_image = download_image(mask_url).resize((512, 512))
>>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
@@ -1200,7 +1214,7 @@ def __call__(
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
@@ -1303,6 +1317,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
condition_kwargs = {}
if isinstance(self.vae, AsymmetricAutoencoderKL):
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
index fd89b195c778..7857bc58a8ad 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
@@ -22,16 +22,23 @@
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import PIL_INTERPOLATION, deprecate, logging
+from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -79,6 +86,7 @@ class StableDiffusionInstructPix2PixPipeline(
TextualInversionLoaderMixin,
StableDiffusionLoraLoaderMixin,
IPAdapterMixin,
+ FromSingleFileMixin,
):
r"""
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
@@ -106,8 +114,8 @@ class StableDiffusionInstructPix2PixPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -157,7 +165,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -457,6 +465,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
index ffe02ae679e5..c6967bc393b5 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py
@@ -25,11 +25,18 @@
from ...loaders import FromSingleFileMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import EulerDiscreteScheduler
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -116,7 +123,7 @@ def __init__(
unet=unet,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
def _encode_prompt(
@@ -640,6 +647,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
index 4cbbe17531ef..dae4540ebe00 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py
@@ -30,12 +30,26 @@
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
-from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -149,7 +163,7 @@ def __init__(
watermarker=watermarker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
self.register_to_config(max_noise_level=max_noise_level)
@@ -769,6 +783,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
index 41811f8f2c0e..07d82251d4ba 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py
@@ -28,6 +28,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -38,8 +39,16 @@
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -154,7 +163,7 @@ def __init__(
vae=vae,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt with _encode_prompt->_encode_prior_prompt, tokenizer->prior_tokenizer, text_encoder->prior_text_encoder
@@ -924,6 +933,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
else:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
index 2556d5e57b6d..eac9945ff349 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py
@@ -28,6 +28,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -38,8 +39,16 @@
from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -155,7 +164,7 @@ def __init__(
vae=vae,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -829,6 +838,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 9. Post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
index 2967183e3c5f..dc0d64144e12 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
@@ -215,9 +215,7 @@ def __init__(
image_encoder=image_encoder,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
@@ -870,7 +868,8 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
- ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
index 0fdb6d25ed91..6a3a4abe7696 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
@@ -226,10 +226,8 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
- latent_channels = self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels
)
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
index 67afae1d21c0..23cc4983d54f 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
@@ -225,10 +225,8 @@ def __init__(
transformer=transformer,
scheduler=scheduler,
)
- self.vae_scale_factor = (
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
- )
- latent_channels = self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels
)
diff --git a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
index 8f40fa72a25c..351b146fb423 100644
--- a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
+++ b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py
@@ -30,6 +30,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -41,6 +42,14 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """
@@ -194,8 +203,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -242,7 +251,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1008,6 +1017,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
index 2b86470dbff1..bdc9cb80da16 100644
--- a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
+++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
@@ -33,6 +33,7 @@
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -44,6 +45,13 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -268,8 +276,8 @@ class StableDiffusionDiffEditPipeline(
A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents.
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -292,7 +300,7 @@ def __init__(
):
super().__init__()
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -306,7 +314,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False:
+ if scheduler is not None and getattr(scheduler.config, "skip_prk_steps", True) is False:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration"
" `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
@@ -345,8 +353,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -367,7 +375,7 @@ def __init__(
feature_extractor=feature_extractor,
inverse_scheduler=inverse_scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1508,6 +1516,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
index 52ccd5612776..4bbb93e44a83 100644
--- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
+++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py
@@ -29,6 +29,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -40,8 +41,16 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -120,8 +129,8 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline, StableDiffusionMixin):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -168,7 +177,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -828,6 +837,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
index 6c36ec173539..86ef01784057 100644
--- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
+++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py
@@ -32,7 +32,14 @@
from ...models.attention import GatedSelfAttentionDense
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionPipelineOutput
@@ -40,8 +47,16 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -172,8 +187,8 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -226,7 +241,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1010,6 +1025,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
index 122701ff923f..24e11bff3052 100755
--- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
@@ -83,7 +83,8 @@ class StableDiffusionKDiffusionPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -124,7 +125,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
model = ModelWrapper(unet, scheduler.alphas_cumprod)
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
index 45f814fd538f..ddcc77de28f5 100644
--- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py
@@ -170,7 +170,7 @@ def __init__(
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
@@ -321,7 +321,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -380,8 +382,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
index 81bb0e9a7270..702f3eda5816 100644
--- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
+++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py
@@ -30,6 +30,7 @@
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -40,8 +41,16 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```python
@@ -203,8 +212,8 @@ class StableDiffusionLDM3DPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -254,7 +263,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1002,6 +1011,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
index 2fc79c0610f0..ccee6d47b47a 100644
--- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
+++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py
@@ -26,6 +26,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -37,8 +38,16 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -179,8 +188,8 @@ class StableDiffusionPanoramaPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -230,7 +239,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -1155,6 +1164,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type != "latent":
if circular_padding:
image = self.decode_latents_with_padding(latents)
diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
index cd59cf51869d..6c4513b9a69d 100644
--- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
+++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py
@@ -12,13 +12,20 @@
from ...loaders import IPAdapterMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import deprecate, logging
+from ...utils import deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionSafePipelineOutput
from .safety_checker import SafeStableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -46,8 +53,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline, StableDiffusionMixin, IPAda
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -74,7 +81,7 @@ def __init__(
" abuse, brutality, cruelty"
)
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ if scheduler is not None and getattr(scheduler.config, "steps_offset", 1) != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
@@ -88,7 +95,7 @@ def __init__(
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ if scheduler is not None and getattr(scheduler.config, "clip_sample", False) is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
@@ -126,8 +133,8 @@ def __init__(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -149,7 +156,7 @@ def __init__(
image_encoder=image_encoder,
)
self._safety_text_concept = safety_concept
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.register_to_config(requires_safety_checker=requires_safety_checker)
@property
@@ -739,6 +746,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post-processing
image = self.decode_latents(latents)
diff --git a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
index c32052d2e4d0..e96422073b19 100644
--- a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
+++ b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py
@@ -27,6 +27,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -38,8 +39,16 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -47,7 +56,7 @@
>>> from diffusers import StableDiffusionSAGPipeline
>>> pipe = StableDiffusionSAGPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
@@ -123,8 +132,8 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -157,7 +166,7 @@ def __init__(
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -840,6 +849,9 @@ def get_map_size(module, input, output):
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
index 77363b2546d7..eb1030f3bb9d 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py
@@ -65,7 +65,7 @@ def __init__(
unet=unet,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
index d83fa6201117..18e6d91b3245 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py
@@ -269,7 +269,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
@@ -406,7 +406,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -465,8 +467,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
index 126f25a41adc..08d0b44d613d 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py
@@ -291,7 +291,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -427,7 +427,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -486,8 +488,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index a378ae65eb30..920caf4d24a1 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -321,7 +321,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -531,7 +531,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -590,8 +592,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
index b59f2312726d..e191565f947e 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py
@@ -199,7 +199,7 @@ def __init__(
scheduler=scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
self.is_cosxl_edit = is_cosxl_edit
@@ -333,7 +333,9 @@ def encode_prompt(
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
@@ -385,7 +387,8 @@ def encode_prompt(
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
index fb986075aeea..8c1af7863e63 100644
--- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
+++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py
@@ -24,14 +24,22 @@
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
from ...schedulers import EulerDiscreteScheduler
-from ...utils import BaseOutput, logging, replace_example_docstring
+from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -177,7 +185,7 @@ def __init__(
scheduler=scheduler,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)
def _encode_image(
@@ -600,6 +608,9 @@ def __call__(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# cast back to fp16 if needed
if needs_upcasting:
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
index 1a938aaf9423..8520a2e2b741 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py
@@ -31,6 +31,7 @@
USE_PEFT_BACKEND,
BaseOutput,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -41,6 +42,14 @@
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
@dataclass
class StableDiffusionAdapterPipelineOutput(BaseOutput):
"""
@@ -59,6 +68,7 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput):
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -208,7 +218,8 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline, StableDiffusionMixin):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -259,7 +270,7 @@ def __init__(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -914,6 +925,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type == "latent":
image = latents
has_nsfw_concept = None
diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
index 20569d0adb32..d4cbc3c66bea 100644
--- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
+++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py
@@ -43,6 +43,7 @@
from ...utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -53,8 +54,16 @@
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -248,7 +257,8 @@ class StableDiffusionXLAdapterPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -292,7 +302,7 @@ def __init__(
image_encoder=image_encoder,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
@@ -422,7 +432,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -481,8 +493,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -1261,6 +1275,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
index cdd72b97f86b..5c63d66e3133 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py
@@ -25,6 +25,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -36,8 +37,16 @@
from . import TextToVideoSDPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -105,7 +114,7 @@ def __init__(
unet=unet,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -627,6 +636,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 8. Post processing
if output_type == "latent":
video = latents
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
index 92bf1d388c13..006c7a79ce0d 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py
@@ -26,6 +26,7 @@
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
+ is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -37,8 +38,16 @@
from . import TextToVideoSDPipelineOutput
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -140,7 +149,7 @@ def __init__(
unet=unet,
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -679,6 +688,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
index c95c7f1b9625..df85f470a80b 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py
@@ -11,16 +11,30 @@
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
-from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
+from ...loaders import FromSingleFileMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ BaseOutput,
+ is_torch_xla_available,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from ..stable_diffusion import StableDiffusionSafetyChecker
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -282,7 +296,11 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
class TextToVideoZeroPipeline(
- DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionLoraLoaderMixin,
+ FromSingleFileMixin,
):
r"""
Pipeline for zero-shot text-to-video generation using Stable Diffusion.
@@ -304,8 +322,8 @@ class TextToVideoZeroPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`CLIPImageProcessor`]):
A [`CLIPImageProcessor`] to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -340,7 +358,7 @@ def __init__(
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def forward_loop(self, x_t0, t0, t1, generator):
@@ -440,6 +458,10 @@ def backward_loop(
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
return latents.clone().detach()
# Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs
diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
index 9ff473cc3a38..a9f7b4b000c2 100644
--- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
+++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py
@@ -42,6 +42,16 @@
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+from ...utils import is_torch_xla_available
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -409,7 +419,7 @@ def __init__(
feature_extractor=feature_extractor,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
@@ -705,7 +715,9 @@ def encode_prompt(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
- pooled_prompt_embeds = prompt_embeds[0]
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
+ pooled_prompt_embeds = prompt_embeds[0]
+
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
@@ -764,8 +776,10 @@ def encode_prompt(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
+
# We are only ALWAYS interested in the pooled output of the final text encoder
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -922,6 +936,10 @@ def backward_loop(
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
return latents.clone().detach()
@torch.no_grad()
diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py
index 25c6739d8720..bf42d44f74c1 100644
--- a/src/diffusers/pipelines/unclip/pipeline_unclip.py
+++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py
@@ -22,12 +22,19 @@
from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
from ...schedulers import UnCLIPScheduler
-from ...utils import logging
+from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -474,6 +481,9 @@ def __call__(
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = super_res_latents
# done super res
diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
index 2a0e7e90e4d2..8fa0a848f7e7 100644
--- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
+++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py
@@ -27,12 +27,19 @@
from ...models import UNet2DConditionModel, UNet2DModel
from ...schedulers import UnCLIPScheduler
-from ...utils import logging
+from ...utils import is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .text_proj import UnCLIPTextProjModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -400,6 +407,9 @@ def __call__(
noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
).prev_sample
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
image = super_res_latents
# done super res
diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
index 4f65caf4e610..66d7404fb9a5 100644
--- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
+++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py
@@ -18,7 +18,14 @@
from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
-from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_torch_xla_available,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
@@ -26,6 +33,13 @@
from .modeling_uvit import UniDiffuserModel
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -117,7 +131,7 @@ def __init__(
scheduler=scheduler,
)
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.num_channels_latents = vae.config.latent_channels
@@ -1378,6 +1392,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 9. Post-processing
image = None
text = None
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
index b08421415b23..edc01f0d5c75 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
@@ -19,15 +19,23 @@
from transformers import CLIPTextModel, CLIPTokenizer
from ...schedulers import DDPMWuerstchenScheduler
-from ...utils import deprecate, logging, replace_example_docstring
+from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -413,6 +421,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
index 92223ce993a6..8f6ba419721d 100644
--- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
@@ -22,14 +22,22 @@
from ...loaders import StableDiffusionLoraLoaderMixin
from ...schedulers import DDPMWuerstchenScheduler
-from ...utils import BaseOutput, deprecate, logging, replace_example_docstring
+from ...utils import BaseOutput, deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .modeling_wuerstchen_prior import WuerstchenPrior
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
EXAMPLE_DOC_STRING = """
@@ -502,6 +510,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
# 10. Denormalize the latents
latents = latents * self.config.latent_mean - self.config.latent_std
diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py
index 35e5743fbcf0..9bbb5e4ca266 100644
--- a/src/diffusers/quantizers/gguf/utils.py
+++ b/src/diffusers/quantizers/gguf/utils.py
@@ -450,7 +450,7 @@ def __init__(
def forward(self, inputs):
weight = dequantize_gguf_tensor(self.weight)
weight = weight.to(self.compute_dtype)
- bias = self.bias.to(self.compute_dtype)
+ bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
output = torch.nn.functional.linear(inputs, weight, bias)
return output
diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py
index 0861160de6aa..ace0ad6b6044 100644
--- a/tests/lora/test_lora_layers_flux.py
+++ b/tests/lora/test_lora_layers_flux.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import gc
import os
import sys
@@ -162,6 +163,105 @@ def test_with_alpha_in_state_dict(self):
)
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
+ def test_lora_expansion_works_for_absent_keys(self):
+ components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == self.output_shape)
+
+ # Modify the config to have a layer which won't be present in the second LoRA we will load.
+ modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
+ modified_denoiser_lora_config.target_modules.add("x_embedder")
+
+ pipe.transformer.add_adapter(modified_denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
+
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertFalse(
+ np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
+ "LoRA should lead to different results.",
+ )
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
+ self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
+
+ # Modify the state dict to exclude "x_embedder" related LoRA params.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+ lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
+
+ pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
+ pipe.set_adapters(["one", "two"])
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
+ images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ self.assertFalse(
+ np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
+ "Different LoRAs should lead to different results.",
+ )
+ self.assertFalse(
+ np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
+ "LoRA should lead to different results.",
+ )
+
+ def test_lora_expansion_works_for_extra_keys(self):
+ components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertTrue(output_no_lora.shape == self.output_shape)
+
+ # Modify the config to have a layer which won't be present in the first LoRA we will load.
+ modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
+ modified_denoiser_lora_config.target_modules.add("x_embedder")
+
+ pipe.transformer.add_adapter(modified_denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
+
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
+ self.assertFalse(
+ np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
+ "LoRA should lead to different results.",
+ )
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
+ self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+ pipe.unload_lora_weights()
+ # Modify the state dict to exclude "x_embedder" related LoRA params.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+ lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
+ pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
+
+ # Load state dict with `x_embedder`.
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
+
+ pipe.set_adapters(["one", "two"])
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
+ images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
+
+ self.assertFalse(
+ np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
+ "Different LoRAs should lead to different results.",
+ )
+ self.assertFalse(
+ np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
+ "LoRA should lead to different results.",
+ )
+
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@@ -606,7 +706,7 @@ def test_lora_unload_with_parameter_expanded_shapes(self):
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
- control_pipe.unload_lora_weights()
+ control_pipe.unload_lora_weights(reset_to_overwritten_params=True)
self.assertTrue(
control_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
@@ -624,6 +724,65 @@ def test_lora_unload_with_parameter_expanded_shapes(self):
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)
+ def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
+ components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
+
+ logger = logging.get_logger("diffusers.loaders.lora_pipeline")
+ logger.setLevel(logging.DEBUG)
+
+ # Change the transformer config to mimic a real use case.
+ num_channels_without_control = 4
+ transformer = FluxTransformer2DModel.from_config(
+ components["transformer"].config, in_channels=num_channels_without_control
+ ).to(torch_device)
+ self.assertTrue(
+ transformer.config.in_channels == num_channels_without_control,
+ f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
+ )
+
+ # This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
+ components["transformer"] = transformer
+ pipe = FluxPipeline(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ control_image = inputs.pop("control_image")
+ original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ control_pipe = self.pipeline_class(**components)
+ out_features, in_features = control_pipe.transformer.x_embedder.weight.shape
+ rank = 4
+
+ dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
+ dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
+ lora_state_dict = {
+ "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
+ "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
+ }
+ with CaptureLogger(logger) as cap_logger:
+ control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
+
+ inputs["control_image"] = control_image
+ lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
+ self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
+ self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
+ self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
+
+ control_pipe.unload_lora_weights(reset_to_overwritten_params=False)
+ self.assertTrue(
+ control_pipe.transformer.config.in_channels == 2 * num_channels_without_control,
+ f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
+ )
+ no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4))
+ self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
+ self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
+
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py
index 8bda98438571..d2015d8b0711 100644
--- a/tests/lora/test_lora_layers_hunyuanvideo.py
+++ b/tests/lora/test_lora_layers_hunyuanvideo.py
@@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import gc
import sys
import unittest
+import numpy as np
+import pytest
import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
@@ -26,7 +29,11 @@
)
from diffusers.utils.testing_utils import (
floats_tensor,
+ nightly,
+ numpy_cosine_similarity_distance,
+ require_big_gpu_with_torch_cuda,
require_peft_backend,
+ require_torch_gpu,
skip_mps,
)
@@ -182,3 +189,69 @@ def test_simple_inference_with_text_lora_fused(self):
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+
+
+@nightly
+@require_torch_gpu
+@require_peft_backend
+@require_big_gpu_with_torch_cuda
+@pytest.mark.big_gpu_with_torch_cuda
+class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
+ """internal note: The integration slices were obtained on DGX.
+
+ torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the
+ assertions to pass.
+ """
+
+ num_inference_steps = 10
+ seed = 0
+
+ def setUp(self):
+ super().setUp()
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ model_id = "hunyuanvideo-community/HunyuanVideo"
+ transformer = HunyuanVideoTransformer3DModel.from_pretrained(
+ model_id, subfolder="transformer", torch_dtype=torch.bfloat16
+ )
+ self.pipeline = HunyuanVideoPipeline.from_pretrained(
+ model_id, transformer=transformer, torch_dtype=torch.float16
+ ).to("cuda")
+
+ def tearDown(self):
+ super().tearDown()
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_original_format_cseti(self):
+ self.pipeline.load_lora_weights(
+ "Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors"
+ )
+ self.pipeline.fuse_lora()
+ self.pipeline.unload_lora_weights()
+ self.pipeline.vae.enable_tiling()
+
+ prompt = "CSETIARCANE. A cat walks on the grass, realistic"
+
+ out = self.pipeline(
+ prompt=prompt,
+ height=320,
+ width=512,
+ num_frames=9,
+ num_inference_steps=self.num_inference_steps,
+ output_type="np",
+ generator=torch.manual_seed(self.seed),
+ ).frames[0]
+ out = out.flatten()
+ out_slice = np.concatenate((out[:8], out[-8:]))
+
+ # fmt: off
+ expected_slice = np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815])
+ # fmt: on
+
+ max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
+
+ assert max_diff < 1e-3
diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py
index 8c42f9c86ee9..40383e3f1ee3 100644
--- a/tests/lora/test_lora_layers_sd3.py
+++ b/tests/lora/test_lora_layers_sd3.py
@@ -29,9 +29,11 @@
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
+ nightly,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
+ slow,
torch_device,
)
@@ -126,6 +128,8 @@ def test_modify_padding_mode(self):
pass
+@slow
+@nightly
@require_torch_gpu
@require_peft_backend
class LoraSD3IntegrationTests(unittest.TestCase):
diff --git a/tests/models/transformers/test_models_prior.py b/tests/models/transformers/test_models_prior.py
index d2ed10dfa1f6..471c1084c00c 100644
--- a/tests/models/transformers/test_models_prior.py
+++ b/tests/models/transformers/test_models_prior.py
@@ -132,7 +132,6 @@ def test_output_pretrained(self):
output = model(**input)[0]
output_slice = output[0, :5].flatten().cpu()
- print(output_slice)
# Since the VAE Gaussian prior's generator is seeded on the appropriate device,
# the expected output slices are not the same for CPU and GPU.
@@ -182,7 +181,6 @@ def test_kandinsky_prior(self, seed, expected_slice):
assert list(sample.shape) == [1, 768]
output_slice = sample[0, :8].flatten().cpu()
- print(output_slice)
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py
index 4c13b54e0620..73b83b9eb514 100644
--- a/tests/models/transformers/test_models_transformer_cogvideox.py
+++ b/tests/models/transformers/test_models_transformer_cogvideox.py
@@ -71,7 +71,7 @@ def prepare_init_args_and_inputs_for_common(self):
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
- "num_layers": 1,
+ "num_layers": 2,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
@@ -130,7 +130,7 @@ def prepare_init_args_and_inputs_for_common(self):
"out_channels": 4,
"time_embed_dim": 2,
"text_embed_dim": 8,
- "num_layers": 1,
+ "num_layers": 2,
"sample_width": 8,
"sample_height": 8,
"sample_frames": 8,
diff --git a/tests/models/transformers/test_models_transformer_cogview3plus.py b/tests/models/transformers/test_models_transformer_cogview3plus.py
index eda9813808e9..ec6c58a6734c 100644
--- a/tests/models/transformers/test_models_transformer_cogview3plus.py
+++ b/tests/models/transformers/test_models_transformer_cogview3plus.py
@@ -71,7 +71,7 @@ def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 4,
- "num_layers": 1,
+ "num_layers": 2,
"attention_head_dim": 4,
"num_attention_heads": 2,
"out_channels": 4,
diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py
index ddf5f53511f7..a39b36ee20cc 100644
--- a/tests/models/unets/test_models_unet_2d.py
+++ b/tests/models/unets/test_models_unet_2d.py
@@ -105,6 +105,35 @@ def test_mid_block_attn_groups(self):
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+ def test_mid_block_none(self):
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ mid_none_init_dict["mid_block_type"] = None
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ mid_none_model = self.model_class(**mid_none_init_dict)
+ mid_none_model.to(torch_device)
+ mid_none_model.eval()
+
+ self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
+
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.to_tuple()[0]
+
+ with torch.no_grad():
+ mid_none_output = mid_none_model(**mid_none_inputs_dict)
+
+ if isinstance(mid_none_output, dict):
+ mid_none_output = mid_none_output.to_tuple()[0]
+
+ self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
+
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"AttnUpBlock2D",
diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py
index 8ec5b6e9a5e4..57f6e4ee440b 100644
--- a/tests/models/unets/test_models_unet_2d_condition.py
+++ b/tests/models/unets/test_models_unet_2d_condition.py
@@ -175,8 +175,7 @@ def create_ip_adapter_plus_state_dict(model):
)
ip_image_projection_state_dict = OrderedDict()
- keys = [k for k in image_projection.state_dict() if "layers." in k]
- print(keys)
+
for k, v in image_projection.state_dict().items():
if "2.to" in k:
k = k.replace("2.to", "0.to")
diff --git a/tests/pipelines/controlnet/test_flax_controlnet.py b/tests/pipelines/controlnet/test_flax_controlnet.py
index bf5564e810ef..c71116dc7927 100644
--- a/tests/pipelines/controlnet/test_flax_controlnet.py
+++ b/tests/pipelines/controlnet/test_flax_controlnet.py
@@ -78,7 +78,7 @@ def test_canny(self):
expected_slice = jnp.array(
[0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078]
)
- print(f"output_slice: {output_slice}")
+
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
def test_pose(self):
@@ -123,5 +123,5 @@ def test_pose(self):
expected_slice = jnp.array(
[[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]]
)
- print(f"output_slice: {output_slice}")
+
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
index 5c547164c29a..7527d17af32a 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
@@ -150,6 +150,8 @@ def get_dummy_components(
"transformer": transformer,
"vae": vae,
"controlnet": controlnet,
+ "image_encoder": None,
+ "feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py
index 607a47e08e58..a7f861565cc9 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py
@@ -308,8 +308,6 @@ def test_kandinsky(self):
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
- print(image_from_tuple_slice)
-
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.0320, 0.0860, 0.4013, 0.0518, 0.2484, 0.5847, 0.4411, 0.2321, 0.4593])
diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
index effea2619749..4aa48a920fad 100644
--- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
+++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
@@ -146,7 +146,7 @@ def test_ledits_pp_inversion(self):
)
latent_slice = sd_pipe.init_latents[0, -1, -3:, -3:].to(device)
- print(latent_slice.flatten())
+
expected_slice = np.array([-0.9084, -0.0367, 0.2940, 0.0839, 0.6890, 0.2651, -0.7104, 2.1090, -0.7822])
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
@@ -167,12 +167,12 @@ def test_ledits_pp_inversion_batch(self):
)
latent_slice = sd_pipe.init_latents[0, -1, -3:, -3:].to(device)
- print(latent_slice.flatten())
+
expected_slice = np.array([0.2528, 0.1458, -0.2166, 0.4565, -0.5657, -1.0286, -0.9961, 0.5933, 1.1173])
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
latent_slice = sd_pipe.init_latents[1, -1, -3:, -3:].to(device)
- print(latent_slice.flatten())
+
expected_slice = np.array([-0.0796, 2.0583, 0.5501, 0.5358, 0.0282, -0.2803, -1.0470, 0.7023, -0.0072])
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
index fcfd0aa51b9f..da694175a9f1 100644
--- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
+++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py
@@ -216,14 +216,14 @@ def test_ledits_pp_inversion_batch(self):
)
latent_slice = sd_pipe.init_latents[0, -1, -3:, -3:].to(device)
- print(latent_slice.flatten())
+
expected_slice = np.array([0.2528, 0.1458, -0.2166, 0.4565, -0.5656, -1.0286, -0.9961, 0.5933, 1.1172])
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
latent_slice = sd_pipe.init_latents[1, -1, -3:, -3:].to(device)
- print(latent_slice.flatten())
+
expected_slice = np.array([-0.0796, 2.0583, 0.5500, 0.5358, 0.0282, -0.2803, -1.0470, 0.7024, -0.0072])
- print(latent_slice.flatten())
+
assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3
def test_ledits_pp_warmup_steps(self):
diff --git a/tests/pipelines/pag/test_pag_sd.py b/tests/pipelines/pag/test_pag_sd.py
index 3979bb170e0b..17e3f7038439 100644
--- a/tests/pipelines/pag/test_pag_sd.py
+++ b/tests/pipelines/pag/test_pag_sd.py
@@ -318,7 +318,7 @@ def test_pag_cfg(self):
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
- print(image_slice.flatten())
+
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
@@ -339,7 +339,6 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
- print(image_slice.flatten())
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
), f"output is different from expected, {image_slice.flatten()}"
diff --git a/tests/pipelines/pag/test_pag_sd_img2img.py b/tests/pipelines/pag/test_pag_sd_img2img.py
index ec8cde23c31d..f44204f82486 100644
--- a/tests/pipelines/pag/test_pag_sd_img2img.py
+++ b/tests/pipelines/pag/test_pag_sd_img2img.py
@@ -255,7 +255,7 @@ def test_pag_cfg(self):
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
- print(image_slice.flatten())
+
expected_slice = np.array(
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
)
@@ -276,7 +276,7 @@ def test_pag_uncond(self):
expected_slice = np.array(
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
)
- print(image_slice.flatten())
+
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
), f"output is different from expected, {image_slice.flatten()}"
diff --git a/tests/pipelines/pag/test_pag_sd_inpaint.py b/tests/pipelines/pag/test_pag_sd_inpaint.py
index cd175c600d47..a528b66cc72a 100644
--- a/tests/pipelines/pag/test_pag_sd_inpaint.py
+++ b/tests/pipelines/pag/test_pag_sd_inpaint.py
@@ -292,7 +292,7 @@ def test_pag_cfg(self):
image_slice = image[0, -3:, -3:, -1].flatten()
assert image.shape == (1, 512, 512, 3)
- print(image_slice.flatten())
+
expected_slice = np.array(
[0.38793945, 0.4111328, 0.47924805, 0.39208984, 0.4165039, 0.41674805, 0.37060547, 0.36791992, 0.40625]
)
diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
index b9b061c060c0..5690caa257b7 100644
--- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
+++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py
@@ -206,9 +206,6 @@ def test_stable_diffusion_pix2pix_euler(self):
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
- slice = [round(x, 4) for x in image_slice.flatten().tolist()]
- print(",".join([str(x) for x in slice]))
-
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.7417, 0.3842, 0.4732, 0.5776, 0.5891, 0.5139, 0.4052, 0.5673, 0.4986])
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py
index dc855f44b817..9e4fa767085f 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py
@@ -62,7 +62,7 @@ def test_stable_diffusion_flax(self):
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512])
- print(f"output_slice: {output_slice}")
+
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
@@ -104,5 +104,5 @@ def test_stable_diffusion_dpm_flax(self):
output_slice = jnp.asarray(jax.device_get(image_slice.flatten()))
expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297])
- print(f"output_slice: {output_slice}")
+
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py
index 8f039980ec24..eeec52dab51d 100644
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py
+++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py
@@ -78,5 +78,5 @@ def test_stable_diffusion_inpaint_pipeline(self):
expected_slice = jnp.array(
[0.3611307, 0.37649736, 0.3757408, 0.38213953, 0.39295167, 0.3841631, 0.41554978, 0.4137475, 0.4217084]
)
- print(f"output_slice: {output_slice}")
+
assert jnp.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
index 2091af9c0383..7c7b03786563 100644
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
+++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py
@@ -642,9 +642,6 @@ def test_adapter_sdxl_lcm(self):
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448])
- debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()]
- print(",".join(debug))
-
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_adapter_sdxl_lcm_custom_timesteps(self):
@@ -667,7 +664,4 @@ def test_adapter_sdxl_lcm_custom_timesteps(self):
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448])
- debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()]
- print(",".join(debug))
-
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 764be1890cc5..f5494fbade2e 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -1192,7 +1192,6 @@ def _test_inference_batch_consistent(
logger.setLevel(level=diffusers.logging.WARNING)
for batch_size, batched_input in zip(batch_sizes, batched_inputs):
- print(batch_size, batched_input)
output = pipe(**batched_input)
assert len(output[0]) == batch_size
diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py
index 8ac4c9915c27..8f768b10e846 100644
--- a/tests/quantization/gguf/test_gguf.py
+++ b/tests/quantization/gguf/test_gguf.py
@@ -6,6 +6,8 @@
import torch.nn as nn
from diffusers import (
+ AuraFlowPipeline,
+ AuraFlowTransformer2DModel,
FluxPipeline,
FluxTransformer2DModel,
GGUFQuantizationConfig,
@@ -54,7 +56,8 @@ def test_gguf_linear_layers(self):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"):
assert module.weight.dtype == torch.uint8
- assert module.bias.dtype == torch.float32
+ if module.bias is not None:
+ assert module.bias.dtype == torch.float32
def test_gguf_memory_usage(self):
quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
@@ -377,3 +380,79 @@ def test_pipeline_inference(self):
)
max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
assert max_diff < 1e-4
+
+
+class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = AuraFlowTransformer2DModel
+ expected_memory_use_in_gb = 4
+
+ def setUp(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def tearDown(self):
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 4, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 2048),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+ def test_pipeline_inference(self):
+ quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+ transformer = self.model_cls.from_single_file(
+ self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype
+ )
+ pipe = AuraFlowPipeline.from_pretrained(
+ "fal/AuraFlow-v0.3", transformer=transformer, torch_dtype=self.torch_dtype
+ )
+ pipe.enable_model_cpu_offload()
+
+ prompt = "a pony holding a sign that says hello"
+ output = pipe(
+ prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np"
+ ).images[0]
+ output_slice = output[:3, :3, :].flatten()
+ expected_slice = np.array(
+ [
+ 0.46484375,
+ 0.546875,
+ 0.64453125,
+ 0.48242188,
+ 0.53515625,
+ 0.59765625,
+ 0.47070312,
+ 0.5078125,
+ 0.5703125,
+ 0.42773438,
+ 0.50390625,
+ 0.5703125,
+ 0.47070312,
+ 0.515625,
+ 0.57421875,
+ 0.45898438,
+ 0.48632812,
+ 0.53515625,
+ 0.4453125,
+ 0.5078125,
+ 0.56640625,
+ 0.47851562,
+ 0.5234375,
+ 0.57421875,
+ 0.48632812,
+ 0.5234375,
+ 0.56640625,
+ ]
+ )
+ max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice)
+ assert max_diff < 1e-4
diff --git a/tests/schedulers/test_scheduler_sasolver.py b/tests/schedulers/test_scheduler_sasolver.py
index d6d7c029b019..baa2736b2fcc 100644
--- a/tests/schedulers/test_scheduler_sasolver.py
+++ b/tests/schedulers/test_scheduler_sasolver.py
@@ -103,8 +103,6 @@ def test_full_loop_no_noise(self):
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 329.1999816894531) < 1e-2
assert abs(result_mean.item() - 0.4286458194255829) < 1e-3
- else:
- print("None")
def test_full_loop_with_v_prediction(self):
scheduler_class = self.scheduler_classes[0]
@@ -135,8 +133,6 @@ def test_full_loop_with_v_prediction(self):
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 193.4154052734375) < 1e-2
assert abs(result_mean.item() - 0.2518429756164551) < 1e-3
- else:
- print("None")
def test_full_loop_device(self):
scheduler_class = self.scheduler_classes[0]
@@ -166,8 +162,6 @@ def test_full_loop_device(self):
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 337.394287109375) < 1e-2
assert abs(result_mean.item() - 0.4393154978752136) < 1e-3
- else:
- print("None")
def test_full_loop_device_karras_sigmas(self):
scheduler_class = self.scheduler_classes[0]
@@ -198,8 +192,6 @@ def test_full_loop_device_karras_sigmas(self):
elif torch_device in ["cuda"]:
assert abs(result_sum.item() - 837.25537109375) < 1e-2
assert abs(result_mean.item() - 1.0901763439178467) < 1e-2
- else:
- print("None")
def test_beta_sigmas(self):
self.check_over_configs(use_beta_sigmas=True)
diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py
index 71afda1b80bb..dd15a5c7c071 100644
--- a/tests/single_file/test_stable_diffusion_single_file.py
+++ b/tests/single_file/test_stable_diffusion_single_file.py
@@ -4,11 +4,13 @@
import torch
-from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
+from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
+from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
+ nightly,
require_torch_accelerator,
slow,
torch_device,
@@ -118,3 +120,44 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0
def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)
+
+
+@nightly
+@slow
+@require_torch_accelerator
+class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+ pipeline_class = StableDiffusionInstructPix2PixPipeline
+ ckpt_path = "https://huggingface.co/timbrooks/instruct-pix2pix/blob/main/instruct-pix2pix-00-22000.safetensors"
+ original_config = (
+ "https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/refs/heads/main/configs/generate.yaml"
+ )
+ repo_id = "timbrooks/instruct-pix2pix"
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
+ generator = torch.Generator(device=generator_device).manual_seed(seed)
+ image = load_image(
+ "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_pix2pix/example.jpg"
+ )
+ inputs = {
+ "prompt": "turn him into a cyborg",
+ "image": image,
+ "generator": generator,
+ "num_inference_steps": 3,
+ "guidance_scale": 7.5,
+ "image_guidance_scale": 1.0,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_single_file_format_inference_is_same_as_pretrained(self):
+ super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)