diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index fb5288c1145f..ceaaddbdf189 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -272,7 +272,7 @@ jobs: python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -s -v -k "not Flax and not Onnx" \ --make-reports=tests_torch_minimum_version_cuda \ - tests/models/test_modelling_common.py \ + tests/models/test_modeling_common.py \ tests/pipelines/test_pipelines_common.py \ tests/pipelines/test_pipeline_utils.py \ tests/pipelines/test_pipelines.py \ diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index 025787606a9c..8d17380b4a49 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -266,6 +266,7 @@ jobs: # TODO (sayakpaul, DN6): revisit `--no-deps` python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps + python -m uv pip install -U tokenizers pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps - name: Environment diff --git a/.github/workflows/release_tests_fast.yml b/.github/workflows/release_tests_fast.yml index bd0b58256d65..7f1a0ecd1089 100644 --- a/.github/workflows/release_tests_fast.yml +++ b/.github/workflows/release_tests_fast.yml @@ -193,7 +193,7 @@ jobs: python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -s -v -k "not Flax and not Onnx" \ --make-reports=tests_torch_minimum_cuda \ - tests/models/test_modelling_common.py \ + tests/models/test_modeling_common.py \ tests/pipelines/test_pipelines_common.py \ tests/pipelines/test_pipeline_utils.py \ tests/pipelines/test_pipelines.py \ diff --git a/docs/source/en/api/pipelines/aura_flow.md b/docs/source/en/api/pipelines/aura_flow.md index c1cf6aa263a7..5d58690505b3 100644 --- a/docs/source/en/api/pipelines/aura_flow.md +++ b/docs/source/en/api/pipelines/aura_flow.md @@ -62,6 +62,33 @@ image = pipeline(prompt).images[0] image.save("auraflow.png") ``` +Loading [GGUF checkpoints](https://huggingface.co/docs/diffusers/quantization/gguf) are also supported: + +```py +import torch +from diffusers import ( + AuraFlowPipeline, + GGUFQuantizationConfig, + AuraFlowTransformer2DModel, +) + +transformer = AuraFlowTransformer2DModel.from_single_file( + "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf", + quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16), + torch_dtype=torch.bfloat16, +) + +pipeline = AuraFlowPipeline.from_pretrained( + "fal/AuraFlow-v0.3", + transformer=transformer, + torch_dtype=torch.bfloat16, +) + +prompt = "a cute pony in a field of flowers" +image = pipeline(prompt).images[0] +image.save("auraflow.png") +``` + ## AuraFlowPipeline [[autodoc]] AuraFlowPipeline diff --git a/docs/source/en/api/pipelines/sana.md b/docs/source/en/api/pipelines/sana.md index 50eb79088c80..b530d6ecd4a4 100644 --- a/docs/source/en/api/pipelines/sana.md +++ b/docs/source/en/api/pipelines/sana.md @@ -59,10 +59,10 @@ Refer to the [Quantization](../../quantization/overview) overview to learn more ```py import torch from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaPipeline -from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModelForCausalLM +from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel quant_config = BitsAndBytesConfig(load_in_8bit=True) -text_encoder_8bit = AutoModelForCausalLM.from_pretrained( +text_encoder_8bit = AutoModel.from_pretrained( "Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="text_encoder", quantization_config=quant_config, diff --git a/examples/advanced_diffusion_training/README.md b/examples/advanced_diffusion_training/README.md index cd8c5feda9f0..504ae1471f44 100644 --- a/examples/advanced_diffusion_training/README.md +++ b/examples/advanced_diffusion_training/README.md @@ -67,6 +67,17 @@ write_basic_config() When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. +Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub: +```bash +huggingface-cli login +``` +This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter. + +> [!NOTE] +> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`: +> `pip install wandb` +> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`. + ### Pivotal Tuning **Training with text encoder(s)** diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md index 8817431bede5..1f83235ad50a 100644 --- a/examples/advanced_diffusion_training/README_flux.md +++ b/examples/advanced_diffusion_training/README_flux.md @@ -65,6 +65,17 @@ write_basic_config() When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment. +Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub: +```bash +huggingface-cli login +``` +This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter. + +> [!NOTE] +> In the examples below we use `wandb` to document the training runs. To do the same, make sure to install `wandb`: +> `pip install wandb` +> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`. + ### Target Modules When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them. More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore diff --git a/examples/community/README.md b/examples/community/README.md index 611a278af88e..c7c40c46ef2d 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -33,12 +33,12 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif | Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) | | K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | | Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | -| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) | +| Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_comparison.ipynb) | [Suvaditya Mukherjee](https://github.com/suvadityamuk) | | MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) | -| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - | [Ray Wang](https://wrong.wang) | -| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | +| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_unclip.ipynb) | [Ray Wang](https://wrong.wang) | +| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_text_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | | UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | -| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - | [Aengus (Duc-Anh)](https://github.com/aengusng8) | +| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)| [Aengus (Duc-Anh)](https://github.com/aengusng8) | | CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) | | TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) | | EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/edict_image_pipeline.ipynb) | [Joqsan Azocar](https://github.com/Joqsan) | @@ -50,7 +50,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif | IADB Pipeline | Implementation of [Iterative α-(de)Blending: a Minimalist Deterministic Diffusion Model](https://arxiv.org/abs/2305.03486) | [IADB Pipeline](#iadb-pipeline) | - | [Thomas Chambon](https://github.com/tchambon) | Zero1to3 Pipeline | Implementation of [Zero-1-to-3: Zero-shot One Image to 3D Object](https://arxiv.org/abs/2303.11328) | [Zero1to3 Pipeline](#zero1to3-pipeline) | - | [Xin Kong](https://github.com/kxhit) | | Stable Diffusion XL Long Weighted Prompt Pipeline | A pipeline support unlimited length of prompt and negative prompt, use A1111 style of prompt weighting | [Stable Diffusion XL Long Weighted Prompt Pipeline](#stable-diffusion-xl-long-weighted-prompt-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1LsqilswLR40XLLcp6XFOl5nKb_wOe26W?usp=sharing) | [Andrew Zhu](https://xhinker.medium.com/) | -| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | - | [Shauray Singh](https://shauray8.github.io/about_shauray/) | +| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) | | sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) | | sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) | | prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) | diff --git a/examples/community/adaptive_mask_inpainting.py b/examples/community/adaptive_mask_inpainting.py index 5e74f6c1127d..df736956485b 100644 --- a/examples/community/adaptive_mask_inpainting.py +++ b/examples/community/adaptive_mask_inpainting.py @@ -416,10 +416,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -438,7 +442,7 @@ def __init__( unet._internal_dict = FrozenDict(new_config) # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 - if unet.config.in_channels != 9: + if unet is not None and unet.config.in_channels != 9: logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") self.register_modules( @@ -450,7 +454,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py index da6c1d2356be..024818daf186 100644 --- a/examples/community/composable_stable_diffusion.py +++ b/examples/community/composable_stable_diffusion.py @@ -132,10 +132,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -162,7 +166,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): diff --git a/examples/community/edict_pipeline.py b/examples/community/edict_pipeline.py index ac977f79abec..a7bc892ddf93 100644 --- a/examples/community/edict_pipeline.py +++ b/examples/community/edict_pipeline.py @@ -35,7 +35,7 @@ def __init__( scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def _encode_prompt( diff --git a/examples/community/fresco_v2v.py b/examples/community/fresco_v2v.py index ab191ecf0d81..2784e2f238f6 100644 --- a/examples/community/fresco_v2v.py +++ b/examples/community/fresco_v2v.py @@ -1342,7 +1342,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False diff --git a/examples/community/gluegen.py b/examples/community/gluegen.py index 91026c5d966f..54cc562d5583 100644 --- a/examples/community/gluegen.py +++ b/examples/community/gluegen.py @@ -221,7 +221,7 @@ def __init__( language_adapter=language_adapter, tensor_norm=tensor_norm, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/community/instaflow_one_step.py b/examples/community/instaflow_one_step.py index 1fac74b3c5a5..e726b42756ee 100644 --- a/examples/community/instaflow_one_step.py +++ b/examples/community/instaflow_one_step.py @@ -152,10 +152,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -182,7 +186,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py index e05a27abb281..648bf2933145 100644 --- a/examples/community/ip_adapter_face_id.py +++ b/examples/community/ip_adapter_face_id.py @@ -234,10 +234,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -265,7 +269,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/community/kohya_hires_fix.py b/examples/community/kohya_hires_fix.py index 0e36f32b19a3..ddbb28896e13 100644 --- a/examples/community/kohya_hires_fix.py +++ b/examples/community/kohya_hires_fix.py @@ -463,6 +463,6 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/community/latent_consistency_img2img.py b/examples/community/latent_consistency_img2img.py index 5fe53ab6b830..6c532c7f76c1 100644 --- a/examples/community/latent_consistency_img2img.py +++ b/examples/community/latent_consistency_img2img.py @@ -69,7 +69,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def _encode_prompt( diff --git a/examples/community/latent_consistency_interpolate.py b/examples/community/latent_consistency_interpolate.py index 84adc125b191..34cdb0fec73b 100644 --- a/examples/community/latent_consistency_interpolate.py +++ b/examples/community/latent_consistency_interpolate.py @@ -273,7 +273,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/community/latent_consistency_txt2img.py b/examples/community/latent_consistency_txt2img.py index 9f25a6db2722..7b60f5bb875c 100755 --- a/examples/community/latent_consistency_txt2img.py +++ b/examples/community/latent_consistency_txt2img.py @@ -67,7 +67,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def _encode_prompt( diff --git a/examples/community/llm_grounded_diffusion.py b/examples/community/llm_grounded_diffusion.py index 9c2cf984f14b..129793dae6b0 100644 --- a/examples/community/llm_grounded_diffusion.py +++ b/examples/community/llm_grounded_diffusion.py @@ -379,10 +379,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -410,7 +414,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index 4e9c5d1f6a40..32baf500d456 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -539,10 +539,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -568,7 +572,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config( diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py index 13d1e2a1156a..4bcef10f97c2 100644 --- a/examples/community/lpw_stable_diffusion_xl.py +++ b/examples/community/lpw_stable_diffusion_xl.py @@ -673,12 +673,16 @@ def __init__( image_encoder=image_encoder, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -827,7 +831,9 @@ def encode_prompt( ) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds_list.append(prompt_embeds) @@ -879,7 +885,8 @@ def encode_prompt( output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py index 0cd85ced59a1..f80b29456c60 100644 --- a/examples/community/matryoshka.py +++ b/examples/community/matryoshka.py @@ -3793,10 +3793,14 @@ def __init__( # new_config["clip_sample"] = False # scheduler._internal_dict = FrozenDict(new_config) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" diff --git a/examples/community/pipeline_animatediff_controlnet.py b/examples/community/pipeline_animatediff_controlnet.py index bedf002d024c..9f99ad248be2 100644 --- a/examples/community/pipeline_animatediff_controlnet.py +++ b/examples/community/pipeline_animatediff_controlnet.py @@ -188,7 +188,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False diff --git a/examples/community/pipeline_animatediff_img2video.py b/examples/community/pipeline_animatediff_img2video.py index 0a578d4b8ef6..f7f0cf31c5dd 100644 --- a/examples/community/pipeline_animatediff_img2video.py +++ b/examples/community/pipeline_animatediff_img2video.py @@ -308,7 +308,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt diff --git a/examples/community/pipeline_animatediff_ipex.py b/examples/community/pipeline_animatediff_ipex.py index dc65e76bc43b..06508f217c4c 100644 --- a/examples/community/pipeline_animatediff_ipex.py +++ b/examples/community/pipeline_animatediff_ipex.py @@ -162,7 +162,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt diff --git a/examples/community/pipeline_demofusion_sdxl.py b/examples/community/pipeline_demofusion_sdxl.py index f83d1b401420..624b2bd1ed81 100644 --- a/examples/community/pipeline_demofusion_sdxl.py +++ b/examples/community/pipeline_demofusion_sdxl.py @@ -166,9 +166,13 @@ def __init__( scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -290,7 +294,9 @@ def encode_prompt( ) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds_list.append(prompt_embeds) @@ -342,7 +348,8 @@ def encode_prompt( output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/examples/community/pipeline_fabric.py b/examples/community/pipeline_fabric.py index 02fdcd04c103..30847f875bda 100644 --- a/examples/community/pipeline_fabric.py +++ b/examples/community/pipeline_fabric.py @@ -150,10 +150,14 @@ def __init__( ): super().__init__() - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -179,7 +183,7 @@ def __init__( tokenizer=tokenizer, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt diff --git a/examples/community/pipeline_flux_differential_img2img.py b/examples/community/pipeline_flux_differential_img2img.py index 68cb69115bde..a66e2b1c7c8a 100644 --- a/examples/community/pipeline_flux_differential_img2img.py +++ b/examples/community/pipeline_flux_differential_img2img.py @@ -221,13 +221,12 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, - vae_latent_channels=self.vae.config.latent_channels, + vae_latent_channels=latent_channels, do_normalize=False, do_binarize=False, do_convert_grayscale=True, @@ -876,10 +875,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index c8a87a426dc0..42fed90762da 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -219,9 +219,7 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 @@ -419,7 +417,7 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode=" ) image = image.to(dtype) - x0 = self.vae.encode(image.to(self.device)).latent_dist.sample() + x0 = self.vae.encode(image.to(self._execution_device)).latent_dist.sample() x0 = (x0 - self.vae.config.shift_factor) * self.vae.config.scaling_factor x0 = x0.to(dtype) return x0, resized @@ -822,10 +820,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, @@ -992,10 +990,10 @@ def invert( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inversion_steps = retrieve_timesteps( self.scheduler, diff --git a/examples/community/pipeline_flux_with_cfg.py b/examples/community/pipeline_flux_with_cfg.py index 06da6da899cd..0b27fd2bcddf 100644 --- a/examples/community/pipeline_flux_with_cfg.py +++ b/examples/community/pipeline_flux_with_cfg.py @@ -64,6 +64,7 @@ """ +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, @@ -189,9 +190,7 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 @@ -757,10 +756,10 @@ def __call__( image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/examples/community/pipeline_hunyuandit_differential_img2img.py b/examples/community/pipeline_hunyuandit_differential_img2img.py index 8cf2830f25ab..a294ff782450 100644 --- a/examples/community/pipeline_hunyuandit_differential_img2img.py +++ b/examples/community/pipeline_hunyuandit_differential_img2img.py @@ -327,9 +327,7 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, diff --git a/examples/community/pipeline_kolors_differential_img2img.py b/examples/community/pipeline_kolors_differential_img2img.py index e5570248d22b..dfef872d1c30 100644 --- a/examples/community/pipeline_kolors_differential_img2img.py +++ b/examples/community/pipeline_kolors_differential_img2img.py @@ -209,16 +209,18 @@ def __init__( feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True ) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt def encode_prompt( diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py index 3a193fb5bc9c..736f00799eae 100644 --- a/examples/community/pipeline_prompt2prompt.py +++ b/examples/community/pipeline_prompt2prompt.py @@ -174,10 +174,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -205,7 +209,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/community/pipeline_sdxl_style_aligned.py b/examples/community/pipeline_sdxl_style_aligned.py index 8328bc2caed9..9377caf7ba2e 100644 --- a/examples/community/pipeline_sdxl_style_aligned.py +++ b/examples/community/pipeline_sdxl_style_aligned.py @@ -488,13 +488,17 @@ def __init__( feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -628,7 +632,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -688,7 +694,8 @@ def encode_prompt( output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py index 8cee5ecbc141..50952304fc1e 100644 --- a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py +++ b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py @@ -207,7 +207,7 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels ) diff --git a/examples/community/pipeline_stable_diffusion_boxdiff.py b/examples/community/pipeline_stable_diffusion_boxdiff.py index fe32ae7db7e4..bd58a65ce787 100644 --- a/examples/community/pipeline_stable_diffusion_boxdiff.py +++ b/examples/community/pipeline_stable_diffusion_boxdiff.py @@ -460,10 +460,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -491,7 +495,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/community/pipeline_stable_diffusion_pag.py b/examples/community/pipeline_stable_diffusion_pag.py index 12a40d44aaec..874303e0ad6c 100644 --- a/examples/community/pipeline_stable_diffusion_pag.py +++ b/examples/community/pipeline_stable_diffusion_pag.py @@ -427,10 +427,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -458,7 +462,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py index 1ac651a1fe60..8a709ab46757 100644 --- a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py +++ b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py @@ -151,7 +151,7 @@ def __init__( watermarker=watermarker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor, resample="bilinear") # self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(max_noise_level=max_noise_level) diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py index ae495979f366..e55be92962f2 100644 --- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py +++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py @@ -226,12 +226,16 @@ def __init__( scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( @@ -359,7 +363,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -419,7 +425,8 @@ def encode_prompt( output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py index 94ca71cf7b1b..8480117866cc 100644 --- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py +++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py @@ -374,12 +374,16 @@ def __init__( ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( @@ -507,7 +511,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -567,7 +573,8 @@ def encode_prompt( output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py b/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py index 584820e86254..e74ea263017f 100644 --- a/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py +++ b/examples/community/pipeline_stable_diffusion_xl_differential_img2img.py @@ -258,7 +258,7 @@ def __init__( ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -394,7 +394,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -454,7 +456,8 @@ def encode_prompt( output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/examples/community/pipeline_stable_diffusion_xl_ipex.py b/examples/community/pipeline_stable_diffusion_xl_ipex.py index 022dfb1abf82..f43726b1b5b8 100644 --- a/examples/community/pipeline_stable_diffusion_xl_ipex.py +++ b/examples/community/pipeline_stable_diffusion_xl_ipex.py @@ -253,10 +253,14 @@ def __init__( feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -390,7 +394,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -450,7 +456,8 @@ def encode_prompt( output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/examples/community/pipeline_zero1to3.py b/examples/community/pipeline_zero1to3.py index 0f7fdf627136..9a34f91bf841 100644 --- a/examples/community/pipeline_zero1to3.py +++ b/examples/community/pipeline_zero1to3.py @@ -151,10 +151,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -181,7 +185,7 @@ def __init__( feature_extractor=feature_extractor, cc_projection=cc_projection, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.register_to_config(requires_safety_checker=requires_safety_checker) # self.model_mode = None diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py index c421acf354c8..7e66bff51d3b 100644 --- a/examples/community/rerender_a_video.py +++ b/examples/community/rerender_a_video.py @@ -352,7 +352,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -632,7 +632,7 @@ def __call__( The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. frames (`List[np.ndarray]` or `torch.Tensor`): The input images to be used as the starting point for the image generation process. - control_frames (`List[np.ndarray]` or `torch.Tensor`): The ControlNet input images condition to provide guidance to the `unet` for generation. + control_frames (`List[np.ndarray]` or `torch.Tensor` or `Callable`): The ControlNet input images condition to provide guidance to the `unet` for generation or any callable object to convert frame to control_frame. strength ('float'): SDEdit strength. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -789,7 +789,7 @@ def __call__( # Currently we only support single control if isinstance(controlnet, ControlNetModel): control_image = self.prepare_control_image( - image=control_frames[0], + image=control_frames(frames[0]) if callable(control_frames) else control_frames[0], width=width, height=height, batch_size=batch_size, @@ -908,6 +908,9 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: @@ -924,7 +927,7 @@ def __call__( for idx in range(1, len(frames)): image = frames[idx] prev_image = frames[idx - 1] - control_image = control_frames[idx] + control_image = control_frames(image) if callable(control_frames) else control_frames[idx] # 5.1 prepare frames image = self.image_processor.preprocess(image).to(dtype=self.dtype) prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype) diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py index c7c88d6fdcc7..6aa4067d695d 100644 --- a/examples/community/stable_diffusion_controlnet_img2img.py +++ b/examples/community/stable_diffusion_controlnet_img2img.py @@ -179,7 +179,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt( diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py index b473ffe79933..2d19e26b4220 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint.py +++ b/examples/community/stable_diffusion_controlnet_inpaint.py @@ -278,7 +278,7 @@ def __init__( feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt( diff --git a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py index 8928f34239e3..4363a2294b63 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py +++ b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py @@ -263,7 +263,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.register_to_config(requires_safety_checker=requires_safety_checker) def _encode_prompt( diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py index ecd38ce345c5..b2d4541797f5 100644 --- a/examples/community/stable_diffusion_ipex.py +++ b/examples/community/stable_diffusion_ipex.py @@ -148,10 +148,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -178,7 +182,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.register_to_config(requires_safety_checker=requires_safety_checker) def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1): diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py index 1c705f5c768e..9ef95a52051d 100644 --- a/examples/community/stable_diffusion_reference.py +++ b/examples/community/stable_diffusion_reference.py @@ -181,10 +181,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -202,7 +206,7 @@ def __init__( new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 - if unet.config.in_channels != 4: + if unet is not None and unet.config.in_channels != 4: logger.warning( f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default," f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`," @@ -219,7 +223,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/community/stable_diffusion_repaint.py b/examples/community/stable_diffusion_repaint.py index a2b221b84969..0bc28eca15cc 100644 --- a/examples/community/stable_diffusion_repaint.py +++ b/examples/community/stable_diffusion_repaint.py @@ -236,10 +236,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -257,7 +261,7 @@ def __init__( new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 - if unet.config.in_channels != 4: + if unet is not None and unet.config.in_channels != 4: logger.warning( f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default," f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`," @@ -274,7 +278,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt diff --git a/examples/community/stable_diffusion_tensorrt_img2img.py b/examples/community/stable_diffusion_tensorrt_img2img.py index 87a9d7cb84ec..ae12cd94f9b0 100755 --- a/examples/community/stable_diffusion_tensorrt_img2img.py +++ b/examples/community/stable_diffusion_tensorrt_img2img.py @@ -753,10 +753,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -806,7 +810,7 @@ def __init__( self.engine = {} # loaded in build_engines() self.vae.forward = self.vae.decode - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/community/stable_diffusion_tensorrt_inpaint.py b/examples/community/stable_diffusion_tensorrt_inpaint.py index d6b1331adac1..557aabdacfb8 100755 --- a/examples/community/stable_diffusion_tensorrt_inpaint.py +++ b/examples/community/stable_diffusion_tensorrt_inpaint.py @@ -757,10 +757,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -810,7 +814,7 @@ def __init__( self.engine = {} # loaded in build_engines() self.vae.forward = self.vae.decode - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py index b008b3bae944..595c5f5ea830 100755 --- a/examples/community/stable_diffusion_tensorrt_txt2img.py +++ b/examples/community/stable_diffusion_tensorrt_txt2img.py @@ -669,10 +669,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -722,7 +726,7 @@ def __init__( self.engine = {} # loaded in build_engines() self.vae.forward = self.vae.decode - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py index aace66f9c18e..d7f882974a22 100644 --- a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py +++ b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py @@ -310,7 +310,7 @@ def __init__( controlnet=controlnet, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) self.control_image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) diff --git a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py index cb4260d4653f..19c1f30d82da 100644 --- a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py +++ b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py @@ -233,7 +233,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False diff --git a/examples/research_projects/rdm/pipeline_rdm.py b/examples/research_projects/rdm/pipeline_rdm.py index f8093a3f217d..e84568786f50 100644 --- a/examples/research_projects/rdm/pipeline_rdm.py +++ b/examples/research_projects/rdm/pipeline_rdm.py @@ -78,7 +78,7 @@ def __init__( feature_extractor=feature_extractor, ) # Copy from statement here and all the methods we take from stable_diffusion_pipeline - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.retriever = retriever diff --git a/examples/research_projects/realfill/requirements.txt b/examples/research_projects/realfill/requirements.txt index 8fbaf908a2c8..96f504ece1f3 100644 --- a/examples/research_projects/realfill/requirements.txt +++ b/examples/research_projects/realfill/requirements.txt @@ -6,4 +6,4 @@ torch==2.2.0 torchvision>=0.16 ftfy==6.1.1 tensorboard==2.14.0 -Jinja2==3.1.4 +Jinja2==3.1.5 diff --git a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py index 163ff8f08931..e883d8ef95a7 100644 --- a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py +++ b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py @@ -765,7 +765,7 @@ def load_model_hook(models, input_dir): lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) transformer_state_dict = { - f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") + f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.") } transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict) incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default") diff --git a/scripts/convert_sana_to_diffusers.py b/scripts/convert_sana_to_diffusers.py index 2f1732817be3..99a9ff322251 100644 --- a/scripts/convert_sana_to_diffusers.py +++ b/scripts/convert_sana_to_diffusers.py @@ -25,6 +25,7 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext ckpt_ids = [ + "Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth", "Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth", "Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth", "Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth", @@ -89,7 +90,10 @@ def main(args): converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight") # scheduler - flow_shift = 3.0 + if args.image_size == 4096: + flow_shift = 6.0 + else: + flow_shift = 3.0 # model config if args.model_type == "SanaMS_1600M_P1_D20": @@ -99,7 +103,7 @@ def main(args): else: raise ValueError(f"{args.model_type} is not supported.") # Positional embedding interpolation scale. - interpolation_scale = {512: None, 1024: None, 2048: 1.0} + interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0} for depth in range(layer_num): # Transformer blocks. @@ -272,9 +276,9 @@ def main(args): "--image_size", default=1024, type=int, - choices=[512, 1024, 2048], + choices=[512, 1024, 2048, 4096], required=False, - help="Image size of pretrained model, 512, 1024 or 2048.", + help="Image size of pretrained model, 512, 1024, 2048 or 4096.", ) parser.add_argument( "--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"] diff --git a/setup.py b/setup.py index 35ce34920f2a..d696c14ca842 100644 --- a/setup.py +++ b/setup.py @@ -135,6 +135,7 @@ "transformers>=4.41.2", "urllib3<=2.0.0", "black", + "phonemizer", ] # this is a lookup table with items like: @@ -227,6 +228,7 @@ def run(self): "scipy", "torchvision", "transformers", + "phonemizer", ) extras["torch"] = deps_list("torch", "accelerate") diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 9e7bf242eca7..bb5a54f73419 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -43,4 +43,5 @@ "transformers": "transformers>=4.41.2", "urllib3": "urllib3<=2.0.0", "black": "black", + "phonemizer": "phonemizer", } diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py index 286d0a12bc71..0c584777affc 100644 --- a/src/diffusers/loaders/lora_base.py +++ b/src/diffusers/loaders/lora_base.py @@ -28,13 +28,20 @@ from ..utils import ( USE_PEFT_BACKEND, _get_model_file, + convert_state_dict_to_diffusers, + convert_state_dict_to_peft, delete_adapter_layers, deprecate, + get_adapter_name, + get_peft_kwargs, is_accelerate_available, is_peft_available, + is_peft_version, is_transformers_available, + is_transformers_version, logging, recurse_remove_peft_layers, + scale_lora_layers, set_adapter_layers, set_weights_and_activate_adapters, ) @@ -43,6 +50,8 @@ if is_transformers_available(): from transformers import PreTrainedModel + from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules + if is_peft_available(): from peft.tuners.tuners_utils import BaseTunerLayer @@ -297,6 +306,152 @@ def _best_guess_weight_name( return weight_name +def _load_lora_into_text_encoder( + state_dict, + network_alphas, + text_encoder, + prefix=None, + lora_scale=1.0, + text_encoder_name="text_encoder", + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, +): + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + peft_kwargs = {} + if low_cpu_mem_usage: + if not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + if not is_transformers_version(">", "4.45.2"): + # Note from sayakpaul: It's not in `transformers` stable yet. + # https://github.com/huggingface/transformers/pull/33725/ + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." + ) + peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage + + from peft import LoraConfig + + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), + # then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as + # their prefixes. + keys = list(state_dict.keys()) + prefix = text_encoder_name if prefix is None else prefix + + # Safe prefix to check with. + if any(text_encoder_name in key for key in keys): + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix] + text_encoder_lora_state_dict = { + k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys + } + + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {prefix}.") + rank = {} + text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) + + # convert state dict + text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) + + for name, _ in text_encoder_attn_modules(text_encoder): + for module in ("out_proj", "q_proj", "k_proj", "v_proj"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + for name, _ in text_encoder_mlp_modules(text_encoder): + for module in ("fc1", "fc2"): + rank_key = f"{name}.{module}.lora_B.weight" + if rank_key not in text_encoder_lora_state_dict: + continue + rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] + + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] + network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) + + if "use_dora" in lora_config_kwargs: + if lora_config_kwargs["use_dora"]: + if is_peft_version("<", "0.9.0"): + raise ValueError( + "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<", "0.9.0"): + lora_config_kwargs.pop("use_dora") + + if "lora_bias" in lora_config_kwargs: + if lora_config_kwargs["lora_bias"]: + if is_peft_version("<=", "0.13.2"): + raise ValueError( + "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." + ) + else: + if is_peft_version("<=", "0.13.2"): + lora_config_kwargs.pop("lora_bias") + + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(text_encoder) + + is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline) + + # inject LoRA layers and load the state dict + # in transformers we automatically check whether the adapter name is already in use or not + text_encoder.load_adapter( + adapter_name=adapter_name, + adapter_state_dict=text_encoder_lora_state_dict, + peft_config=lora_config, + **peft_kwargs, + ) + + # scale LoRA layers with `lora_scale` + scale_lora_layers(text_encoder, weight=lora_scale) + + text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + + +def _func_optionally_disable_offloading(_pipeline): + is_model_cpu_offload = False + is_sequential_cpu_offload = False + + if _pipeline is not None and _pipeline.hf_device_map is None: + for _, component in _pipeline.components.items(): + if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): + if not is_model_cpu_offload: + is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) + if not is_sequential_cpu_offload: + is_sequential_cpu_offload = ( + isinstance(component._hf_hook, AlignDevicesHook) + or hasattr(component._hf_hook, "hooks") + and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) + ) + + logger.info( + "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." + ) + remove_hook_from_module(component, recurse=is_sequential_cpu_offload) + + return (is_model_cpu_offload, is_sequential_cpu_offload) + + class LoraBaseMixin: """Utility class for handling LoRAs.""" @@ -327,27 +482,7 @@ def _optionally_disable_offloading(cls, _pipeline): tuple: A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. """ - is_model_cpu_offload = False - is_sequential_cpu_offload = False - - if _pipeline is not None and _pipeline.hf_device_map is None: - for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): - if not is_model_cpu_offload: - is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) - if not is_sequential_cpu_offload: - is_sequential_cpu_offload = ( - isinstance(component._hf_hook, AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) - - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - - return (is_model_cpu_offload, is_sequential_cpu_offload) + return _func_optionally_disable_offloading(_pipeline=_pipeline) @classmethod def _fetch_state_dict(cls, *args, **kwargs): diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index c866fd603095..d0d2d36bea9f 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -20,20 +20,21 @@ from ..utils import ( USE_PEFT_BACKEND, - convert_state_dict_to_diffusers, - convert_state_dict_to_peft, deprecate, - get_adapter_name, - get_peft_kwargs, is_peft_available, is_peft_version, is_torch_version, is_transformers_available, is_transformers_version, logging, - scale_lora_layers, ) -from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa +from .lora_base import ( # noqa + LORA_WEIGHT_NAME, + LORA_WEIGHT_NAME_SAFE, + LoraBaseMixin, + _fetch_state_dict, + _load_lora_into_text_encoder, +) from .lora_conversion_utils import ( _convert_bfl_flux_control_lora_to_diffusers, _convert_hunyuan_video_lora_to_diffusers, @@ -55,9 +56,6 @@ _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True -if is_transformers_available(): - from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules - logger = logging.get_logger(__name__) TEXT_ENCODER_NAME = "text_encoder" @@ -345,113 +343,17 @@ def load_lora_into_text_encoder( Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - peft_kwargs = {} - if low_cpu_mem_usage: - if not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - if not is_transformers_version(">", "4.45.2"): - # Note from sayakpaul: It's not in `transformers` stable yet. - # https://github.com/huggingface/transformers/pull/33725/ - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." - ) - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - from peft import LoraConfig - - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as - # their prefixes. - prefix = cls.text_encoder_name if prefix is None else prefix - text_encoder_lora_state_dict = { - k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k.startswith(f"{prefix}.") - } - - if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {prefix}.") - rank = {} - text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - if network_alphas is not None: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) - - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") - - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) - - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=text_encoder_lora_state_dict, - peft_config=lora_config, - **peft_kwargs, - ) - - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) - - text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - else: - logger.info( - f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new" - ) + _load_lora_into_text_encoder( + state_dict=state_dict, + network_alphas=network_alphas, + lora_scale=lora_scale, + text_encoder=text_encoder, + prefix=prefix, + text_encoder_name=cls.text_encoder_name, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def save_lora_weights( @@ -873,113 +775,17 @@ def load_lora_into_text_encoder( Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - peft_kwargs = {} - if low_cpu_mem_usage: - if not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - if not is_transformers_version(">", "4.45.2"): - # Note from sayakpaul: It's not in `transformers` stable yet. - # https://github.com/huggingface/transformers/pull/33725/ - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." - ) - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - from peft import LoraConfig - - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as - # their prefixes. - prefix = cls.text_encoder_name if prefix is None else prefix - text_encoder_lora_state_dict = { - k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k.startswith(f"{prefix}.") - } - - if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {prefix}.") - rank = {} - text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - if network_alphas is not None: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) - - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") - - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) - - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=text_encoder_lora_state_dict, - peft_config=lora_config, - **peft_kwargs, - ) - - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) - - text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - else: - logger.info( - f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new" - ) + _load_lora_into_text_encoder( + state_dict=state_dict, + network_alphas=network_alphas, + lora_scale=lora_scale, + text_encoder=text_encoder, + prefix=prefix, + text_encoder_name=cls.text_encoder_name, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def save_lora_weights( @@ -1366,113 +1172,17 @@ def load_lora_into_text_encoder( Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - peft_kwargs = {} - if low_cpu_mem_usage: - if not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - if not is_transformers_version(">", "4.45.2"): - # Note from sayakpaul: It's not in `transformers` stable yet. - # https://github.com/huggingface/transformers/pull/33725/ - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." - ) - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - from peft import LoraConfig - - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as - # their prefixes. - prefix = cls.text_encoder_name if prefix is None else prefix - text_encoder_lora_state_dict = { - k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k.startswith(f"{prefix}.") - } - - if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {prefix}.") - rank = {} - text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - if network_alphas is not None: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) - - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") - - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) - - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=text_encoder_lora_state_dict, - peft_config=lora_config, - **peft_kwargs, - ) - - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) - - text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - else: - logger.info( - f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new" - ) + _load_lora_into_text_encoder( + state_dict=state_dict, + network_alphas=network_alphas, + lora_scale=lora_scale, + text_encoder=text_encoder, + prefix=prefix, + text_encoder_name=cls.text_encoder_name, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def save_lora_weights( @@ -1994,113 +1704,17 @@ def load_lora_into_text_encoder( Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - peft_kwargs = {} - if low_cpu_mem_usage: - if not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - if not is_transformers_version(">", "4.45.2"): - # Note from sayakpaul: It's not in `transformers` stable yet. - # https://github.com/huggingface/transformers/pull/33725/ - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." - ) - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - from peft import LoraConfig - - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as - # their prefixes. - prefix = cls.text_encoder_name if prefix is None else prefix - text_encoder_lora_state_dict = { - k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k.startswith(f"{prefix}.") - } - - if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {prefix}.") - rank = {} - text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - if network_alphas is not None: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) - - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") - - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) - - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=text_encoder_lora_state_dict, - peft_config=lora_config, - **peft_kwargs, - ) - - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) - - text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - else: - logger.info( - f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new" - ) + _load_lora_into_text_encoder( + state_dict=state_dict, + network_alphas=network_alphas, + lora_scale=lora_scale, + text_encoder=text_encoder, + prefix=prefix, + text_encoder_name=cls.text_encoder_name, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights with unet->transformer @@ -2159,7 +1773,7 @@ def save_lora_weights( def fuse_lora( self, - components: List[str] = ["transformer", "text_encoder"], + components: List[str] = ["transformer"], lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, @@ -2550,113 +2164,17 @@ def load_lora_into_text_encoder( Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. """ - if not USE_PEFT_BACKEND: - raise ValueError("PEFT backend is required for this method.") - - peft_kwargs = {} - if low_cpu_mem_usage: - if not is_peft_version(">=", "0.13.1"): - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." - ) - if not is_transformers_version(">", "4.45.2"): - # Note from sayakpaul: It's not in `transformers` stable yet. - # https://github.com/huggingface/transformers/pull/33725/ - raise ValueError( - "`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`." - ) - peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage - - from peft import LoraConfig - - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), - # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as - # their prefixes. - prefix = cls.text_encoder_name if prefix is None else prefix - text_encoder_lora_state_dict = { - k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k.startswith(f"{prefix}.") - } - - if len(text_encoder_lora_state_dict) > 0: - logger.info(f"Loading {prefix}.") - rank = {} - text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) - # convert state dict - text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) - - for name, _ in text_encoder_attn_modules(text_encoder): - for module in ("out_proj", "q_proj", "k_proj", "v_proj"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - for name, _ in text_encoder_mlp_modules(text_encoder): - for module in ("fc1", "fc2"): - rank_key = f"{name}.{module}.lora_B.weight" - if rank_key not in text_encoder_lora_state_dict: - continue - rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1] - - if network_alphas is not None: - alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] - network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys} - - lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False) - - if "use_dora" in lora_config_kwargs: - if lora_config_kwargs["use_dora"]: - if is_peft_version("<", "0.9.0"): - raise ValueError( - "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<", "0.9.0"): - lora_config_kwargs.pop("use_dora") - - if "lora_bias" in lora_config_kwargs: - if lora_config_kwargs["lora_bias"]: - if is_peft_version("<=", "0.13.2"): - raise ValueError( - "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`." - ) - else: - if is_peft_version("<=", "0.13.2"): - lora_config_kwargs.pop("lora_bias") - - lora_config = LoraConfig(**lora_config_kwargs) - - # adapter_name - if adapter_name is None: - adapter_name = get_adapter_name(text_encoder) - - is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) - - # inject LoRA layers and load the state dict - # in transformers we automatically check whether the adapter name is already in use or not - text_encoder.load_adapter( - adapter_name=adapter_name, - adapter_state_dict=text_encoder_lora_state_dict, - peft_config=lora_config, - **peft_kwargs, - ) - - # scale LoRA layers with `lora_scale` - scale_lora_layers(text_encoder, weight=lora_scale) - - text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) - - # Offload back. - if is_model_cpu_offload: - _pipeline.enable_model_cpu_offload() - elif is_sequential_cpu_offload: - _pipeline.enable_sequential_cpu_offload() - # Unsafe code /> - - else: - logger.info( - f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new" - ) + _load_lora_into_text_encoder( + state_dict=state_dict, + network_alphas=network_alphas, + lora_scale=lora_scale, + text_encoder=text_encoder, + prefix=prefix, + text_encoder_name=cls.text_encoder_name, + adapter_name=adapter_name, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + ) @classmethod def save_lora_weights( @@ -2954,10 +2472,9 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer def fuse_lora( self, - components: List[str] = ["transformer", "text_encoder"], + components: List[str] = ["transformer"], lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, @@ -2998,8 +2515,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer - def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). @@ -3013,9 +2529,6 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) @@ -3262,10 +2775,9 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer def fuse_lora( self, - components: List[str] = ["transformer", "text_encoder"], + components: List[str] = ["transformer"], lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, @@ -3306,8 +2818,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer - def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). @@ -3321,9 +2832,6 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) @@ -3570,10 +3078,9 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer def fuse_lora( self, - components: List[str] = ["transformer", "text_encoder"], + components: List[str] = ["transformer"], lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, @@ -3614,8 +3121,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer - def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). @@ -3629,9 +3135,6 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) @@ -3878,10 +3381,9 @@ def save_lora_weights( safe_serialization=safe_serialization, ) - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer def fuse_lora( self, - components: List[str] = ["transformer", "text_encoder"], + components: List[str] = ["transformer"], lora_scale: float = 1.0, safe_fusing: bool = False, adapter_names: Optional[List[str]] = None, @@ -3922,8 +3424,7 @@ def fuse_lora( components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names ) - # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer - def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs): + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): r""" Reverses the effect of [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora). @@ -3937,9 +3438,6 @@ def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], * Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) @@ -4246,9 +3744,6 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): Args: components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. - unfuse_text_encoder (`bool`, defaults to `True`): - Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the - LoRA parameters then it won't have any effect. """ super().unfuse_lora(components=components) diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index bcd3c57762dd..139d7430b035 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -20,7 +20,6 @@ import safetensors import torch -import torch.nn as nn from ..utils import ( MIN_PEFT_VERSION, @@ -30,20 +29,16 @@ delete_adapter_layers, get_adapter_name, get_peft_kwargs, - is_accelerate_available, is_peft_available, is_peft_version, logging, set_adapter_layers, set_weights_and_activate_adapters, ) -from .lora_base import _fetch_state_dict +from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading from .unet_loader_utils import _maybe_expand_lora_scales -if is_accelerate_available(): - from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module - logger = logging.get_logger(__name__) _SET_ADAPTER_SCALE_FN_MAPPING = { @@ -140,27 +135,7 @@ def _optionally_disable_offloading(cls, _pipeline): tuple: A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. """ - is_model_cpu_offload = False - is_sequential_cpu_offload = False - - if _pipeline is not None and _pipeline.hf_device_map is None: - for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): - if not is_model_cpu_offload: - is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) - if not is_sequential_cpu_offload: - is_sequential_cpu_offload = ( - isinstance(component._hf_hook, AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) - - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - - return (is_model_cpu_offload, is_sequential_cpu_offload) + return _func_optionally_disable_offloading(_pipeline=_pipeline) def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs): r""" diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index c5c9bea29b8a..007332f73409 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -60,6 +60,7 @@ def load_single_file_sub_model( local_files_only=False, torch_dtype=None, is_legacy_loading=False, + disable_mmap=False, **kwargs, ): if is_pipeline_module: @@ -106,6 +107,7 @@ def load_single_file_sub_model( subfolder=name, torch_dtype=torch_dtype, local_files_only=local_files_only, + disable_mmap=disable_mmap, **kwargs, ) @@ -308,6 +310,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): hosted on the Hub. - A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component configs in Diffusers format. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example @@ -355,6 +360,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) + disable_mmap = kwargs.pop("disable_mmap", False) is_legacy_loading = False @@ -383,6 +389,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs): cache_dir=cache_dir, local_files_only=local_files_only, revision=revision, + disable_mmap=disable_mmap, ) if config is None: @@ -504,6 +511,7 @@ def load_module(name, value): original_config=original_config, local_files_only=local_files_only, is_legacy_loading=is_legacy_loading, + disable_mmap=disable_mmap, **kwargs, ) except SingleFileComponentError as e: diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 79dc2691b9e4..69ab8b6bad20 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -25,6 +25,7 @@ from .single_file_utils import ( SingleFileComponentError, convert_animatediff_checkpoint_to_diffusers, + convert_auraflow_transformer_checkpoint_to_diffusers, convert_autoencoder_dc_checkpoint_to_diffusers, convert_controlnet_checkpoint, convert_flux_transformer_checkpoint_to_diffusers, @@ -106,6 +107,10 @@ "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers, "default_subfolder": "transformer", }, + "AuraFlowTransformer2DModel": { + "checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, } @@ -182,6 +187,9 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (for example the pipeline components of the specific pipeline class). The overwritten components are directly passed to the pipelines `__init__` @@ -229,6 +237,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = torch_dtype = kwargs.pop("torch_dtype", None) quantization_config = kwargs.pop("quantization_config", None) device = kwargs.pop("device", None) + disable_mmap = kwargs.pop("disable_mmap", False) if isinstance(pretrained_model_link_or_path_or_dict, dict): checkpoint = pretrained_model_link_or_path_or_dict @@ -241,6 +250,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = cache_dir=cache_dir, local_files_only=local_files_only, revision=revision, + disable_mmap=disable_mmap, ) if quantization_config is not None: hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 1fa1bdf259cc..9766098d8584 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -94,6 +94,12 @@ "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight", "animatediff_rgb": "controlnet_cond_embedding.weight", + "auraflow": [ + "double_layers.0.attn.w2q.weight", + "double_layers.0.attn.w1q.weight", + "cond_seq_linear.weight", + "t_embedder.mlp.0.weight", + ], "flux": [ "double_blocks.0.img_attn.norm.key_norm.scale", "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", @@ -154,6 +160,7 @@ "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"}, "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"}, "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, + "auraflow": {"pretrained_model_name_or_path": "fal/AuraFlow-v0.3"}, "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, "flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"}, "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"}, @@ -179,6 +186,7 @@ "inpainting": 512, "inpainting_v2": 512, "controlnet": 512, + "instruct-pix2pix": 512, "v2": 768, "v1": 512, } @@ -380,6 +388,7 @@ def load_single_file_checkpoint( cache_dir=None, local_files_only=None, revision=None, + disable_mmap=False, ): if os.path.isfile(pretrained_model_link_or_path): pretrained_model_link_or_path = pretrained_model_link_or_path @@ -397,7 +406,7 @@ def load_single_file_checkpoint( revision=revision, ) - checkpoint = load_state_dict(pretrained_model_link_or_path) + checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap) # some checkpoints contain the model state dict under a "state_dict" key while "state_dict" in checkpoint: @@ -635,6 +644,9 @@ def infer_diffusers_model_type(checkpoint): elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint: model_type = "hunyuan-video" + elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["auraflow"]): + model_type = "auraflow" + elif ( CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8 @@ -2090,6 +2102,7 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs): def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} keys = list(checkpoint.keys()) + for k in keys: if "model.diffusion_model." in k: checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) @@ -2689,3 +2702,95 @@ def update_state_dict_(state_dict, old_key, new_key): handler_fn_inplace(key, checkpoint) return checkpoint + + +def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + state_dict_keys = list(checkpoint.keys()) + + # Handle register tokens and positional embeddings + converted_state_dict["register_tokens"] = checkpoint.pop("register_tokens", None) + + # Handle time step projection + converted_state_dict["time_step_proj.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight", None) + converted_state_dict["time_step_proj.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias", None) + converted_state_dict["time_step_proj.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight", None) + converted_state_dict["time_step_proj.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias", None) + + # Handle context embedder + converted_state_dict["context_embedder.weight"] = checkpoint.pop("cond_seq_linear.weight", None) + + # Calculate the number of layers + def calculate_layers(keys, key_prefix): + layers = set() + for k in keys: + if key_prefix in k: + layer_num = int(k.split(".")[1]) # get the layer number + layers.add(layer_num) + return len(layers) + + mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") + single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") + + # MMDiT blocks + for i in range(mmdit_layers): + # Feed-forward + path_mapping = {"mlpX": "ff", "mlpC": "ff_context"} + weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} + for orig_k, diffuser_k in path_mapping.items(): + for k, v in weight_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = checkpoint.pop( + f"double_layers.{i}.{orig_k}.{k}.weight", None + ) + + # Norms + path_mapping = {"modX": "norm1", "modC": "norm1_context"} + for orig_k, diffuser_k in path_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = checkpoint.pop( + f"double_layers.{i}.{orig_k}.1.weight", None + ) + + # Attentions + x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"} + context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"} + for attn_mapping in [x_attn_mapping, context_attn_mapping]: + for k, v in attn_mapping.items(): + converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( + f"double_layers.{i}.attn.{k}.weight", None + ) + + # Single-DiT blocks + for i in range(single_dit_layers): + # Feed-forward + mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"} + for k, v in mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = checkpoint.pop( + f"single_layers.{i}.mlp.{k}.weight", None + ) + + # Norms + converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop( + f"single_layers.{i}.modCX.1.weight", None + ) + + # Attentions + x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"} + for k, v in x_attn_mapping.items(): + converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = checkpoint.pop( + f"single_layers.{i}.attn.{k}.weight", None + ) + # Final blocks + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_linear.weight", None) + + # Handle the final norm layer + norm_weight = checkpoint.pop("modF.1.weight", None) + if norm_weight is not None: + converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(norm_weight, dim=None) + else: + converted_state_dict["norm_out.linear.weight"] = None + + converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("positional_encoding") + converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("init_x_linear.weight") + converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias") + + return converted_state_dict diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index d84c52c98440..c68349c36dba 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -21,7 +21,6 @@ import torch import torch.nn.functional as F from huggingface_hub.utils import validate_hf_hub_args -from torch import nn from ..models.embeddings import ( ImageProjection, @@ -44,13 +43,11 @@ is_torch_version, logging, ) +from .lora_base import _func_optionally_disable_offloading from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME from .utils import AttnProcsLayers -if is_accelerate_available(): - from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module - logger = logging.get_logger(__name__) @@ -411,27 +408,7 @@ def _optionally_disable_offloading(cls, _pipeline): tuple: A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. """ - is_model_cpu_offload = False - is_sequential_cpu_offload = False - - if _pipeline is not None and _pipeline.hf_device_map is None: - for _, component in _pipeline.components.items(): - if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): - if not is_model_cpu_offload: - is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) - if not is_sequential_cpu_offload: - is_sequential_cpu_offload = ( - isinstance(component._hf_hook, AlignDevicesHook) - or hasattr(component._hf_hook, "hooks") - and isinstance(component._hf_hook.hooks[0], AlignDevicesHook) - ) - - logger.info( - "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." - ) - remove_hook_from_module(component, recurse=is_sequential_cpu_offload) - - return (is_model_cpu_offload, is_sequential_cpu_offload) + return _func_optionally_disable_offloading(_pipeline=_pipeline) def save_attn_procs( self, diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py index 109e37c23e1b..1e6a26dddca8 100644 --- a/src/diffusers/models/autoencoders/autoencoder_dc.py +++ b/src/diffusers/models/autoencoders/autoencoder_dc.py @@ -486,6 +486,9 @@ def __init__( self.tile_sample_stride_height = 448 self.tile_sample_stride_width = 448 + self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + def enable_tiling( self, tile_sample_min_height: Optional[int] = None, @@ -515,6 +518,8 @@ def enable_tiling( self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio def disable_tiling(self) -> None: r""" @@ -606,11 +611,106 @@ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutp return (decoded,) return DecoderOutput(sample=decoded) + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor: - raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.") + batch_size, num_channels, height, width = x.shape + latent_height = height // self.spatial_compression_ratio + latent_width = width // self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + # Split x into overlapping tiles and encode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, x.shape[2], self.tile_sample_stride_height): + row = [] + for j in range(0, x.shape[3], self.tile_sample_stride_width): + tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + if ( + tile.shape[2] % self.spatial_compression_ratio != 0 + or tile.shape[3] % self.spatial_compression_ratio != 0 + ): + pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio + pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio + tile = F.pad(tile, (0, pad_w, 0, pad_h)) + tile = self.encoder(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width] + + if not return_dict: + return (encoded,) + return EncoderOutput(latent=encoded) def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: - raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.") + batch_size, num_channels, height, width = z.shape + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + + blend_height = self.tile_sample_min_height - self.tile_sample_stride_height + blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + + # Split z into overlapping tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=3)) + + decoded = torch.cat(result_rows, dim=2) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor: encoded = self.encode(sample, return_dict=False)[0] diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 5f5ea2351709..a3d006f18994 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -131,7 +131,9 @@ def _fetch_remapped_cls_from_config(config, old_class): return old_class -def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None): +def load_state_dict( + checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False +): """ Reads a checkpoint file, returning properly formatted errors if they arise. """ @@ -142,7 +144,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[ try: file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: - return safetensors.torch.load_file(checkpoint_file, device="cpu") + if disable_mmap: + return safetensors.torch.load(open(checkpoint_file, "rb").read()) + else: + return safetensors.torch.load_file(checkpoint_file, device="cpu") elif file_extension == GGUF_FILE_EXTENSION: return load_gguf_checkpoint(checkpoint_file) else: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index d6efcc736487..17e9d2043150 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -559,6 +559,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` weights. If set to `False`, `safetensors` weights are not loaded. + disable_mmap ('bool', *optional*, defaults to 'False'): + Whether to disable mmap when loading a Safetensors model. This option can perform better when the model + is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well. @@ -604,6 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) quantization_config = kwargs.pop("quantization_config", None) + disable_mmap = kwargs.pop("disable_mmap", False) allow_pickle = False if use_safetensors is None: @@ -883,7 +887,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # TODO (sayakpaul, SunMarc): remove this after model loading refactor else: param_device = torch.device(torch.cuda.current_device()) - state_dict = load_state_dict(model_file, variant=variant) + state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap) model._convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu @@ -920,14 +924,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU - force_hook = True device_map = _determine_device_map( model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer ) if device_map is None and is_sharded: # we load the parameters on the cpu device_map = {"": "cpu"} - force_hook = False try: accelerate.load_checkpoint_and_dispatch( model, @@ -937,7 +939,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, - force_hooks=force_hook, strict=True, ) except AttributeError as e: @@ -967,7 +968,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_folder=offload_folder, offload_state_dict=offload_state_dict, dtype=torch_dtype, - force_hooks=force_hook, strict=True, ) model._undo_temp_convert_self_to_deprecated_attention_blocks() @@ -983,7 +983,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: model = cls.from_config(config, **unused_kwargs) - state_dict = load_state_dict(model_file, variant=variant) + state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap) model._convert_deprecated_attention_blocks(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( @@ -1214,7 +1214,7 @@ def _get_signature_keys(cls, obj): # Adapted from `transformers` modeling_utils.py def _get_no_split_modules(self, device_map: str): """ - Get the modules of the model that should not be spit when using device_map. We iterate through the modules to + Get the modules of the model that should not be split when using device_map. We iterate through the modules to get the underlying `_no_split_modules`. Args: diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py index b3f29e6b6224..b35488a89282 100644 --- a/src/diffusers/models/transformers/auraflow_transformer_2d.py +++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py @@ -20,6 +20,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin from ...utils import is_torch_version, logging from ...utils.torch_utils import maybe_allow_in_graph from ..attention_processor import ( @@ -253,7 +254,7 @@ def forward( return encoder_hidden_states, hidden_states -class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin): +class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/). diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py index b47d439774cc..51634780692d 100644 --- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py +++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py @@ -120,8 +120,10 @@ def forward( encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: text_seq_length = encoder_hidden_states.size(1) + attention_kwargs = attention_kwargs or {} # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( @@ -133,6 +135,7 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + **attention_kwargs, ) hidden_states = hidden_states + gate_msa * attn_hidden_states @@ -210,6 +213,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"] @register_to_config def __init__( @@ -497,6 +501,7 @@ def custom_forward(*inputs): encoder_hidden_states, emb, image_rotary_emb, + attention_kwargs, **ckpt_kwargs, ) else: @@ -505,6 +510,7 @@ def custom_forward(*inputs): encoder_hidden_states=encoder_hidden_states, temb=emb, image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, ) if not self.config.use_rotary_positional_embeddings: diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py index fe9c7290b063..81039fd49e0d 100644 --- a/src/diffusers/models/transformers/transformer_allegro.py +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -221,6 +221,8 @@ class AllegroTransformer3DModel(ModelMixin, ConfigMixin): Scaling factor to apply in 3D positional embeddings across time dimension. """ + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py index 94d852f6df4b..369509a3a35e 100644 --- a/src/diffusers/models/transformers/transformer_cogview3plus.py +++ b/src/diffusers/models/transformers/transformer_cogview3plus.py @@ -166,6 +166,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["CogView3PlusTransformerBlock", "CogView3PlusPatchEmbed"] @register_to_config def __init__( diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py index 6cb97af93652..044f2048775f 100644 --- a/src/diffusers/models/transformers/transformer_hunyuan_video.py +++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py @@ -542,6 +542,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, """ _supports_gradient_checkpointing = True + _no_split_modules = [ + "HunyuanVideoTransformerBlock", + "HunyuanVideoSingleTransformerBlock", + "HunyuanVideoPatchEmbed", + "HunyuanVideoTokenRefiner", + ] @register_to_config def __init__( @@ -713,15 +719,15 @@ def forward( condition_sequence_length = encoder_hidden_states.shape[1] sequence_length = latent_sequence_length + condition_sequence_length attention_mask = torch.zeros( - batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool - ) # [B, N, N] + batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool + ) # [B, N] effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,] effective_sequence_length = latent_sequence_length + effective_condition_sequence_length for i in range(batch_size): - attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True - attention_mask = attention_mask.unsqueeze(1) # [B, 1, N, N], for broadcasting across attention heads + attention_mask[i, : effective_sequence_length[i]] = True + attention_mask = attention_mask.unsqueeze(1) # [B, 1, N], for broadcasting across attention heads # 4. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index bec62ce5cf45..090357237f46 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block types. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): - Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. + Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`. up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): @@ -103,6 +103,7 @@ def __init__( freq_shift: int = 0, flip_sin_to_cos: bool = True, down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"), + mid_block_type: Optional[str] = "UNetMidBlock2D", up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"), block_out_channels: Tuple[int, ...] = (224, 448, 672, 896), layers_per_block: int = 2, @@ -194,19 +195,22 @@ def __init__( self.down_blocks.append(down_block) # mid - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, - dropout=dropout, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], - resnet_groups=norm_num_groups, - attn_groups=attn_norm_num_groups, - add_attention=add_attention, - ) + if mid_block_type is None: + self.mid_block = None + else: + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1], + resnet_groups=norm_num_groups, + attn_groups=attn_norm_num_groups, + add_attention=add_attention, + ) # up reversed_block_out_channels = list(reversed(block_out_channels)) @@ -322,7 +326,8 @@ def forward( down_block_res_samples += res_samples # 4. mid - sample = self.mid_block(sample, emb) + if self.mid_block is not None: + sample = self.mid_block(sample, emb) # 5. up skip_sample = None diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index b3650dc6cee1..91aedf2cdbe6 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -33,6 +33,7 @@ deprecate, is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -41,6 +42,14 @@ from .pipeline_output import AllegroPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + logger = logging.get_logger(__name__) if is_bs4_available(): @@ -194,10 +203,10 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor_spatial = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 ) self.vae_scale_factor_temporal = ( - self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 ) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -921,6 +930,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": latents = latents.to(self.vae.dtype) video = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/amused/pipeline_amused.py b/src/diffusers/pipelines/amused/pipeline_amused.py index a8c24b0aeecc..12f7dc7c59d4 100644 --- a/src/diffusers/pipelines/amused/pipeline_amused.py +++ b/src/diffusers/pipelines/amused/pipeline_amused.py @@ -20,10 +20,18 @@ from ...image_processor import VaeImageProcessor from ...models import UVit2DModel, VQModel from ...schedulers import AmusedScheduler -from ...utils import replace_example_docstring +from ...utils import is_torch_xla_available, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -66,7 +74,9 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + self.vae_scale_factor = ( + 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8 + ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) @torch.no_grad() @@ -297,6 +307,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, timestep, latents) + if XLA_AVAILABLE: + xm.mark_step() + if output_type == "latent": output = latents else: diff --git a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py index c74275b414d4..7ac05b39c3a8 100644 --- a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py +++ b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py @@ -20,10 +20,18 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...models import UVit2DModel, VQModel from ...schedulers import AmusedScheduler -from ...utils import replace_example_docstring +from ...utils import is_torch_xla_available, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -81,7 +89,9 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + self.vae_scale_factor = ( + 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8 + ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) @torch.no_grad() @@ -323,6 +333,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, timestep, latents) + if XLA_AVAILABLE: + xm.mark_step() + if output_type == "latent": output = latents else: diff --git a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py index 24801e0ef977..d908c32745c2 100644 --- a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py +++ b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py @@ -21,10 +21,18 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...models import UVit2DModel, VQModel from ...schedulers import AmusedScheduler -from ...utils import replace_example_docstring +from ...utils import is_torch_xla_available, replace_example_docstring from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -89,7 +97,9 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + self.vae_scale_factor = ( + 2 ** (len(self.vqvae.config.block_out_channels) - 1) if getattr(self, "vqvae", None) else 8 + ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, @@ -354,6 +364,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, timestep, latents) + if XLA_AVAILABLE: + xm.mark_step() + if output_type == "latent": output = latents else: diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index cb6f50f43c4f..5c1d1e2ae0ba 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -34,6 +34,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -47,8 +48,16 @@ from .pipeline_output import AnimateDiffPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -139,7 +148,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt @@ -844,6 +853,9 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # 9. Post processing if output_type == "latent": video = latents diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py index 626e46acbf7f..90c66e9e1973 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py @@ -32,7 +32,7 @@ from ...models.lora import adjust_lora_scale_text_encoder from ...models.unets.unet_motion_model import MotionAdapter from ...schedulers import KarrasDiffusionSchedulers -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import is_compiled_module, randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin @@ -41,8 +41,16 @@ from .pipeline_output import AnimateDiffPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -180,7 +188,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) self.control_video_processor = VideoProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -1090,6 +1098,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # 9. Post processing if output_type == "latent": video = latents diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 6016917537b9..958eb5fb5134 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -48,6 +48,7 @@ ) from ...utils import ( USE_PEFT_BACKEND, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -60,8 +61,16 @@ from .pipeline_output import AnimateDiffPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -307,10 +316,14 @@ def __init__( feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt with num_images_per_prompt->num_videos_per_prompt def encode_prompt( @@ -438,7 +451,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -497,8 +512,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) @@ -1261,6 +1278,9 @@ def __call__( progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py index 6dde7d6686ee..42e0c6632632 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py @@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -42,8 +43,16 @@ from .pipeline_output import AnimateDiffPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```python @@ -188,7 +197,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -994,6 +1003,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # 11. Post processing if output_type == "latent": video = latents diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index b0adbea77445..edac6bfd9e4e 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -31,7 +31,7 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin @@ -40,8 +40,16 @@ from .pipeline_output import AnimateDiffPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -243,7 +251,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) def encode_prompt( @@ -1037,6 +1045,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # 10. Post-processing if output_type == "latent": video = latents diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py index 10a27af246f7..1a75d658b3ad 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py @@ -39,7 +39,7 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import is_compiled_module, randn_tensor from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin @@ -48,8 +48,16 @@ from .pipeline_output import AnimateDiffPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -270,7 +278,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) self.control_video_processor = VideoProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -1325,6 +1333,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # 11. Post-processing if output_type == "latent": video = latents diff --git a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py index 105ca40f773f..14c6d44fc586 100644 --- a/src/diffusers/pipelines/audioldm/pipeline_audioldm.py +++ b/src/diffusers/pipelines/audioldm/pipeline_audioldm.py @@ -22,13 +22,21 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers -from ...utils import logging, replace_example_docstring +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline, StableDiffusionMixin +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -94,7 +102,7 @@ def __init__( scheduler=scheduler, vocoder=vocoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 def _encode_prompt( self, @@ -530,6 +538,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # 8. Post-processing mel_spectrogram = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py index b45771d7de74..b8b5d07af529 100644 --- a/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py +++ b/src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py @@ -48,8 +48,20 @@ if is_librosa_available(): import librosa + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -207,7 +219,7 @@ def __init__( scheduler=scheduler, vocoder=vocoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing def enable_vae_slicing(self): @@ -225,7 +237,7 @@ def disable_vae_slicing(self): """ self.vae.disable_slicing() - def enable_model_cpu_offload(self, gpu_id=0): + def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` @@ -237,11 +249,23 @@ def enable_model_cpu_offload(self, gpu_id=0): else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") - device = torch.device(f"cuda:{gpu_id}") + torch_device = torch.device(device) + device_index = torch_device.index + + if gpu_id is not None and device_index is not None: + raise ValueError( + f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}" + f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}" + ) + + device_type = torch_device.type + device = torch.device(f"{device_type}:{gpu_id or torch_device.index}") if self.device.type != "cpu": self.to("cpu", silence_dtype_warnings=True) - torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + device_mod = getattr(torch, device.type, None) + if hasattr(device_mod, "empty_cache") and device_mod.is_available(): + device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist) model_sequence = [ self.text_encoder.text_model, @@ -1033,6 +1057,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + self.maybe_free_model_hooks() # 8. Post-processing diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index 0bb3fb7368d8..d3326c54973f 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -146,9 +146,7 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def check_inputs( diff --git a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py index ff23247b5f81..cbd8bef67945 100644 --- a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +++ b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py @@ -20,6 +20,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import PNDMScheduler from ...utils import ( + is_torch_xla_available, logging, replace_example_docstring, ) @@ -30,8 +31,16 @@ from .modeling_ctx_clip import ContextCLIPTextModel +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -336,6 +345,9 @@ def __call__( latents, )["prev_sample"] + if XLA_AVAILABLE: + xm.mark_step() + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index a1555402ccf6..d78d5508dc7f 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -26,12 +26,19 @@ from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler -from ...utils import logging, replace_example_docstring +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from .pipeline_output import CogVideoXPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -183,14 +190,12 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor_spatial = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 ) self.vae_scale_factor_temporal = ( - self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 - ) - self.vae_scaling_factor_image = ( - self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 ) + self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -755,6 +760,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # Discard any padding frames that were added for CogVideoX 1.5 latents = latents[:, additional_frames:] diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index e4c6ca1206fe..46e7b9ee468e 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -27,12 +27,19 @@ from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import KarrasDiffusionSchedulers -from ...utils import logging, replace_example_docstring +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from .pipeline_output import CogVideoXPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -190,14 +197,12 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) self.vae_scale_factor_spatial = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 ) self.vae_scale_factor_temporal = ( - self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 - ) - self.vae_scaling_factor_image = ( - self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 ) + self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -810,6 +815,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index 6842123ff798..58793902345a 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -29,6 +29,7 @@ from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import ( + is_torch_xla_available, logging, replace_example_docstring, ) @@ -37,6 +38,13 @@ from .pipeline_output import CogVideoXPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -203,14 +211,12 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor_spatial = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 ) self.vae_scale_factor_temporal = ( - self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 - ) - self.vae_scaling_factor_image = ( - self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 ) + self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -868,6 +874,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # Discard any padding frames that were added for CogVideoX 1.5 latents = latents[:, additional_frames:] diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 945f7694caae..333e3418dca2 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -27,12 +27,19 @@ from ...models.embeddings import get_3d_rotary_pos_embed from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler -from ...utils import logging, replace_example_docstring +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from .pipeline_output import CogVideoXPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -206,14 +213,12 @@ def __init__( ) self.vae_scale_factor_spatial = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 ) self.vae_scale_factor_temporal = ( - self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 - ) - self.vae_scaling_factor_image = ( - self.vae.config.scaling_factor if hasattr(self, "vae") and self.vae is not None else 0.7 + self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4 ) + self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) @@ -836,6 +841,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": video = self.decode_latents(latents) video = self.video_processor.postprocess_video(video=video, output_type=output_type) diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py index 8bed88c275cf..0cd3943fbcd2 100644 --- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py +++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py @@ -24,11 +24,18 @@ from ...models import AutoencoderKL, CogView3PlusTransformer2DModel from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler -from ...utils import logging, replace_example_docstring +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from .pipeline_output import CogView3PipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -153,9 +160,7 @@ def __init__( self.register_modules( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) @@ -656,6 +661,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 diff --git a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py index d2f67a698917..f0c71655e628 100644 --- a/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py +++ b/src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py @@ -19,6 +19,7 @@ from ...models import UNet2DModel from ...schedulers import CMStochasticIterativeScheduler from ...utils import ( + is_torch_xla_available, logging, replace_example_docstring, ) @@ -26,6 +27,13 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -263,6 +271,9 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, sample) + if XLA_AVAILABLE: + xm.mark_step() + # 6. Post-process image sample image = self.postprocess_image(sample, output_type=output_type) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 1ae4c8d492e5..214835062a05 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -254,7 +254,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py index 86e0ddef663e..88c387d48dd2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py @@ -21,6 +21,7 @@ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import PNDMScheduler from ...utils import ( + is_torch_xla_available, logging, replace_example_docstring, ) @@ -31,8 +32,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -401,6 +410,10 @@ def __call__( t, latents, )["prev_sample"] + + if XLA_AVAILABLE: + xm.mark_step() + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index fbc9844e29a7..73ffeeb5e79c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -30,6 +30,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -41,6 +42,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -224,7 +232,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -1294,6 +1302,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 1f3ac038581e..875dbed38c4d 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -32,6 +32,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -43,6 +44,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -223,7 +231,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True @@ -1476,6 +1484,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 4ec78c5b990f..38e63f56b2f3 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -60,6 +60,16 @@ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -264,7 +274,7 @@ def __init__( ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True @@ -406,7 +416,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -465,8 +477,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) @@ -1829,6 +1843,9 @@ def denoising_value_valid(dnv): step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 536c00ee361c..77d496cf831d 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -62,6 +62,16 @@ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -275,7 +285,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -415,7 +425,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -474,8 +486,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) @@ -1548,6 +1562,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 0c4b250af6e6..86588a5b3851 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -62,6 +62,16 @@ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -267,7 +277,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -408,7 +418,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -467,8 +479,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) @@ -1608,6 +1622,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py index 7012f3b95458..56f6c9149c6e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py @@ -60,6 +60,16 @@ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -246,7 +256,7 @@ def __init__( ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True @@ -388,7 +398,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -447,8 +459,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) @@ -1755,6 +1769,9 @@ def denoising_value_valid(dnv): step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # make sure the VAE is in float32 mode, as it overflows in float16 if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py index dcd885f7d604..a2e50d4f3e09 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py @@ -60,6 +60,17 @@ if is_invisible_watermark_available(): from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -257,7 +268,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -397,7 +408,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -456,8 +469,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) @@ -1454,6 +1469,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py index 95cf067fce12..d4409c54b01c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -61,6 +61,17 @@ if is_invisible_watermark_available(): from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -281,7 +292,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -422,7 +433,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -481,8 +494,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) @@ -1573,6 +1588,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py index 075df628d4f1..3d4b19ea552c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py @@ -178,7 +178,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 def prepare_text_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py index c8464f8108ea..f01c8cc4674d 100644 --- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py @@ -269,9 +269,7 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) self.default_sample_size = ( diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index 4e135f9391dd..7f85fcc1d90d 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -236,9 +236,7 @@ def __init__( image_encoder=image_encoder, feature_extractor=feature_extractor, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 @@ -406,9 +404,9 @@ def encode_prompt( negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. - negative_prompt_2 (`str` or `List[str]`, *optional*): + negative_prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and - `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index 5d5249922f8d..abefb844a8cc 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -230,9 +230,7 @@ def __init__( scheduler=scheduler, controlnet=controlnet, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_resize=True, do_convert_rgb=True, do_normalize=True ) @@ -412,9 +410,9 @@ def encode_prompt( negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. - negative_prompt_2 (`str` or `List[str]`, *optional*): + negative_prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and - `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py index ca10e65de8a4..901ca25c576c 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py @@ -30,6 +30,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -41,6 +42,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -178,7 +186,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -884,6 +892,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py index 326cfdab7be7..acf1f5489ec1 100644 --- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py @@ -54,6 +54,16 @@ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -196,7 +206,7 @@ def __init__( scheduler=scheduler, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -336,7 +346,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -395,8 +407,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) @@ -1074,6 +1088,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # manually for max memory savings if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: self.upcast_vae() diff --git a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py index bcd36c412b54..ed342f66804a 100644 --- a/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +++ b/src/diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py @@ -17,11 +17,18 @@ import torch -from ...utils import logging +from ...utils import is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -146,6 +153,9 @@ def __call__( # 2. compute previous audio sample: x_t -> t_t-1 audio = self.scheduler.step(model_output, t, audio).prev_sample + if XLA_AVAILABLE: + xm.mark_step() + audio = audio.clamp(-1, 1).float().cpu().numpy() audio = audio[:, :, :original_sample_size] diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index a3b967ed369b..1b424f5742f2 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -17,10 +17,19 @@ import torch from ...schedulers import DDIMScheduler +from ...utils import is_torch_xla_available from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + class DDIMPipeline(DiffusionPipeline): r""" Pipeline for image generation. @@ -143,6 +152,9 @@ def __call__( model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator ).prev_sample + if XLA_AVAILABLE: + xm.mark_step() + image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index bb03a8d66758..e58a53b5b7e8 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -17,10 +17,19 @@ import torch +from ...utils import is_torch_xla_available from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + class DDPMPipeline(DiffusionPipeline): r""" Pipeline for image generation. @@ -116,6 +125,9 @@ def __call__( # 2. compute previous image: x_t -> x_t-1 image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample + if XLA_AVAILABLE: + xm.mark_step() + image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py index f545b24bec5c..150978de6e5e 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py @@ -14,6 +14,7 @@ BACKENDS_MAPPING, is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -24,8 +25,16 @@ from .watermark import IFWatermarker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + if is_bs4_available(): from bs4 import BeautifulSoup @@ -735,6 +744,9 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) + if XLA_AVAILABLE: + xm.mark_step() + image = intermediate_images if output_type == "pil": diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py index 07017912575d..a92d7be6a11c 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py @@ -17,6 +17,7 @@ PIL_INTERPOLATION, is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -27,8 +28,16 @@ from .watermark import IFWatermarker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + if is_bs4_available(): from bs4 import BeautifulSoup @@ -856,6 +865,9 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) + if XLA_AVAILABLE: + xm.mark_step() + image = intermediate_images if output_type == "pil": diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py index 6685ba6d774a..b23ea39bb292 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py @@ -35,6 +35,16 @@ import ftfy +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -174,7 +184,7 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - if unet.config.in_channels != 6: + if unet is not None and unet.config.in_channels != 6: logger.warning( "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." ) @@ -974,6 +984,9 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) + if XLA_AVAILABLE: + xm.mark_step() + image = intermediate_images if output_type == "pil": diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py index 7fca0bc0443c..030821b789aa 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py @@ -17,6 +17,7 @@ PIL_INTERPOLATION, is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -27,8 +28,16 @@ from .watermark import IFWatermarker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + if is_bs4_available(): from bs4 import BeautifulSoup @@ -975,6 +984,9 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) + if XLA_AVAILABLE: + xm.mark_step() + image = intermediate_images if output_type == "pil": diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py index 4f04a1de2a6e..bdad9c29b18f 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py @@ -35,6 +35,16 @@ import ftfy +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -176,7 +186,7 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - if unet.config.in_channels != 6: + if unet is not None and unet.config.in_channels != 6: logger.warning( "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." ) @@ -1085,6 +1095,9 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) + if XLA_AVAILABLE: + xm.mark_step() + image = intermediate_images if output_type == "pil": diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py index 891963f2a904..012c4ca6d448 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py @@ -34,6 +34,16 @@ import ftfy +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -132,7 +142,7 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - if unet.config.in_channels != 6: + if unet is not None and unet.config.in_channels != 6: logger.warning( "It seems like you have loaded a checkpoint that shall not be used for super resolution from {unet.config._name_or_path} as it accepts {unet.config.in_channels} input channels instead of 6. Please make sure to pass a super resolution checkpoint as the `'unet'`: IFSuperResolutionPipeline.from_pretrained(unet=super_resolution_unet, ...)`." ) @@ -831,6 +841,9 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, intermediate_images) + if XLA_AVAILABLE: + xm.mark_step() + image = intermediate_images if output_type == "pil": diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py index cfd251a72b35..48c0aa4f6d76 100644 --- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py @@ -253,10 +253,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -284,7 +288,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py index 612e5d57dff2..fa70689d790d 100644 --- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -281,10 +281,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -312,7 +316,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py index 340abcf69c5e..1752540e8f79 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py @@ -213,10 +213,14 @@ def __init__( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -243,7 +247,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py index 5b77920a0c75..f9c9c37c4867 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py @@ -183,10 +183,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -213,7 +217,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py index 9e91986896bd..06db871daf62 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py @@ -121,7 +121,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py index be21900ab55a..d486a32f6a4c 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py @@ -143,7 +143,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py index 2978972200c7..509f25620950 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py @@ -365,7 +365,7 @@ def __init__( caption_generator=caption_generator, inverse_scheduler=inverse_scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py index c8dc18e2e8ac..4fb437958abd 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py @@ -76,7 +76,7 @@ def __init__( vae=vae, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 @torch.no_grad() def image_variation( diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 2212651fbb5b..0065279bc0b1 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -94,7 +94,7 @@ def __init__( vae=vae, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if self.text_unet is not None and ( diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 62d3e83a4790..7dfc7e961825 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -77,7 +77,7 @@ def __init__( vae=vae, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def _encode_prompt(self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt): diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index de4c2ac9b7f4..1d6771793f39 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -82,7 +82,7 @@ def __init__( vae=vae, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if self.text_unet is not None: diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py index 14321b5f33cf..cf5ebbce2ba8 100644 --- a/src/diffusers/pipelines/dit/pipeline_dit.py +++ b/src/diffusers/pipelines/dit/pipeline_dit.py @@ -24,10 +24,19 @@ from ...models import AutoencoderKL, DiTTransformer2DModel from ...schedulers import KarrasDiffusionSchedulers +from ...utils import is_torch_xla_available from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + class DiTPipeline(DiffusionPipeline): r""" Pipeline for image generation based on a Transformer backbone instead of a UNet. @@ -211,6 +220,9 @@ def __call__( # compute previous image: x_t -> x_t-1 latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample + if XLA_AVAILABLE: + xm.mark_step() + if guidance_scale > 1: latents, _ = latent_model_input.chunk(2, dim=0) else: diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 181f0269ce3e..f5716dc9c8ea 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -206,9 +206,7 @@ def __init__( image_encoder=image_encoder, feature_extractor=feature_extractor, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) @@ -667,7 +665,16 @@ def __call__( instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - will be used instead + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -711,6 +718,14 @@ def __call__( Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -775,7 +790,10 @@ def __call__( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) - do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt ( prompt_embeds, pooled_prompt_embeds, @@ -824,10 +842,10 @@ def __call__( image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index ac8474becb78..8aece8527556 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -82,6 +82,7 @@ """ +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, @@ -212,12 +213,8 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) - self.vae_latent_channels = ( - self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + self.vae_latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor( @@ -802,10 +799,10 @@ def __call__( image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py index 7001b19569f2..c386f41c8827 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -227,9 +227,7 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) @@ -809,10 +807,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py index a9ac1c72c6ed..192b690f69e5 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py @@ -258,15 +258,14 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor * 2, - vae_latent_channels=self.vae.config.latent_channels, + vae_latent_channels=latent_channels, do_normalize=False, do_binarize=True, do_convert_grayscale=True, @@ -985,10 +984,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 4c2d2a0a3db9..30e244bae000 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -229,9 +229,7 @@ def __init__( scheduler=scheduler, controlnet=controlnet, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) @@ -876,10 +874,10 @@ def __call__( image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 4c82d73f0379..d8aefc3942e9 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -227,9 +227,7 @@ def __init__( scheduler=scheduler, controlnet=controlnet, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) @@ -864,10 +862,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 85943b278dc6..bfc96eeb8dab 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -230,15 +230,14 @@ def __init__( controlnet=controlnet, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor * 2, - vae_latent_channels=self.vae.config.latent_channels, + vae_latent_channels=latent_channels, do_normalize=False, do_binarize=True, do_convert_grayscale=True, @@ -1017,10 +1016,10 @@ def __call__( ) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 723478ce724d..ed8623e31733 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -221,15 +221,14 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor * 2, - vae_latent_channels=self.vae.config.latent_channels, + vae_latent_channels=latent_channels, do_normalize=False, do_binarize=True, do_convert_grayscale=True, @@ -882,10 +881,10 @@ def __call__( image_seq_len = latents.shape[1] mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index 2b336fbdd472..a63ecdadbd0c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -211,9 +211,7 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) @@ -746,10 +744,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 15abdb90ebd0..2be8e75973ef 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -208,15 +208,14 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor * 2, - vae_latent_channels=self.vae.config.latent_channels, + vae_latent_channels=latent_channels, do_normalize=False, do_binarize=True, do_convert_grayscale=True, @@ -877,10 +876,10 @@ def __call__( image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index 3b0956a32da3..5c3d6ce611cc 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -23,15 +23,23 @@ from ...loaders import HunyuanVideoLoraLoaderMixin from ...models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import logging, replace_example_docstring +from ...utils import is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import HunyuanVideoPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```python @@ -184,12 +192,8 @@ def __init__( tokenizer_2=tokenizer_2, ) - self.vae_scale_factor_temporal = ( - self.vae.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 - ) - self.vae_scale_factor_spatial = ( - self.vae.spatial_compression_ratio if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _get_llama_prompt_embeds( @@ -671,6 +675,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor video = self.vae.decode(latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py index 6f542cb59f46..6a5cf298d2d4 100644 --- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -240,9 +240,7 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) self.default_sample_size = ( diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py index f528b60e6ed7..58d65a190d5b 100644 --- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -27,6 +27,7 @@ from ...schedulers import DDIMScheduler from ...utils import ( BaseOutput, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -35,8 +36,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -133,7 +142,7 @@ def __init__( unet=unet, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 # `do_resize=False` as we do custom resizing. self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) @@ -711,6 +720,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # 8. Post processing if output_type == "latent": video = latents diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py index b2041e101564..b5f4acf5c05a 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py @@ -22,6 +22,7 @@ from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDIMScheduler, DDPMScheduler from ...utils import ( + is_torch_xla_available, logging, replace_example_docstring, ) @@ -30,8 +31,16 @@ from .text_encoder import MultilingualCLIP +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -385,6 +394,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py index ef5241fee5d2..5d56efef9287 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py @@ -25,6 +25,7 @@ from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDIMScheduler from ...utils import ( + is_torch_xla_available, logging, replace_example_docstring, ) @@ -33,8 +34,16 @@ from .text_encoder import MultilingualCLIP +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -478,6 +487,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # 7. post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py index 778b6e314c0d..cce5f0b3d5bc 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py @@ -29,6 +29,7 @@ from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDIMScheduler from ...utils import ( + is_torch_xla_available, logging, replace_example_docstring, ) @@ -37,8 +38,16 @@ from .text_encoder import MultilingualCLIP +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -613,6 +622,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py index b5152d71cb6b..a348deef8b29 100644 --- a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py @@ -24,6 +24,7 @@ from ...schedulers import UnCLIPScheduler from ...utils import ( BaseOutput, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -31,8 +32,16 @@ from ..pipeline_utils import DiffusionPipeline +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -519,6 +528,9 @@ def __call__( prev_timestep=prev_timestep, ).prev_sample + if XLA_AVAILABLE: + xm.mark_step() + latents = self.prior.post_process_latents(latents) image_embeddings = latents diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py index 471db61556f5..a584674540d8 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py @@ -18,13 +18,21 @@ from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDPMScheduler -from ...utils import deprecate, logging, replace_example_docstring +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -296,6 +304,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py index 0130c3951b38..bada59080c7b 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py @@ -19,14 +19,23 @@ from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDPMScheduler from ...utils import ( + is_torch_xla_available, logging, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -297,6 +306,10 @@ def __call__( if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py index 12be1534c642..4f6c4188bd48 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py @@ -22,14 +22,23 @@ from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDPMScheduler from ...utils import ( + is_torch_xla_available, logging, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -358,6 +367,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # post-processing image = self.movq.decode(latents, force_not_quantize=True)["sample"] diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py index 899273a1a736..624748896911 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py @@ -21,13 +21,21 @@ from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDPMScheduler -from ...utils import deprecate, logging +from ...utils import deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -372,6 +380,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}" diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py index b5ba7a0011a1..482093a4bb29 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py @@ -25,13 +25,21 @@ from ... import __version__ from ...models import UNet2DConditionModel, VQModel from ...schedulers import DDPMScheduler -from ...utils import deprecate, logging +from ...utils import deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -526,6 +534,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # post-processing latents = mask_image[:1] * image[:1] + (1 - mask_image[:1]) * latents diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py index f2134b22b40b..d05a7fbdb1b8 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py @@ -7,6 +7,7 @@ from ...models import PriorTransformer from ...schedulers import UnCLIPScheduler from ...utils import ( + is_torch_xla_available, logging, replace_example_docstring, ) @@ -15,8 +16,16 @@ from ..pipeline_utils import DiffusionPipeline +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -524,6 +533,9 @@ def __call__( ) text_mask = callback_outputs.pop("text_mask", text_mask) + if XLA_AVAILABLE: + xm.mark_step() + latents = self.prior.post_process_latents(latents) image_embeddings = latents diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py index ec6509bb3cb5..56d326e26e6e 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py @@ -7,6 +7,7 @@ from ...models import PriorTransformer from ...schedulers import UnCLIPScheduler from ...utils import ( + is_torch_xla_available, logging, replace_example_docstring, ) @@ -15,8 +16,16 @@ from ..pipeline_utils import DiffusionPipeline +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -538,6 +547,9 @@ def __call__( prev_timestep=prev_timestep, ).prev_sample + if XLA_AVAILABLE: + xm.mark_step() + latents = self.prior.post_process_latents(latents) image_embeddings = latents diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py index 8dbae2a1909a..5309f94a53c8 100644 --- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py @@ -8,6 +8,7 @@ from ...schedulers import DDPMScheduler from ...utils import ( deprecate, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -15,8 +16,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -549,6 +558,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # post-processing if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py index 81c45c4fb6f8..fbdad79db445 100644 --- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py @@ -12,6 +12,7 @@ from ...schedulers import DDPMScheduler from ...utils import ( deprecate, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -19,8 +20,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -617,6 +626,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # post-processing if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors.py b/src/diffusers/pipelines/kolors/pipeline_kolors.py index 1d2d07572d68..99a8bf4e4ce9 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors.py @@ -188,12 +188,14 @@ def __init__( feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) def encode_prompt( self, diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index 6ddda7acf2a8..df94ec3f0f24 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -207,12 +207,14 @@ def __init__( feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) # Copied from diffusers.pipelines.kolors.pipeline_kolors.KolorsPipeline.encode_prompt def encode_prompt( diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index e985648abace..1c59ca7d6d7c 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -30,6 +30,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -40,6 +41,13 @@ from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -226,7 +234,7 @@ def __init__( " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt @@ -952,6 +960,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + denoised = denoised.to(prompt_embeds.dtype) if not output_type == "latent": image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py index d110cd464522..a3d9917d3376 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -29,6 +29,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -39,8 +40,16 @@ from ..stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -209,7 +218,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -881,6 +890,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + denoised = denoised.to(prompt_embeds.dtype) if not output_type == "latent": image = self.vae.decode(denoised / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index cd63637b6c2f..d079e71fe38e 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -25,10 +25,19 @@ from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...utils import is_torch_xla_available from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + class LDMTextToImagePipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using latent diffusion. @@ -202,6 +211,9 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + if XLA_AVAILABLE: + xm.mark_step() + # scale and decode the image latents with vae latents = 1 / self.vqvae.config.scaling_factor * latents image = self.vqvae.decode(latents).sample diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py index bb72b4d4eb8e..879722e6a0e2 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py @@ -15,11 +15,19 @@ LMSDiscreteScheduler, PNDMScheduler, ) -from ...utils import PIL_INTERPOLATION +from ...utils import PIL_INTERPOLATION, is_torch_xla_available from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + def preprocess(image): w, h = image.size w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 @@ -174,6 +182,9 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample + if XLA_AVAILABLE: + xm.mark_step() + # decode the image latents with the VQVAE image = self.vqvae.decode(latents).sample image = torch.clamp(image, -1.0, 1.0) diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 19c4a6d1ddf9..1b70650dfa11 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -30,8 +30,10 @@ from ...utils import ( BACKENDS_MAPPING, BaseOutput, + deprecate, is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -39,8 +41,16 @@ from ...video_processor import VideoProcessor +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + if is_bs4_available(): from bs4 import BeautifulSoup @@ -180,7 +190,7 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py @@ -836,7 +846,17 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if not output_type == "latents": + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latents": + deprecation_message = ( + "Passing `output_type='latents'` is deprecated. Please pass `output_type='latent'` instead." + ) + deprecate("output_type_latents", "1.0.0", deprecation_message, standard_warn=False) + output_type = "latent" + + if not output_type == "latent": video = self.decode_latents(latents, video_length, decode_chunk_size=14) video = self.video_processor.postprocess_video(video=video, output_type=output_type) else: diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py index ab68ffe33646..bdac47c47ade 100644 --- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py @@ -19,6 +19,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -29,8 +30,16 @@ from .pipeline_output import LEditsPPDiffusionPipelineOutput, LEditsPPInversionPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -359,10 +368,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -389,7 +402,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -1209,6 +1222,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # 8. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ @@ -1378,6 +1394,9 @@ def invert( progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + self.init_latents = xts[-1].expand(self.batch_size, -1, -1, -1) zs = zs.flip(0) self.zs = zs diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py index 137e0c742c09..cad7d8a66a08 100644 --- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py @@ -372,7 +372,7 @@ def __init__( feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if not isinstance(scheduler, DDIMScheduler) and not isinstance(scheduler, DPMSolverMultistepScheduler): @@ -384,7 +384,11 @@ def __init__( "The scheduler has been changed to DPMSolverMultistepScheduler." ) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index d65c0b1f6a8b..e04290b45754 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -677,10 +677,10 @@ def __call__( sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) mu = calculate_shift( video_sequence_length, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, @@ -769,7 +769,7 @@ def __call__( if not self.vae.config.timestep_conditioning: timestep = None else: - noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype) + noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype) if not isinstance(decode_timestep, list): decode_timestep = [decode_timestep] * batch_size if decode_noise_scale is None: diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index f8b6d4873a7c..b1dcc41d887e 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -747,10 +747,10 @@ def __call__( sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) mu = calculate_shift( video_sequence_length, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 0a59d98919f0..52bb6546031d 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -31,6 +31,7 @@ BACKENDS_MAPPING, is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -38,8 +39,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + if is_bs4_available(): from bs4 import BeautifulSoup @@ -874,6 +883,9 @@ def __call__( progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py index a602ba611ea5..e5cd62e35773 100644 --- a/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py +++ b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py @@ -37,6 +37,7 @@ ) from ...utils import ( BaseOutput, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -46,6 +47,13 @@ from .marigold_image_processing import MarigoldImageProcessor +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -174,7 +182,7 @@ def __init__( default_processing_resolution=default_processing_resolution, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.scale_invariant = scale_invariant self.shift_invariant = shift_invariant @@ -517,6 +525,9 @@ def __call__( noise, t, batch_pred_latent, generator=generator ).prev_sample # [B,4,h,w] + if XLA_AVAILABLE: + xm.mark_step() + pred_latents.append(batch_pred_latent) pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w] diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py index aa9ad36ffc35..22f155f92022 100644 --- a/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py +++ b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py @@ -36,6 +36,7 @@ ) from ...utils import ( BaseOutput, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -44,6 +45,13 @@ from .marigold_image_processing import MarigoldImageProcessor +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -161,7 +169,7 @@ def __init__( default_processing_resolution=default_processing_resolution, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.use_full_z_range = use_full_z_range self.default_denoising_steps = default_denoising_steps @@ -493,6 +501,9 @@ def __call__( noise, t, batch_pred_latent, generator=generator ).prev_sample # [B,4,h,w] + if XLA_AVAILABLE: + xm.mark_step() + pred_latents.append(batch_pred_latent) pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w] diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index aac4e32e33f0..435470064633 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -62,19 +62,6 @@ """ -def calculate_shift( - image_seq_len, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.16, -): - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - mu = image_seq_len * m + b - return mu - - # from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77 def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None): if linear_steps is None: diff --git a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py index 728635da6d4d..73837af7d429 100644 --- a/src/diffusers/pipelines/musicldm/pipeline_musicldm.py +++ b/src/diffusers/pipelines/musicldm/pipeline_musicldm.py @@ -42,8 +42,20 @@ if is_librosa_available(): import librosa + +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -111,7 +123,7 @@ def __init__( scheduler=scheduler, vocoder=vocoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 def _encode_prompt( self, @@ -603,6 +615,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + self.maybe_free_model_hooks() # 8. Post-processing diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py index 28c4f3d32b78..bc90073cba77 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py @@ -30,6 +30,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -42,6 +43,13 @@ from .pag_utils import PAGMixin +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -251,7 +259,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -1293,6 +1301,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py index 3ad9cbf45f0d..bc7a4b57affd 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py @@ -31,6 +31,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -43,6 +44,13 @@ from .pag_utils import PAGMixin +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -228,7 +236,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True @@ -1505,6 +1513,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py index 15a93357470f..83540885bfb2 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py @@ -62,6 +62,16 @@ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -280,7 +290,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -421,7 +431,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -480,8 +492,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) @@ -1560,6 +1574,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py index 19c26b98ba37..b84f5d555914 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py @@ -62,6 +62,16 @@ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -270,7 +280,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False @@ -413,7 +423,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -472,8 +484,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) @@ -1626,6 +1640,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # If we do sequential model offloading, let's offload unet and controlnet # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index dea1f12696b2..a6a8deb5883c 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -245,9 +245,7 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) self.default_sample_size = ( diff --git a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py index 3e84f44adcf7..62f634312ada 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_kolors.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_kolors.py @@ -202,12 +202,14 @@ def __init__( feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) self.set_pag_applied_layers(pag_applied_layers) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py index b2fbdd683e86..d927a7961a16 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py @@ -29,6 +29,7 @@ deprecate, is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -43,8 +44,16 @@ from .pag_utils import PAGMixin +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + if is_bs4_available(): from bs4 import BeautifulSoup @@ -172,7 +181,7 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) self.set_pag_applied_layers(pag_applied_layers) @@ -843,6 +852,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] if use_resolution_binning: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index 03662bb37158..2cdc1c70cdcc 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -30,6 +30,7 @@ BACKENDS_MAPPING, is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -43,8 +44,16 @@ from .pag_utils import PAGMixin +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + if is_bs4_available(): from bs4 import BeautifulSoup @@ -162,7 +171,11 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor = 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 8 + ) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) self.set_pag_applied_layers( @@ -170,6 +183,35 @@ def __init__( pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()), ) + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + def encode_prompt( self, prompt: Union[str, List[str]], @@ -863,6 +905,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if output_type == "latent": image = latents else: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_sd.py index 2e2d9afb9096..fc7dc3a83f27 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd.py @@ -27,6 +27,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -39,8 +40,16 @@ from .pag_utils import PAGMixin +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -250,10 +259,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -281,7 +294,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -1034,6 +1047,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py index d1b96e75574f..fde3e500a573 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py @@ -200,9 +200,7 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 @@ -377,9 +375,9 @@ def encode_prompt( negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. - negative_prompt_2 (`str` or `List[str]`, *optional*): + negative_prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and - `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py index 24e31fa4cfc7..d64582a26f7a 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py @@ -216,9 +216,7 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 @@ -393,9 +391,9 @@ def encode_prompt( negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. - negative_prompt_2 (`str` or `List[str]`, *optional*): + negative_prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and - `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py index 1e81fa3a158c..d3a015e569c1 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py @@ -26,6 +26,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -40,8 +41,16 @@ from .pag_utils import PAGMixin +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -147,7 +156,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) self.set_pag_applied_layers(pag_applied_layers) @@ -847,6 +856,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # 9. Post processing if output_type == "latent": video = latents diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py index 81db8caf16f0..d91c02b607a3 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py @@ -30,6 +30,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -42,8 +43,16 @@ from .pag_utils import PAGMixin +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -245,10 +254,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -276,7 +289,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -1066,6 +1079,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py index 800f512c061c..33abfb0be89f 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py @@ -28,6 +28,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -40,8 +41,16 @@ from .pag_utils import PAGMixin +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -277,10 +286,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -308,7 +321,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True @@ -1318,6 +1331,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": condition_kwargs = {} if isinstance(self.vae, AsymmetricAutoencoderKL): diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py index c2611164a049..856f6a3e789e 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -275,10 +275,14 @@ def __init__( feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -415,7 +419,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -474,8 +480,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index 6d634d524848..93dcca0ea9d6 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -298,7 +298,7 @@ def __init__( ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -436,7 +436,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -495,8 +497,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py index 7f85c13ac561..fdf3df2f4d6a 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -314,7 +314,7 @@ def __init__( ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True @@ -526,7 +526,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -585,8 +587,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index b225fd71edf8..55a9f47145a2 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -23,7 +23,7 @@ from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from ...utils import deprecate, logging +from ...utils import deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput @@ -31,6 +31,13 @@ from .image_encoder import PaintByExampleImageEncoder +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -209,7 +216,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -604,6 +611,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + self.maybe_free_model_hooks() if not output_type == "latent": diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index b7dfcd39edce..df8499ab900a 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -37,6 +37,7 @@ from ...utils import ( USE_PEFT_BACKEND, BaseOutput, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -48,8 +49,16 @@ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -195,7 +204,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt @@ -928,6 +937,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + # 9. Post processing if output_type == "latent": video = latents diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index be900ca4469b..527724d1de1a 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -411,6 +411,13 @@ def module_is_offloaded(module): pipeline_is_sequentially_offloaded = any( module_is_sequentially_offloaded(module) for _, module in self.components.items() ) + + is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 + if is_pipeline_device_mapped: + raise ValueError( + "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline." + ) + if device and torch.device(device).type == "cuda": if pipeline_is_sequentially_offloaded and not pipeline_has_bnb: raise ValueError( @@ -422,12 +429,6 @@ def module_is_offloaded(module): "You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation." ) - is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 - if is_pipeline_device_mapped: - raise ValueError( - "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`." - ) - # Display a warning in this case (the operation succeeds but the benefits are lost) pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items()) if pipeline_is_offloaded and device and torch.device(device).type == "cuda": diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 391b831166d2..46a7337051ef 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -29,6 +29,7 @@ deprecate, is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -36,8 +37,16 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + if is_bs4_available(): from bs4 import BeautifulSoup @@ -285,7 +294,7 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt @@ -943,6 +952,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] if use_resolution_binning: diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index 64e1e5bae06c..356ba3a29af3 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -29,6 +29,7 @@ deprecate, is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -41,8 +42,16 @@ ) +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + if is_bs4_available(): from bs4 import BeautifulSoup @@ -211,7 +220,7 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300 @@ -854,6 +863,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] if use_resolution_binning: diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index c90dec4d41b3..8b318597c12d 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -63,6 +63,49 @@ import ftfy +ASPECT_RATIO_4096_BIN = { + "0.25": [2048.0, 8192.0], + "0.26": [2048.0, 7936.0], + "0.27": [2048.0, 7680.0], + "0.28": [2048.0, 7424.0], + "0.32": [2304.0, 7168.0], + "0.33": [2304.0, 6912.0], + "0.35": [2304.0, 6656.0], + "0.4": [2560.0, 6400.0], + "0.42": [2560.0, 6144.0], + "0.48": [2816.0, 5888.0], + "0.5": [2816.0, 5632.0], + "0.52": [2816.0, 5376.0], + "0.57": [3072.0, 5376.0], + "0.6": [3072.0, 5120.0], + "0.68": [3328.0, 4864.0], + "0.72": [3328.0, 4608.0], + "0.78": [3584.0, 4608.0], + "0.82": [3584.0, 4352.0], + "0.88": [3840.0, 4352.0], + "0.94": [3840.0, 4096.0], + "1.0": [4096.0, 4096.0], + "1.07": [4096.0, 3840.0], + "1.13": [4352.0, 3840.0], + "1.21": [4352.0, 3584.0], + "1.29": [4608.0, 3584.0], + "1.38": [4608.0, 3328.0], + "1.46": [4864.0, 3328.0], + "1.67": [5120.0, 3072.0], + "1.75": [5376.0, 3072.0], + "2.0": [5632.0, 2816.0], + "2.09": [5888.0, 2816.0], + "2.4": [6144.0, 2560.0], + "2.5": [6400.0, 2560.0], + "2.89": [6656.0, 2304.0], + "3.0": [6912.0, 2304.0], + "3.11": [7168.0, 2304.0], + "3.62": [7424.0, 2048.0], + "3.75": [7680.0, 2048.0], + "3.88": [7936.0, 2048.0], + "4.0": [8192.0, 2048.0], +} + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -175,6 +218,35 @@ def __init__( ) self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor) + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + def encode_prompt( self, prompt: Union[str, List[str]], @@ -619,7 +691,7 @@ def __call__( negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - clean_caption: bool = True, + clean_caption: bool = False, use_resolution_binning: bool = True, attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, @@ -734,7 +806,9 @@ def __call__( # 1. Check inputs. Raise error if not correct if use_resolution_binning: - if self.transformer.config.sample_size == 64: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_4096_BIN + elif self.transformer.config.sample_size == 64: aspect_ratio_bin = ASPECT_RATIO_2048_BIN elif self.transformer.config.sample_size == 32: aspect_ratio_bin = ASPECT_RATIO_1024_BIN diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 6f83071f3e85..a8c374259349 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -9,12 +9,19 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging +from ...utils import deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import SemanticStableDiffusionPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -87,7 +94,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -701,6 +708,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # 8. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py index f87f28e06c4a..ef8a95daefa4 100644 --- a/src/diffusers/pipelines/shap_e/pipeline_shap_e.py +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e.py @@ -25,6 +25,7 @@ from ...schedulers import HeunDiscreteScheduler from ...utils import ( BaseOutput, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -33,8 +34,16 @@ from .renderer import ShapERenderer +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -291,6 +300,9 @@ def __call__( sample=latents, ).prev_sample + if XLA_AVAILABLE: + xm.mark_step() + # Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py index 7cc145e4c3e2..c0d1e38e0994 100644 --- a/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +++ b/src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py @@ -24,6 +24,7 @@ from ...schedulers import HeunDiscreteScheduler from ...utils import ( BaseOutput, + is_torch_xla_available, logging, replace_example_docstring, ) @@ -32,8 +33,16 @@ from .renderer import ShapERenderer +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -278,6 +287,9 @@ def __call__( sample=latents, ).prev_sample + if XLA_AVAILABLE: + xm.mark_step() + if output_type not in ["np", "pil", "latent", "mesh"]: raise ValueError( f"Only the output types `pil`, `np`, `latent` and `mesh` are supported not output_type={output_type}" diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py index 111ccc40c5a5..e3b9ec44005a 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py @@ -19,14 +19,22 @@ from ...models import StableCascadeUNet from ...schedulers import DDPMWuerstchenScheduler -from ...utils import is_torch_version, logging, replace_example_docstring +from ...utils import is_torch_version, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -503,6 +511,9 @@ def __call__( prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + if XLA_AVAILABLE: + xm.mark_step() + if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py index 058dbf6b0797..241c454e103e 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py @@ -23,13 +23,21 @@ from ...models import StableCascadeUNet from ...schedulers import DDPMWuerstchenScheduler -from ...utils import BaseOutput, logging, replace_example_docstring +from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] EXAMPLE_DOC_STRING = """ @@ -611,6 +619,9 @@ def __call__( prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + if XLA_AVAILABLE: + xm.mark_step() + # Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 9ecae6083eb6..eaeb5f809c47 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -132,10 +132,14 @@ def __init__( " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -162,7 +166,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 def prepare_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index ecfb8c16f62c..c2d918156084 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -165,7 +165,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 def prepare_inputs(self, prompt: Union[str, List[str]], image: Union[Image.Image, List[Image.Image]]): if not isinstance(prompt, (str, list)): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py index 338220ae3940..abcba926160a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py @@ -159,10 +159,14 @@ def __init__( " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -189,7 +193,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 def prepare_inputs( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 959c8135f73b..6e93c34929de 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -254,12 +254,15 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - self._is_unet_config_sample_size_int = isinstance(unet.config.sample_size, int) + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + self._is_unet_config_sample_size_int = unet is not None and isinstance(unet.config.sample_size, int) is_unet_sample_size_less_64 = ( - hasattr(unet.config, "sample_size") + unet is not None + and hasattr(unet.config, "sample_size") and self._is_unet_config_sample_size_int and unet.config.sample_size < 64 ) @@ -290,7 +293,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 9e758d91cadd..f158c41cac53 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -28,11 +28,26 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + PIL_INTERPOLATION, + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -115,10 +130,14 @@ def __init__( ): super().__init__() - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -145,7 +164,7 @@ def __init__( depth_estimator=depth_estimator, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt @@ -861,6 +880,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index fb80bb34b3ba..e0268065a415 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -24,13 +24,20 @@ from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging +from ...utils import deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -97,10 +104,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -126,7 +137,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -401,6 +412,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + self.maybe_free_model_hooks() if not output_type == "latent": diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index a1ae273add62..901dcd6db012 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -32,6 +32,7 @@ PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -43,8 +44,16 @@ from .safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -273,10 +282,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -304,7 +317,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -1120,6 +1133,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index db4c687f991d..6f4e7f358952 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -27,13 +27,27 @@ from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -215,10 +229,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -237,7 +255,7 @@ def __init__( unet._internal_dict = FrozenDict(new_config) # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 - if unet.config.in_channels != 9: + if unet is not None and unet.config.in_channels != 9: logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") self.register_modules( @@ -250,7 +268,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True @@ -1303,6 +1321,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": condition_kwargs = {} if isinstance(self.vae, AsymmetricAutoencoderKL): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 76b4f285b50f..7857bc58a8ad 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -165,7 +165,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index ffe02ae679e5..c6967bc393b5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -25,11 +25,18 @@ from ...loaders import FromSingleFileMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import EulerDiscreteScheduler -from ...utils import deprecate, logging +from ...utils import deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -116,7 +123,7 @@ def __init__( unet=unet, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") def _encode_prompt( @@ -640,6 +647,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 4cbbe17531ef..dae4540ebe00 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -30,12 +30,26 @@ ) from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -149,7 +163,7 @@ def __init__( watermarker=watermarker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") self.register_to_config(max_noise_level=max_noise_level) @@ -769,6 +783,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index 41811f8f2c0e..07d82251d4ba 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -28,6 +28,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -38,8 +39,16 @@ from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -154,7 +163,7 @@ def __init__( vae=vae, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline._encode_prompt with _encode_prompt->_encode_prior_prompt, tokenizer->prior_tokenizer, text_encoder->prior_text_encoder @@ -924,6 +933,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] else: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 2556d5e57b6d..eac9945ff349 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -28,6 +28,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -38,8 +39,16 @@ from .stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -155,7 +164,7 @@ def __init__( vae=vae, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt @@ -829,6 +838,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # 9. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 4ec0eb829b69..23950f895aae 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -215,9 +215,7 @@ def __init__( image_encoder=image_encoder, feature_extractor=feature_extractor, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 @@ -385,9 +383,9 @@ def encode_prompt( negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. - negative_prompt_2 (`str` or `List[str]`, *optional*): + negative_prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and - `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -1015,10 +1013,10 @@ def __call__( ) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) scheduler_kwargs["mu"] = mu elif mu is not None: diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 77daf5b0b4e0..b6e95844b3bd 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -226,10 +226,8 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) - latent_channels = self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16 + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 self.image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels ) @@ -402,9 +400,9 @@ def encode_prompt( negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. - negative_prompt_2 (`str` or `List[str]`, *optional*): + negative_prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and - `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -945,10 +943,10 @@ def __call__( ) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) scheduler_kwargs["mu"] = mu elif mu is not None: diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index e1cfdb3e6e97..67791c17a74b 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -225,10 +225,8 @@ def __init__( transformer=transformer, scheduler=scheduler, ) - self.vae_scale_factor = ( - 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 - ) - latent_channels = self.vae.config.latent_channels if hasattr(self, "vae") and self.vae is not None else 16 + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 self.image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, vae_latent_channels=latent_channels ) @@ -408,9 +406,9 @@ def encode_prompt( negative_prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. - negative_prompt_2 (`str` or `List[str]`, *optional*): + negative_prompt_3 (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and - `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders + `text_encoder_3`. If not defined, `negative_prompt` is used in all the text-encoders. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -1055,10 +1053,10 @@ def __call__( ) mu = calculate_shift( image_seq_len, - self.scheduler.config.base_image_seq_len, - self.scheduler.config.max_image_seq_len, - self.scheduler.config.base_shift, - self.scheduler.config.max_shift, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), ) scheduler_kwargs["mu"] = mu elif mu is not None: diff --git a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py index 2147d42a9f38..351b146fb423 100644 --- a/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py @@ -30,6 +30,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -41,6 +42,14 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + logger = logging.get_logger(__name__) EXAMPLE_DOC_STRING = """ @@ -242,7 +251,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -1008,6 +1017,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # 8. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py index 978ab165f891..4b999662a6e7 100644 --- a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -33,6 +33,7 @@ USE_PEFT_BACKEND, BaseOutput, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -44,6 +45,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -336,10 +344,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -367,7 +379,7 @@ def __init__( feature_extractor=feature_extractor, inverse_scheduler=inverse_scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -1508,6 +1520,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py index ce34691eba7c..4bbb93e44a83 100644 --- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py @@ -29,6 +29,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -40,8 +41,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -168,7 +177,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -828,6 +837,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) diff --git a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py index 3c147b64898d..86ef01784057 100644 --- a/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +++ b/src/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py @@ -32,7 +32,14 @@ from ...models.attention import GatedSelfAttentionDense from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from ..stable_diffusion import StableDiffusionPipelineOutput @@ -40,8 +47,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -226,7 +241,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -1010,6 +1025,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py index 664c0810d8cf..24e11bff3052 100755 --- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -125,7 +125,7 @@ def __init__( feature_extractor=feature_extractor, ) self.register_to_config(requires_safety_checker=requires_safety_checker) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) model = ModelWrapper(unet, scheduler.alphas_cumprod) diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py index 45f814fd538f..c7c5bd9cff67 100644 --- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py @@ -170,10 +170,14 @@ def __init__( scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) model = ModelWrapper(unet, scheduler.alphas_cumprod) if scheduler.config.prediction_type == "v_prediction": @@ -321,7 +325,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -380,8 +386,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py index a42c865317a9..702f3eda5816 100644 --- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py @@ -30,6 +30,7 @@ USE_PEFT_BACKEND, BaseOutput, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -40,8 +41,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```python @@ -254,7 +263,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -1002,6 +1011,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index e200a85f4b55..ccee6d47b47a 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -26,6 +26,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -37,8 +38,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -230,7 +239,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -1155,6 +1164,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if output_type != "latent": if circular_padding: image = self.decode_latents_with_padding(latents) diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index dc94ea960c8f..deae82eb8813 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -12,13 +12,20 @@ from ...loaders import IPAdapterMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers -from ...utils import deprecate, logging +from ...utils import deprecate, is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import StableDiffusionSafePipelineOutput from .safety_checker import SafeStableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -117,10 +124,14 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + is_unet_version_less_0_9_0 = ( + unet is not None + and hasattr(unet.config, "_diffusers_version") + and version.parse(version.parse(unet.config._diffusers_version).base_version) < version.parse("0.9.0.dev0") + ) + is_unet_sample_size_less_64 = ( + unet is not None and hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" @@ -149,7 +160,7 @@ def __init__( image_encoder=image_encoder, ) self._safety_text_concept = safety_concept - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.register_to_config(requires_safety_checker=requires_safety_checker) @property @@ -739,6 +750,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # 8. Post-processing image = self.decode_latents(latents) diff --git a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py index 06d463c98f6b..e96422073b19 100644 --- a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py @@ -27,6 +27,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -38,8 +39,16 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -157,7 +166,7 @@ def __init__( feature_extractor=feature_extractor, image_encoder=image_encoder, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -840,6 +849,9 @@ def get_map_size(module, input, output): step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index 77363b2546d7..eb1030f3bb9d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -65,7 +65,7 @@ def __init__( unet=unet, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 def prepare_inputs(self, prompt: Union[str, List[str]]): if not isinstance(prompt, (str, list)): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index d83fa6201117..9c69fe65fbdb 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -269,10 +269,14 @@ def __init__( feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -406,7 +410,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -465,8 +471,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 126f25a41adc..08d0b44d613d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -291,7 +291,7 @@ def __init__( ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -427,7 +427,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -486,8 +488,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index a378ae65eb30..920caf4d24a1 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -321,7 +321,7 @@ def __init__( ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True @@ -531,7 +531,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -590,8 +592,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index b59f2312726d..aaffe8efa730 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -199,9 +199,13 @@ def __init__( scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) self.is_cosxl_edit = is_cosxl_edit add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -333,7 +337,9 @@ def encode_prompt( ) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] prompt_embeds_list.append(prompt_embeds) @@ -385,7 +391,8 @@ def encode_prompt( output_hidden_states=True, ) # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index fb986075aeea..8c1af7863e63 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -24,14 +24,22 @@ from ...image_processor import PipelineImageInput from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel from ...schedulers import EulerDiscreteScheduler -from ...utils import BaseOutput, logging, replace_example_docstring +from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import is_compiled_module, randn_tensor from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -177,7 +185,7 @@ def __init__( scheduler=scheduler, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor) def _encode_image( @@ -600,6 +608,9 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # cast back to fp16 if needed if needs_upcasting: diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index ea7e99dafd51..8520a2e2b741 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -31,6 +31,7 @@ USE_PEFT_BACKEND, BaseOutput, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -41,6 +42,14 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + @dataclass class StableDiffusionAdapterPipelineOutput(BaseOutput): """ @@ -59,6 +68,7 @@ class StableDiffusionAdapterPipelineOutput(BaseOutput): logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -260,7 +270,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -915,6 +925,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if output_type == "latent": image = latents has_nsfw_concept = None diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index b51bedf7ee56..5eacb64d01e3 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -43,6 +43,7 @@ from ...utils import ( PIL_INTERPOLATION, USE_PEFT_BACKEND, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -53,8 +54,16 @@ from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -293,9 +302,13 @@ def __init__( image_encoder=image_encoder, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( @@ -423,7 +436,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -482,8 +497,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) @@ -1262,6 +1279,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": # make sure the VAE is in float32 mode, as it overflows in float16 needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index cdd72b97f86b..5c63d66e3133 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -25,6 +25,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -36,8 +37,16 @@ from . import TextToVideoSDPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -105,7 +114,7 @@ def __init__( unet=unet, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt @@ -627,6 +636,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # 8. Post processing if output_type == "latent": video = latents diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index 92bf1d388c13..006c7a79ce0d 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -26,6 +26,7 @@ from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -37,8 +38,16 @@ from . import TextToVideoSDPipelineOutput +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -140,7 +149,7 @@ def __init__( unet=unet, scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt @@ -679,6 +688,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # manually for max memory savings if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.unet.to("cpu") diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py index 11fef4f16c90..df85f470a80b 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py @@ -358,7 +358,7 @@ def __init__( " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def forward_loop(self, x_t0, t0, t1, generator): diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py index 9ff473cc3a38..339d5b3a6019 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py @@ -42,6 +42,16 @@ from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker +from ...utils import is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -409,10 +419,14 @@ def __init__( feature_extractor=feature_extractor, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = self.unet.config.sample_size + self.default_sample_size = ( + self.unet.config.sample_size + if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size") + else 128 + ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -705,7 +719,9 @@ def encode_prompt( prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] + if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2: + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: prompt_embeds = prompt_embeds.hidden_states[-2] else: @@ -764,8 +780,10 @@ def encode_prompt( uncond_input.input_ids.to(device), output_hidden_states=True, ) + # We are only ALWAYS interested in the pooled output of the final text encoder - negative_pooled_prompt_embeds = negative_prompt_embeds[0] + if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2: + negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] negative_prompt_embeds_list.append(negative_prompt_embeds) @@ -922,6 +940,10 @@ def backward_loop( progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + return latents.clone().detach() @torch.no_grad() diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip.py b/src/diffusers/pipelines/unclip/pipeline_unclip.py index 25c6739d8720..bf42d44f74c1 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip.py @@ -22,12 +22,19 @@ from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel from ...schedulers import UnCLIPScheduler -from ...utils import logging +from ...utils import is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .text_proj import UnCLIPTextProjModel +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -474,6 +481,9 @@ def __call__( noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample + if XLA_AVAILABLE: + xm.mark_step() + image = super_res_latents # done super res diff --git a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py index 2a0e7e90e4d2..8fa0a848f7e7 100644 --- a/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +++ b/src/diffusers/pipelines/unclip/pipeline_unclip_image_variation.py @@ -27,12 +27,19 @@ from ...models import UNet2DConditionModel, UNet2DModel from ...schedulers import UnCLIPScheduler -from ...utils import logging +from ...utils import is_torch_xla_available, logging from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .text_proj import UnCLIPTextProjModel +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -400,6 +407,9 @@ def __call__( noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator ).prev_sample + if XLA_AVAILABLE: + xm.mark_step() + image = super_res_latents # done super res diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py index 4f65caf4e610..66d7404fb9a5 100644 --- a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -18,7 +18,14 @@ from ...models import AutoencoderKL from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.outputs import BaseOutput from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -26,6 +33,13 @@ from .modeling_uvit import UniDiffuserModel +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -117,7 +131,7 @@ def __init__( scheduler=scheduler, ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.num_channels_latents = vae.config.latent_channels @@ -1378,6 +1392,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # 9. Post-processing image = None text = None diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py index b08421415b23..edc01f0d5c75 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py @@ -19,15 +19,23 @@ from transformers import CLIPTextModel, CLIPTokenizer from ...schedulers import DDPMWuerstchenScheduler -from ...utils import deprecate, logging, replace_example_docstring +from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_paella_vq_model import PaellaVQModel from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -413,6 +421,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if output_type not in ["pt", "np", "pil", "latent"]: raise ValueError( f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}" diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py index 92223ce993a6..8f6ba419721d 100644 --- a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py @@ -22,14 +22,22 @@ from ...loaders import StableDiffusionLoraLoaderMixin from ...schedulers import DDPMWuerstchenScheduler -from ...utils import BaseOutput, deprecate, logging, replace_example_docstring +from ...utils import BaseOutput, deprecate, is_torch_xla_available, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .modeling_wuerstchen_prior import WuerstchenPrior +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:] EXAMPLE_DOC_STRING = """ @@ -502,6 +510,9 @@ def __call__( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + # 10. Denormalize the latents latents = latents * self.config.latent_mean - self.config.latent_std diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index 35e5743fbcf0..9bbb5e4ca266 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -450,7 +450,7 @@ def __init__( def forward(self, inputs): weight = dequantize_gguf_tensor(self.weight) weight = weight.to(self.compute_dtype) - bias = self.bias.to(self.compute_dtype) + bias = self.bias.to(self.compute_dtype) if self.bias is not None else None output = torch.nn.functional.linear(inputs, weight, bias) return output diff --git a/src/diffusers/schedulers/scheduling_ddim_inverse.py b/src/diffusers/schedulers/scheduling_ddim_inverse.py index 6c2352f2c828..d9d9ae683ad0 100644 --- a/src/diffusers/schedulers/scheduling_ddim_inverse.py +++ b/src/diffusers/schedulers/scheduling_ddim_inverse.py @@ -266,7 +266,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps - # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 + # "leading" and "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 if self.config.timestep_spacing == "leading": step_ratio = self.config.num_train_timesteps // self.num_inference_steps # creates integer timesteps by multiplying by ratio diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 64b702bc0e32..f534637161fc 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -136,8 +136,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): sampling, and `solver_order=3` for unconditional sampling. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), - `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen - Video](https://imagen.research.google/video/paper.pdf) paper). + `sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`. thresholding (`bool`, defaults to `False`): Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such as Stable Diffusion. @@ -174,6 +174,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of `lambda(t)`. + use_flow_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use flow sigmas for step sizes in the noise schedule during the sampling process. + flow_shift (`float`, *optional*, defaults to 1.0): + The shift value for the timestep schedule for flow matching. final_sigmas_type (`str`, defaults to `"zero"`): The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index c7474d56c708..185c9fbabb89 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -54,11 +54,30 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): Args: num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. - timestep_spacing (`str`, defaults to `"linspace"`): - The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. shift (`float`, defaults to 1.0): The shift value for the timestep schedule. + use_dynamic_shifting (`bool`, defaults to False): + Whether to apply timestep shifting on-the-fly based on the image resolution. + base_shift (`float`, defaults to 0.5): + Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent + with desired output. + max_shift (`float`, defaults to 1.15): + Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be + more exaggerated or stylized. + base_image_seq_len (`int`, defaults to 256): + The base image sequence length. + max_image_seq_len (`int`, defaults to 4096): + The maximum image sequence length. + invert_sigmas (`bool`, defaults to False): + Whether to invert the sigmas. + shift_terminal (`float`, defaults to None): + The end value of the shifted timestep schedule. + use_karras_sigmas (`bool`, defaults to False): + Whether to use Karras sigmas for step sizes in the noise schedule during sampling. + use_exponential_sigmas (`bool`, defaults to False): + Whether to use exponential sigmas for step sizes in the noise schedule during sampling. + use_beta_sigmas (`bool`, defaults to False): + Whether to use beta sigmas for step sizes in the noise schedule during sampling. """ _compatibles = [] diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 3ee753e43a33..69bd520e1e6f 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -795,8 +795,8 @@ def test_modify_padding_mode(self): @nightly @require_torch_gpu @require_peft_backend -@unittest.skip("We cannot run inference on this model with the current CI hardware") -# TODO (DN6, sayakpaul): move these tests to a beefier GPU +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda class FluxLoRAIntegrationTests(unittest.TestCase): """internal note: The integration slices were obtained on audace. @@ -818,6 +818,7 @@ def setUp(self): def tearDown(self): super().tearDown() + del self.pipeline gc.collect() torch.cuda.empty_cache() @@ -825,7 +826,10 @@ def test_flux_the_last_ben(self): self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors") self.pipeline.fuse_lora() self.pipeline.unload_lora_weights() - self.pipeline.enable_model_cpu_offload() + # Instead of calling `enable_model_cpu_offload()`, we do a cuda placement here because the CI + # run supports it. We have about 34GB RAM in the CI runner which kills the test when run with + # `enable_model_cpu_offload()`. We repeat this for the other tests, too. + self.pipeline = self.pipeline.to(torch_device) prompt = "jon snow eating pizza with ketchup" @@ -847,7 +851,7 @@ def test_flux_kohya(self): self.pipeline.load_lora_weights("Norod78/brain-slug-flux") self.pipeline.fuse_lora() self.pipeline.unload_lora_weights() - self.pipeline.enable_model_cpu_offload() + self.pipeline = self.pipeline.to(torch_device) prompt = "The cat with a brain slug earring" out = self.pipeline( @@ -869,7 +873,7 @@ def test_flux_kohya_with_text_encoder(self): self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors") self.pipeline.fuse_lora() self.pipeline.unload_lora_weights() - self.pipeline.enable_model_cpu_offload() + self.pipeline = self.pipeline.to(torch_device) prompt = "optimus is cleaning the house with broomstick" out = self.pipeline( @@ -891,7 +895,7 @@ def test_flux_xlabs(self): self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors") self.pipeline.fuse_lora() self.pipeline.unload_lora_weights() - self.pipeline.enable_model_cpu_offload() + self.pipeline = self.pipeline.to(torch_device) prompt = "A blue jay standing on a large basket of rainbow macarons, disney style" diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 40383e3f1ee3..a789221e79a0 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -17,6 +17,7 @@ import unittest import numpy as np +import pytest import torch from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel @@ -31,9 +32,9 @@ from diffusers.utils.testing_utils import ( nightly, numpy_cosine_similarity_distance, + require_big_gpu_with_torch_cuda, require_peft_backend, require_torch_gpu, - slow, torch_device, ) @@ -128,11 +129,12 @@ def test_modify_padding_mode(self): pass -@slow @nightly @require_torch_gpu @require_peft_backend -class LoraSD3IntegrationTests(unittest.TestCase): +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda +class SD3LoraIntegrationTests(unittest.TestCase): pipeline_class = StableDiffusion3Img2ImgPipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" @@ -166,14 +168,17 @@ def get_inputs(self, device, seed=0): def test_sd3_img2img_lora(self): pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16) - pipe.load_lora_weights("zwloong/sd3-lora-training-rank16-v2", weight_name="pytorch_lora_weights.safetensors") - pipe.enable_sequential_cpu_offload() + pipe.load_lora_weights("zwloong/sd3-lora-training-rank16-v2") + pipe.fuse_lora() + pipe.unload_lora_weights() + pipe = pipe.to(torch_device) inputs = self.get_inputs(torch_device) image = pipe(**inputs).images[0] image_slice = image[0, -3:, -3:] - expected_slice = np.array([0.5396, 0.5776, 0.7432, 0.5151, 0.5586, 0.7383, 0.5537, 0.5933, 0.7153]) + expected_slice = np.array([0.5649, 0.5405, 0.5488, 0.5688, 0.5449, 0.5513, 0.5337, 0.5107, 0.5059]) + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) assert max_diff < 1e-4, f"Outputs are not close enough, got {max_diff}" diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py index c61ae1bdf0ff..77abe139d785 100644 --- a/tests/models/autoencoders/test_models_vq.py +++ b/tests/models/autoencoders/test_models_vq.py @@ -65,9 +65,11 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skip("Test not supported.") def test_forward_signature(self): pass + @unittest.skip("Test not supported.") def test_training(self): pass diff --git a/tests/models/transformers/test_models_prior.py b/tests/models/transformers/test_models_prior.py index d2ed10dfa1f6..471c1084c00c 100644 --- a/tests/models/transformers/test_models_prior.py +++ b/tests/models/transformers/test_models_prior.py @@ -132,7 +132,6 @@ def test_output_pretrained(self): output = model(**input)[0] output_slice = output[0, :5].flatten().cpu() - print(output_slice) # Since the VAE Gaussian prior's generator is seeded on the appropriate device, # the expected output slices are not the same for CPU and GPU. @@ -182,7 +181,6 @@ def test_kandinsky_prior(self, seed, expected_slice): assert list(sample.shape) == [1, 768] output_slice = sample[0, :8].flatten().cpu() - print(output_slice) expected_output_slice = torch.tensor(expected_slice) assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py index 4c13b54e0620..73b83b9eb514 100644 --- a/tests/models/transformers/test_models_transformer_cogvideox.py +++ b/tests/models/transformers/test_models_transformer_cogvideox.py @@ -71,7 +71,7 @@ def prepare_init_args_and_inputs_for_common(self): "out_channels": 4, "time_embed_dim": 2, "text_embed_dim": 8, - "num_layers": 1, + "num_layers": 2, "sample_width": 8, "sample_height": 8, "sample_frames": 8, @@ -130,7 +130,7 @@ def prepare_init_args_and_inputs_for_common(self): "out_channels": 4, "time_embed_dim": 2, "text_embed_dim": 8, - "num_layers": 1, + "num_layers": 2, "sample_width": 8, "sample_height": 8, "sample_frames": 8, diff --git a/tests/models/transformers/test_models_transformer_cogview3plus.py b/tests/models/transformers/test_models_transformer_cogview3plus.py index eda9813808e9..ec6c58a6734c 100644 --- a/tests/models/transformers/test_models_transformer_cogview3plus.py +++ b/tests/models/transformers/test_models_transformer_cogview3plus.py @@ -71,7 +71,7 @@ def prepare_init_args_and_inputs_for_common(self): init_dict = { "patch_size": 2, "in_channels": 4, - "num_layers": 1, + "num_layers": 2, "attention_head_dim": 4, "num_attention_heads": 2, "out_channels": 4, diff --git a/tests/models/unets/test_models_unet_1d.py b/tests/models/unets/test_models_unet_1d.py index 9f7ef3bca085..6eb7d3485c8b 100644 --- a/tests/models/unets/test_models_unet_1d.py +++ b/tests/models/unets/test_models_unet_1d.py @@ -51,9 +51,11 @@ def input_shape(self): def output_shape(self): return (4, 14, 16) + @unittest.skip("Test not supported.") def test_ema_training(self): pass + @unittest.skip("Test not supported.") def test_training(self): pass @@ -126,6 +128,7 @@ def test_output_pretrained(self): # fmt: on self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) + @unittest.skip("Test not supported.") def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass @@ -205,9 +208,11 @@ def test_output(self): expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1)) self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + @unittest.skip("Test not supported.") def test_ema_training(self): pass + @unittest.skip("Test not supported.") def test_training(self): pass @@ -265,6 +270,7 @@ def test_output_pretrained(self): # fmt: on self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) + @unittest.skip("Test not supported.") def test_forward_with_norm_groups(self): # Not implemented yet for this UNet pass diff --git a/tests/models/unets/test_models_unet_2d.py b/tests/models/unets/test_models_unet_2d.py index ddf5f53511f7..05bece23efd6 100644 --- a/tests/models/unets/test_models_unet_2d.py +++ b/tests/models/unets/test_models_unet_2d.py @@ -105,6 +105,35 @@ def test_mid_block_attn_groups(self): expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + def test_mid_block_none(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common() + mid_none_init_dict["mid_block_type"] = None + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + mid_none_model = self.model_class(**mid_none_init_dict) + mid_none_model.to(torch_device) + mid_none_model.eval() + + self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.") + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.to_tuple()[0] + + with torch.no_grad(): + mid_none_output = mid_none_model(**mid_none_inputs_dict) + + if isinstance(mid_none_output, dict): + mid_none_output = mid_none_output.to_tuple()[0] + + self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.") + def test_gradient_checkpointing_is_applied(self): expected_set = { "AttnUpBlock2D", @@ -354,6 +383,7 @@ def test_output_pretrained_ve_large(self): self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) + @unittest.skip("Test not supported.") def test_forward_with_norm_groups(self): # not required for this model pass diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 8ec5b6e9a5e4..57f6e4ee440b 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -175,8 +175,7 @@ def create_ip_adapter_plus_state_dict(model): ) ip_image_projection_state_dict = OrderedDict() - keys = [k for k in image_projection.state_dict() if "layers." in k] - print(keys) + for k, v in image_projection.state_dict().items(): if "2.to" in k: k = k.replace("2.to", "0.to") diff --git a/tests/models/unets/test_models_unet_controlnetxs.py b/tests/models/unets/test_models_unet_controlnetxs.py index 3025d7117f35..9431e810280f 100644 --- a/tests/models/unets/test_models_unet_controlnetxs.py +++ b/tests/models/unets/test_models_unet_controlnetxs.py @@ -320,6 +320,7 @@ def test_time_embedding_mixing(self): assert output.shape == output_mix_time.shape + @unittest.skip("Test not supported.") def test_forward_with_norm_groups(self): # UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups. pass diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py index fb550dd3219d..bf3ce2542d4e 100644 --- a/tests/pipelines/audioldm2/test_audioldm2.py +++ b/tests/pipelines/audioldm2/test_audioldm2.py @@ -469,8 +469,8 @@ def test_xformers_attention_forwardGenerator_pass(self): pass def test_dict_tuple_outputs_equivalent(self): - # increase tolerance from 1e-4 -> 2e-4 to account for large composite model - super().test_dict_tuple_outputs_equivalent(expected_max_difference=2e-4) + # increase tolerance from 1e-4 -> 3e-4 to account for large composite model + super().test_dict_tuple_outputs_equivalent(expected_max_difference=3e-4) def test_inference_batch_single_identical(self): # increase tolerance from 1e-4 -> 2e-4 to account for large composite model diff --git a/tests/pipelines/controlnet/test_flax_controlnet.py b/tests/pipelines/controlnet/test_flax_controlnet.py index bf5564e810ef..c71116dc7927 100644 --- a/tests/pipelines/controlnet/test_flax_controlnet.py +++ b/tests/pipelines/controlnet/test_flax_controlnet.py @@ -78,7 +78,7 @@ def test_canny(self): expected_slice = jnp.array( [0.167969, 0.116699, 0.081543, 0.154297, 0.132812, 0.108887, 0.169922, 0.169922, 0.205078] ) - print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 def test_pose(self): @@ -123,5 +123,5 @@ def test_pose(self): expected_slice = jnp.array( [[0.271484, 0.261719, 0.275391, 0.277344, 0.279297, 0.291016, 0.294922, 0.302734, 0.302734]] ) - print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index 8202424e7f15..5e856b125f32 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -32,9 +32,9 @@ from diffusers.utils import load_image from diffusers.utils.testing_utils import ( enable_full_determinism, + nightly, numpy_cosine_similarity_distance, require_big_gpu_with_torch_cuda, - slow, torch_device, ) from diffusers.utils.torch_utils import randn_tensor @@ -204,7 +204,7 @@ def test_flux_image_output_shape(self): assert (output_height, output_width) == (expected_height, expected_width) -@slow +@nightly @require_big_gpu_with_torch_cuda @pytest.mark.big_gpu_with_torch_cuda class FluxControlNetPipelineSlowTests(unittest.TestCase): @@ -230,8 +230,7 @@ def test_canny(self): text_encoder_2=None, controlnet=controlnet, torch_dtype=torch.bfloat16, - ) - pipe.enable_model_cpu_offload() + ).to(torch_device) pipe.set_progress_bar_config(disable=None) generator = torch.Generator(device="cpu").manual_seed(0) @@ -241,12 +240,12 @@ def test_canny(self): prompt_embeds = torch.load( hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt") - ) + ).to(torch_device) pooled_prompt_embeds = torch.load( hf_hub_download( repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt" ) - ) + ).to(torch_device) output = pipe( prompt_embeds=prompt_embeds, diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 7981e6c2a93b..addc29e14670 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -9,6 +9,7 @@ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel from diffusers.utils.testing_utils import ( + nightly, numpy_cosine_similarity_distance, require_big_gpu_with_torch_cuda, slow, @@ -208,8 +209,19 @@ def test_flux_image_output_shape(self): output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) + def test_flux_true_cfg(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + inputs.pop("generator") -@slow + no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + inputs["negative_prompt"] = "bad quality" + inputs["true_cfg_scale"] = 2.0 + true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + assert not np.allclose(no_true_cfg_out, true_cfg_out) + + +@nightly @require_big_gpu_with_torch_cuda @pytest.mark.big_gpu_with_torch_cuda class FluxPipelineSlowTests(unittest.TestCase): @@ -227,19 +239,16 @@ def tearDown(self): torch.cuda.empty_cache() def get_inputs(self, device, seed=0): - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device="cpu").manual_seed(seed) + generator = torch.Generator(device="cpu").manual_seed(seed) prompt_embeds = torch.load( hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt") - ) + ).to(torch_device) pooled_prompt_embeds = torch.load( hf_hub_download( repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt" ) - ) + ).to(torch_device) return { "prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, @@ -253,8 +262,7 @@ def get_inputs(self, device, seed=0): def test_flux_inference(self): pipe = self.pipeline_class.from_pretrained( self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None - ) - pipe.enable_model_cpu_offload() + ).to(torch_device) inputs = self.get_inputs(torch_device) diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py index 607a47e08e58..a7f861565cc9 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py @@ -308,8 +308,6 @@ def test_kandinsky(self): image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] - print(image_from_tuple_slice) - assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.0320, 0.0860, 0.4013, 0.0518, 0.2484, 0.5847, 0.4411, 0.2321, 0.4593]) diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py index effea2619749..4aa48a920fad 100644 --- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py +++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py @@ -146,7 +146,7 @@ def test_ledits_pp_inversion(self): ) latent_slice = sd_pipe.init_latents[0, -1, -3:, -3:].to(device) - print(latent_slice.flatten()) + expected_slice = np.array([-0.9084, -0.0367, 0.2940, 0.0839, 0.6890, 0.2651, -0.7104, 2.1090, -0.7822]) assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3 @@ -167,12 +167,12 @@ def test_ledits_pp_inversion_batch(self): ) latent_slice = sd_pipe.init_latents[0, -1, -3:, -3:].to(device) - print(latent_slice.flatten()) + expected_slice = np.array([0.2528, 0.1458, -0.2166, 0.4565, -0.5657, -1.0286, -0.9961, 0.5933, 1.1173]) assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3 latent_slice = sd_pipe.init_latents[1, -1, -3:, -3:].to(device) - print(latent_slice.flatten()) + expected_slice = np.array([-0.0796, 2.0583, 0.5501, 0.5358, 0.0282, -0.2803, -1.0470, 0.7023, -0.0072]) assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3 diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py index fcfd0aa51b9f..da694175a9f1 100644 --- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py +++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion_xl.py @@ -216,14 +216,14 @@ def test_ledits_pp_inversion_batch(self): ) latent_slice = sd_pipe.init_latents[0, -1, -3:, -3:].to(device) - print(latent_slice.flatten()) + expected_slice = np.array([0.2528, 0.1458, -0.2166, 0.4565, -0.5656, -1.0286, -0.9961, 0.5933, 1.1172]) assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3 latent_slice = sd_pipe.init_latents[1, -1, -3:, -3:].to(device) - print(latent_slice.flatten()) + expected_slice = np.array([-0.0796, 2.0583, 0.5500, 0.5358, 0.0282, -0.2803, -1.0470, 0.7024, -0.0072]) - print(latent_slice.flatten()) + assert np.abs(latent_slice.flatten() - expected_slice).max() < 1e-3 def test_ledits_pp_warmup_steps(self): diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index bbcf6d210ce5..c9df5785897c 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -17,15 +17,17 @@ import unittest import numpy as np +import pytest import torch from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel from diffusers.utils.testing_utils import ( enable_full_determinism, + nightly, numpy_cosine_similarity_distance, + require_big_gpu_with_torch_cuda, require_torch_gpu, - slow, torch_device, ) @@ -260,8 +262,10 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): ) -@slow +@nightly @require_torch_gpu +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda class MochiPipelineIntegrationTests(unittest.TestCase): prompt = "A painting of a squirrel eating a burger." @@ -293,7 +297,7 @@ def test_mochi(self): ).frames video = videos[0] - expected_video = torch.randn(1, 16, 480, 848, 3).numpy() + expected_video = torch.randn(1, 19, 480, 848, 3).numpy() max_diff = numpy_cosine_similarity_distance(video, expected_video) assert max_diff < 1e-3, f"Max diff is too high. got {video}" diff --git a/tests/pipelines/pag/test_pag_sd.py b/tests/pipelines/pag/test_pag_sd.py index 3979bb170e0b..17e3f7038439 100644 --- a/tests/pipelines/pag/test_pag_sd.py +++ b/tests/pipelines/pag/test_pag_sd.py @@ -318,7 +318,7 @@ def test_pag_cfg(self): image_slice = image[0, -3:, -3:, -1].flatten() assert image.shape == (1, 512, 512, 3) - print(image_slice.flatten()) + expected_slice = np.array( [0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484] ) @@ -339,7 +339,6 @@ def test_pag_uncond(self): expected_slice = np.array( [0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867] ) - print(image_slice.flatten()) assert ( np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 ), f"output is different from expected, {image_slice.flatten()}" diff --git a/tests/pipelines/pag/test_pag_sd_img2img.py b/tests/pipelines/pag/test_pag_sd_img2img.py index ec8cde23c31d..f44204f82486 100644 --- a/tests/pipelines/pag/test_pag_sd_img2img.py +++ b/tests/pipelines/pag/test_pag_sd_img2img.py @@ -255,7 +255,7 @@ def test_pag_cfg(self): image_slice = image[0, -3:, -3:, -1].flatten() assert image.shape == (1, 512, 512, 3) - print(image_slice.flatten()) + expected_slice = np.array( [0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484] ) @@ -276,7 +276,7 @@ def test_pag_uncond(self): expected_slice = np.array( [0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867] ) - print(image_slice.flatten()) + assert ( np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 ), f"output is different from expected, {image_slice.flatten()}" diff --git a/tests/pipelines/pag/test_pag_sd_inpaint.py b/tests/pipelines/pag/test_pag_sd_inpaint.py index cd175c600d47..a528b66cc72a 100644 --- a/tests/pipelines/pag/test_pag_sd_inpaint.py +++ b/tests/pipelines/pag/test_pag_sd_inpaint.py @@ -292,7 +292,7 @@ def test_pag_cfg(self): image_slice = image[0, -3:, -3:, -1].flatten() assert image.shape == (1, 512, 512, 3) - print(image_slice.flatten()) + expected_slice = np.array( [0.38793945, 0.4111328, 0.47924805, 0.39208984, 0.4165039, 0.41674805, 0.37060547, 0.36791992, 0.40625] ) diff --git a/tests/pipelines/sana/test_sana.py b/tests/pipelines/sana/test_sana.py index 21de4e04437a..7109a700403c 100644 --- a/tests/pipelines/sana/test_sana.py +++ b/tests/pipelines/sana/test_sana.py @@ -254,6 +254,36 @@ def test_attention_slicing_forward_pass( "Attention slicing should not affect the inference results", ) + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_sample_stride_height=64, + tile_sample_stride_width=64, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + # TODO(aryan): Create a dummy gemma model with smol vocab size @unittest.skip( "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error." diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py index b9b061c060c0..5690caa257b7 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py @@ -206,9 +206,6 @@ def test_stable_diffusion_pix2pix_euler(self): image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] - slice = [round(x, 4) for x in image_slice.flatten().tolist()] - print(",".join([str(x) for x in slice])) - assert image.shape == (1, 32, 32, 3) expected_slice = np.array([0.7417, 0.3842, 0.4732, 0.5776, 0.5891, 0.5139, 0.4052, 0.5673, 0.4986]) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py index dc855f44b817..9e4fa767085f 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax.py @@ -62,7 +62,7 @@ def test_stable_diffusion_flax(self): output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) expected_slice = jnp.array([0.4238, 0.4414, 0.4395, 0.4453, 0.4629, 0.4590, 0.4531, 0.45508, 0.4512]) - print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 @@ -104,5 +104,5 @@ def test_stable_diffusion_dpm_flax(self): output_slice = jnp.asarray(jax.device_get(image_slice.flatten())) expected_slice = jnp.array([0.4336, 0.42969, 0.4453, 0.4199, 0.4297, 0.4531, 0.4434, 0.4434, 0.4297]) - print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py index 8f039980ec24..eeec52dab51d 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_flax_inpaint.py @@ -78,5 +78,5 @@ def test_stable_diffusion_inpaint_pipeline(self): expected_slice = jnp.array( [0.3611307, 0.37649736, 0.3757408, 0.38213953, 0.39295167, 0.3841631, 0.41554978, 0.4137475, 0.4217084] ) - print(f"output_slice: {output_slice}") + assert jnp.abs(output_slice - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index 2091af9c0383..7c7b03786563 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -642,9 +642,6 @@ def test_adapter_sdxl_lcm(self): assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448]) - debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()] - print(",".join(debug)) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_adapter_sdxl_lcm_custom_timesteps(self): @@ -667,7 +664,4 @@ def test_adapter_sdxl_lcm_custom_timesteps(self): assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448]) - debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()] - print(",".join(debug)) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 764be1890cc5..f5494fbade2e 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1192,7 +1192,6 @@ def _test_inference_batch_consistent( logger.setLevel(level=diffusers.logging.WARNING) for batch_size, batched_input in zip(batch_sizes, batched_inputs): - print(batch_size, batched_input) output = pipe(**batched_input) assert len(output[0]) == batch_size diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py index 0caed159100a..a0e6e1417e67 100644 --- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py +++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py @@ -232,8 +232,10 @@ def test_inference_batch_single_identical(self): def test_float16_inference(self): super().test_float16_inference() + @unittest.skip(reason="Test not supported.") def test_callback_inputs(self): pass + @unittest.skip(reason="Test not supported.") def test_callback_cfg(self): pass diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index f474a1d4f4d0..b223c71cb5ce 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -372,7 +372,7 @@ def test_quality(self): output_type="np", ).images out_slice = output[0, -3:, -3:, -1].flatten() - expected_slice = np.array([0.0376, 0.0359, 0.0015, 0.0449, 0.0479, 0.0098, 0.0083, 0.0295, 0.0295]) + expected_slice = np.array([0.0674, 0.0623, 0.0364, 0.0632, 0.0671, 0.0430, 0.0317, 0.0493, 0.0583]) max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) self.assertTrue(max_diff < 1e-2) diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py index 8ac4c9915c27..8f768b10e846 100644 --- a/tests/quantization/gguf/test_gguf.py +++ b/tests/quantization/gguf/test_gguf.py @@ -6,6 +6,8 @@ import torch.nn as nn from diffusers import ( + AuraFlowPipeline, + AuraFlowTransformer2DModel, FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig, @@ -54,7 +56,8 @@ def test_gguf_linear_layers(self): for name, module in model.named_modules(): if isinstance(module, torch.nn.Linear) and hasattr(module.weight, "quant_type"): assert module.weight.dtype == torch.uint8 - assert module.bias.dtype == torch.float32 + if module.bias is not None: + assert module.bias.dtype == torch.float32 def test_gguf_memory_usage(self): quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) @@ -377,3 +380,79 @@ def test_pipeline_inference(self): ) max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) assert max_diff < 1e-4 + + +class AuraFlowGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase): + ckpt_path = "https://huggingface.co/city96/AuraFlow-v0.3-gguf/blob/main/aura_flow_0.3-Q2_K.gguf" + torch_dtype = torch.bfloat16 + model_cls = AuraFlowTransformer2DModel + expected_memory_use_in_gb = 4 + + def setUp(self): + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + + def get_dummy_inputs(self): + return { + "hidden_states": torch.randn((1, 4, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to( + torch_device, self.torch_dtype + ), + "encoder_hidden_states": torch.randn( + (1, 512, 2048), + generator=torch.Generator("cpu").manual_seed(0), + ).to(torch_device, self.torch_dtype), + "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), + } + + def test_pipeline_inference(self): + quantization_config = GGUFQuantizationConfig(compute_dtype=self.torch_dtype) + transformer = self.model_cls.from_single_file( + self.ckpt_path, quantization_config=quantization_config, torch_dtype=self.torch_dtype + ) + pipe = AuraFlowPipeline.from_pretrained( + "fal/AuraFlow-v0.3", transformer=transformer, torch_dtype=self.torch_dtype + ) + pipe.enable_model_cpu_offload() + + prompt = "a pony holding a sign that says hello" + output = pipe( + prompt=prompt, num_inference_steps=2, generator=torch.Generator("cpu").manual_seed(0), output_type="np" + ).images[0] + output_slice = output[:3, :3, :].flatten() + expected_slice = np.array( + [ + 0.46484375, + 0.546875, + 0.64453125, + 0.48242188, + 0.53515625, + 0.59765625, + 0.47070312, + 0.5078125, + 0.5703125, + 0.42773438, + 0.50390625, + 0.5703125, + 0.47070312, + 0.515625, + 0.57421875, + 0.45898438, + 0.48632812, + 0.53515625, + 0.4453125, + 0.5078125, + 0.56640625, + 0.47851562, + 0.5234375, + 0.57421875, + 0.48632812, + 0.5234375, + 0.56640625, + ] + ) + max_diff = numpy_cosine_similarity_distance(expected_slice, output_slice) + assert max_diff < 1e-4 diff --git a/tests/schedulers/test_scheduler_ddim_inverse.py b/tests/schedulers/test_scheduler_ddim_inverse.py index 696f57644a83..81d53f1b4778 100644 --- a/tests/schedulers/test_scheduler_ddim_inverse.py +++ b/tests/schedulers/test_scheduler_ddim_inverse.py @@ -1,3 +1,5 @@ +import unittest + import torch from diffusers import DDIMInverseScheduler @@ -95,6 +97,7 @@ def test_inference_steps(self): for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]): self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps) + @unittest.skip("Test not supported.") def test_add_noise_device(self): pass diff --git a/tests/schedulers/test_scheduler_deis.py b/tests/schedulers/test_scheduler_deis.py index 986a8f6a44cf..048bde51c366 100644 --- a/tests/schedulers/test_scheduler_deis.py +++ b/tests/schedulers/test_scheduler_deis.py @@ -1,4 +1,5 @@ import tempfile +import unittest import torch @@ -57,6 +58,7 @@ def check_over_configs(self, time_step=0, **config): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + @unittest.skip("Test not supported.") def test_from_save_pretrained(self): pass diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index 0b50538ae6a1..55b3202ad0be 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -1,4 +1,5 @@ import tempfile +import unittest import torch @@ -67,6 +68,7 @@ def check_over_configs(self, time_step=0, **config): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + @unittest.skip("Test not supported.") def test_from_save_pretrained(self): pass diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 393f544d9639..7cbaa5cc5e8d 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -1,4 +1,5 @@ import tempfile +import unittest import torch @@ -65,6 +66,7 @@ def check_over_configs(self, time_step=0, **config): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + @unittest.skip("Test not supported.") def test_from_save_pretrained(self): pass diff --git a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py index b5522f5991f7..e97d64ec5f1d 100644 --- a/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py +++ b/tests/schedulers/test_scheduler_edm_dpmsolver_multistep.py @@ -3,9 +3,7 @@ import torch -from diffusers import ( - EDMDPMSolverMultistepScheduler, -) +from diffusers import EDMDPMSolverMultistepScheduler from .test_schedulers import SchedulerCommonTest @@ -63,6 +61,7 @@ def check_over_configs(self, time_step=0, **config): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + @unittest.skip("Test not supported.") def test_from_save_pretrained(self): pass @@ -258,5 +257,6 @@ def test_duplicated_timesteps(self, **config): scheduler.set_timesteps(scheduler.config.num_train_timesteps) assert len(scheduler.timesteps) == scheduler.num_inference_steps + @unittest.skip("Test not supported.") def test_trained_betas(self): pass diff --git a/tests/schedulers/test_scheduler_flax.py b/tests/schedulers/test_scheduler_flax.py index d2ee7e13146d..fefad06fcf91 100644 --- a/tests/schedulers/test_scheduler_flax.py +++ b/tests/schedulers/test_scheduler_flax.py @@ -675,6 +675,7 @@ def check_over_configs(self, time_step=0, **config): assert jnp.sum(jnp.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + @unittest.skip("Test not supported.") def test_from_save_pretrained(self): pass diff --git a/tests/schedulers/test_scheduler_ipndm.py b/tests/schedulers/test_scheduler_ipndm.py index 87c8da3ee3c1..ac7973c58295 100644 --- a/tests/schedulers/test_scheduler_ipndm.py +++ b/tests/schedulers/test_scheduler_ipndm.py @@ -1,4 +1,5 @@ import tempfile +import unittest import torch @@ -50,6 +51,7 @@ def check_over_configs(self, time_step=0, **config): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + @unittest.skip("Test not supported.") def test_from_save_pretrained(self): pass diff --git a/tests/schedulers/test_scheduler_pndm.py b/tests/schedulers/test_scheduler_pndm.py index c1519f7c7e8e..13c690468222 100644 --- a/tests/schedulers/test_scheduler_pndm.py +++ b/tests/schedulers/test_scheduler_pndm.py @@ -1,4 +1,5 @@ import tempfile +import unittest import torch @@ -53,6 +54,7 @@ def check_over_configs(self, time_step=0, **config): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + @unittest.skip("Test not supported.") def test_from_save_pretrained(self): pass diff --git a/tests/schedulers/test_scheduler_sasolver.py b/tests/schedulers/test_scheduler_sasolver.py index d6d7c029b019..baa2736b2fcc 100644 --- a/tests/schedulers/test_scheduler_sasolver.py +++ b/tests/schedulers/test_scheduler_sasolver.py @@ -103,8 +103,6 @@ def test_full_loop_no_noise(self): elif torch_device in ["cuda"]: assert abs(result_sum.item() - 329.1999816894531) < 1e-2 assert abs(result_mean.item() - 0.4286458194255829) < 1e-3 - else: - print("None") def test_full_loop_with_v_prediction(self): scheduler_class = self.scheduler_classes[0] @@ -135,8 +133,6 @@ def test_full_loop_with_v_prediction(self): elif torch_device in ["cuda"]: assert abs(result_sum.item() - 193.4154052734375) < 1e-2 assert abs(result_mean.item() - 0.2518429756164551) < 1e-3 - else: - print("None") def test_full_loop_device(self): scheduler_class = self.scheduler_classes[0] @@ -166,8 +162,6 @@ def test_full_loop_device(self): elif torch_device in ["cuda"]: assert abs(result_sum.item() - 337.394287109375) < 1e-2 assert abs(result_mean.item() - 0.4393154978752136) < 1e-3 - else: - print("None") def test_full_loop_device_karras_sigmas(self): scheduler_class = self.scheduler_classes[0] @@ -198,8 +192,6 @@ def test_full_loop_device_karras_sigmas(self): elif torch_device in ["cuda"]: assert abs(result_sum.item() - 837.25537109375) < 1e-2 assert abs(result_mean.item() - 1.0901763439178467) < 1e-2 - else: - print("None") def test_beta_sigmas(self): self.check_over_configs(use_beta_sigmas=True) diff --git a/tests/schedulers/test_scheduler_unclip.py b/tests/schedulers/test_scheduler_unclip.py index b0ce1312e79f..9e66a328f42e 100644 --- a/tests/schedulers/test_scheduler_unclip.py +++ b/tests/schedulers/test_scheduler_unclip.py @@ -1,3 +1,5 @@ +import unittest + import torch from diffusers import UnCLIPScheduler @@ -130,8 +132,10 @@ def test_full_loop_skip_timesteps(self): assert abs(result_sum.item() - 258.2044983) < 1e-2 assert abs(result_mean.item() - 0.3362038) < 1e-3 + @unittest.skip("Test not supported.") def test_trained_betas(self): pass + @unittest.skip("Test not supported.") def test_add_noise_device(self): pass diff --git a/tests/schedulers/test_scheduler_vq_diffusion.py b/tests/schedulers/test_scheduler_vq_diffusion.py index 74437ad45480..c12825ba2e62 100644 --- a/tests/schedulers/test_scheduler_vq_diffusion.py +++ b/tests/schedulers/test_scheduler_vq_diffusion.py @@ -1,3 +1,5 @@ +import unittest + import torch import torch.nn.functional as F @@ -52,5 +54,6 @@ def test_time_indices(self): for t in [0, 50, 99]: self.check_over_forward(time_step=t) + @unittest.skip("Test not supported.") def test_add_noise_device(self): pass diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py index 0917bbe2b0d7..4e7bc0af6842 100644 --- a/tests/single_file/single_file_testing_utils.py +++ b/tests/single_file/single_file_testing_utils.py @@ -47,6 +47,8 @@ def download_diffusers_config(repo_id, tmpdir): class SDSingleFileTesterMixin: + single_file_kwargs = {} + def _compare_component_configs(self, pipe, single_file_pipe): for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items(): if param_name in ["torch_dtype", "architectures", "_name_or_path"]: @@ -154,7 +156,7 @@ def test_single_file_components_with_original_config_local_files_only( self._compare_component_configs(pipe, single_file_pipe) def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4): - sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None) + sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None, **self.single_file_kwargs) sf_pipe.unet.set_attn_processor(AttnProcessor()) sf_pipe.enable_model_cpu_offload(device=torch_device) @@ -170,7 +172,7 @@ def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_d max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten()) - assert max_diff < expected_max_diff + assert max_diff < expected_max_diff, f"{image.flatten()} != {image_single_file.flatten()}" def test_single_file_components_with_diffusers_config( self, diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py index dd15a5c7c071..78baeb94929c 100644 --- a/tests/single_file/test_stable_diffusion_single_file.py +++ b/tests/single_file/test_stable_diffusion_single_file.py @@ -132,6 +132,7 @@ class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCas "https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/refs/heads/main/configs/generate.yaml" ) repo_id = "timbrooks/instruct-pix2pix" + single_file_kwargs = {"extract_ema": True} def setUp(self): super().setUp()