Skip to content

Commit

Permalink
Add torch_xla and from_single_file to instruct-pix2pix (#10444)
Browse files Browse the repository at this point in the history
* Add torch_xla and from_single_file to instruct-pix2pix

* StableDiffusionInstructPix2PixPipelineSingleFileSlowTests

* StableDiffusionInstructPix2PixPipelineSingleFileSlowTests

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
  • Loading branch information
3 people authored Jan 6, 2025
1 parent 7747b58 commit 8f2253c
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 3 deletions.
8 changes: 8 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
}

DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
Expand Down Expand Up @@ -165,6 +166,7 @@
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
}

# Use to configure model sample size when original config is provided
Expand Down Expand Up @@ -633,6 +635,12 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["hunyuan-video"] in checkpoint:
model_type = "hunyuan-video"

elif (
CHECKPOINT_KEY_NAMES["instruct-pix2pix"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["instruct-pix2pix"]].shape[1] == 8
):
model_type = "instruct-pix2pix"

else:
model_type = "v1"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,23 @@

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging
from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker


if is_torch_xla_available():
import torch_xla.core.xla_model as xm

XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False

logger = logging.get_logger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -79,6 +86,7 @@ class StableDiffusionInstructPix2PixPipeline(
TextualInversionLoaderMixin,
StableDiffusionLoraLoaderMixin,
IPAdapterMixin,
FromSingleFileMixin,
):
r"""
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
Expand Down Expand Up @@ -457,6 +465,9 @@ def __call__(
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)

if XLA_AVAILABLE:
xm.mark_step()

if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
Expand Down
45 changes: 44 additions & 1 deletion tests/single_file/test_stable_diffusion_single_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

import torch

from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline
from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
nightly,
require_torch_accelerator,
slow,
torch_device,
Expand Down Expand Up @@ -118,3 +120,44 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0

def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)


@nightly
@slow
@require_torch_accelerator
class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/timbrooks/instruct-pix2pix/blob/main/instruct-pix2pix-00-22000.safetensors"
original_config = (
"https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/refs/heads/main/configs/generate.yaml"
)
repo_id = "timbrooks/instruct-pix2pix"

def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)

def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)

def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed)
image = load_image(
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_pix2pix/example.jpg"
)
inputs = {
"prompt": "turn him into a cyborg",
"image": image,
"generator": generator,
"num_inference_steps": 3,
"guidance_scale": 7.5,
"image_guidance_scale": 1.0,
"output_type": "np",
}
return inputs

def test_single_file_format_inference_is_same_as_pretrained(self):
super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3)

0 comments on commit 8f2253c

Please sign in to comment.