Skip to content

Commit

Permalink
LEditsPP - examples, check height/width, add tiling/slicing (#10471)
Browse files Browse the repository at this point in the history
* LEditsPP - examples, check height/width, add tiling/slicing

* make style
  • Loading branch information
hlky authored Jan 6, 2025
1 parent 6da6406 commit 2f25156
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,19 @@
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import PIL
>>> import requests
>>> import torch
>>> from io import BytesIO
>>> from diffusers import LEditsPPPipelineStableDiffusion
>>> from diffusers.utils import load_image
>>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
... "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
... )
>>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda")
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/cherry_blossom.png"
>>> image = load_image(img_url).convert("RGB")
>>> image = load_image(img_url).resize((512, 512))
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.1)
Expand Down Expand Up @@ -152,7 +150,7 @@ def __init__(self, device):

# The gaussian kernel is the product of the gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
Expand Down Expand Up @@ -706,6 +704,35 @@ def clip_skip(self):
def cross_attention_kwargs(self):
return self._cross_attention_kwargs

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()

def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()

def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()

def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand Down Expand Up @@ -1271,6 +1298,8 @@ def invert(
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
and respective VAE reconstruction(s).
"""
if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
raise ValueError("height and width must be a factor of 32.")
# Reset attn processor, we do not want to store attn maps during inversion
self.unet.set_attn_processor(AttnProcessor())

Expand Down Expand Up @@ -1360,6 +1389,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
image = self.image_processor.preprocess(
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
height, width = image.shape[-2:]
if height % 32 != 0 or width % 32 != 0:
raise ValueError(
"Image height and width must be a factor of 32. "
"Consider down-sampling the input using the `height` and `width` parameters"
)
resized = self.image_processor.postprocess(image=image, output_type="pil")

if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,25 +72,18 @@
Examples:
```py
>>> import torch
>>> import PIL
>>> import requests
>>> from io import BytesIO
>>> from diffusers import LEditsPPPipelineStableDiffusionXL
>>> from diffusers.utils import load_image
>>> pipe = LEditsPPPipelineStableDiffusionXL.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
... "stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", torch_dtype=torch.float16
... )
>>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda")
>>> def download_image(url):
... response = requests.get(url)
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
>>> img_url = "https://www.aiml.informatik.tu-darmstadt.de/people/mbrack/tennis.jpg"
>>> image = download_image(img_url)
>>> image = load_image(img_url).resize((1024, 1024))
>>> _ = pipe.invert(image=image, num_inversion_steps=50, skip=0.2)
Expand Down Expand Up @@ -197,7 +190,7 @@ def __init__(self, device):

# The gaussian kernel is the product of the gaussian function of each dimension.
kernel = 1
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size], indexing="ij")
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
mean = (size - 1) / 2
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
Expand Down Expand Up @@ -768,6 +761,35 @@ def denoising_end(self):
def num_timesteps(self):
return self._num_timesteps

def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()

def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()

def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()

def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()

# Copied from diffusers.pipelines.ledits_pp.pipeline_leditspp_stable_diffusion.LEditsPPPipelineStableDiffusion.prepare_unet
def prepare_unet(self, attention_store, PnP: bool = False):
attn_procs = {}
Expand Down Expand Up @@ -1401,6 +1423,12 @@ def encode_image(self, image, dtype=None, height=None, width=None, resize_mode="
image = self.image_processor.preprocess(
image=image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
height, width = image.shape[-2:]
if height % 32 != 0 or width % 32 != 0:
raise ValueError(
"Image height and width must be a factor of 32. "
"Consider down-sampling the input using the `height` and `width` parameters"
)
resized = self.image_processor.postprocess(image=image, output_type="pil")

if max(image.shape[-2:]) > self.vae.config["sample_size"] * 1.5:
Expand Down Expand Up @@ -1439,6 +1467,10 @@ def invert(
crops_coords_top_left: Tuple[int, int] = (0, 0),
num_zero_noise_steps: int = 3,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
resize_mode: Optional[str] = "default",
crops_coords: Optional[Tuple[int, int, int, int]] = None,
):
r"""
The function to the pipeline for image inversion as described by the [LEDITS++
Expand Down Expand Up @@ -1486,6 +1518,8 @@ def invert(
[`~pipelines.ledits_pp.LEditsPPInversionPipelineOutput`]: Output will contain the resized input image(s)
and respective VAE reconstruction(s).
"""
if height is not None and height % 32 != 0 or width is not None and width % 32 != 0:
raise ValueError("height and width must be a factor of 32.")

# Reset attn processor, we do not want to store attn maps during inversion
self.unet.set_attn_processor(AttnProcessor())
Expand All @@ -1510,7 +1544,14 @@ def invert(
do_classifier_free_guidance = source_guidance_scale > 1.0

# 1. prepare image
x0, resized = self.encode_image(image, dtype=self.text_encoder_2.dtype)
x0, resized = self.encode_image(
image,
dtype=self.text_encoder_2.dtype,
height=height,
width=width,
resize_mode=resize_mode,
crops_coords=crops_coords,
)
width = x0.shape[2] * self.vae_scale_factor
height = x0.shape[3] * self.vae_scale_factor
self.size = (height, width)
Expand Down

0 comments on commit 2f25156

Please sign in to comment.