Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ControNet Inpainting: use masked_image to create initial latents #5498

Closed
wants to merge 16 commits into from
22 changes: 21 additions & 1 deletion src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import PIL.Image
import torch
from PIL import Image
from PIL import Image, ImageFilter, ImageOps

from .configuration_utils import ConfigMixin, register_to_config
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
Expand Down Expand Up @@ -157,6 +157,26 @@ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:

return image

@staticmethod
def fill_mask(image: PIL.Image.Image, mask: PIL.Image.Image) -> PIL.Image.Image:
"""
fills masked regions with colors from image using blur. Not extremely effective.
"""

image_mod = Image.new("RGBA", (image.width, image.height))

image_masked = Image.new("RGBa", (image.width, image.height))
image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))

image_masked = image_masked.convert("RGBa")

for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert("RGBA")
for _ in range(repeats):
image_mod.alpha_composite(blurred)

return image_mod.convert("RGB")

def get_default_height_width(
self,
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
Expand Down
40 changes: 27 additions & 13 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,7 @@ def __call__(
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None,
masked_content: str = "original", # original, blank, fill
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -1077,6 +1078,13 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
masked_content(`str`, *optional*, defaults to `"original"`):
This option determines how the masked content on the original image would affect the generation
process. Choose from `"original"` or `"blank"`. If `"original"`, the entire image will be used to
create the initial latent, therefore the maksed content in will influence the result. If `"blank"`, the
masked image will be used to create initial latent, therefore the masked content will not have any
influence on the results. This option is only applicable when you use inpainting pipeline with
text-to-image Unet Model.

Examples:

Expand Down Expand Up @@ -1196,6 +1204,8 @@ def __call__(
assert False

# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
if masked_content == "fill":
image = self.image_processor.fill_mask(image, mask_image)
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)

Expand All @@ -1214,10 +1224,27 @@ def __call__(
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0

# 7. Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
batch_size * num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
)

# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
num_channels_unet = self.unet.config.in_channels
return_image_latents = num_channels_unet == 4
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved

if num_channels_unet == 4 and masked_content == "blank":
init_image = masked_image_latents.chunk(2)[0]

latents_outputs = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
Expand All @@ -1239,19 +1266,6 @@ def __call__(
else:
latents, noise = latents_outputs

# 7. Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
batch_size * num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
)

# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,7 @@ def __call__(
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None,
masked_content="original", # original, blank
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -1105,6 +1106,13 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
masked_content(`str`, *optional*, defaults to `"original"`):
This option determines how the masked content on the original image would affect the generation
process. Choose from `"original"` or `"blank"`. If `"original"`, the entire image will be used to
create the initial latent, therefore the maksed content in will influence the result. If `"blank"`, the
masked image will be used to create initial latent, therefore the masked content will not have any
influence on the results. This option is only applicable when you use inpainting pipeline with
text-to-image Unet Model.

Examples:

Expand Down Expand Up @@ -1268,10 +1276,25 @@ def denoising_value_valid(dnv):
masked_image = init_image * (mask < 0.5)
_, _, height, width = init_image.shape

# 7. Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
batch_size * num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
)

# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
num_channels_unet = self.unet.config.in_channels
return_image_latents = num_channels_unet == 4
if num_channels_unet == 4 and masked_content == "blank":
init_image = masked_image_latents.chunk(2)[0]

add_noise = True if denoising_start is None else False
latents_outputs = self.prepare_latents(
Expand All @@ -1296,19 +1319,6 @@ def denoising_value_valid(dnv):
else:
latents, noise = latents_outputs

# 7. Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
batch_size * num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
)

# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
# default case for runwayml/stable-diffusion-inpainting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ def __call__(
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: int = None,
masked_content="original", # original, blank
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -813,6 +814,14 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
masked_content(`str`, *optional*, defaults to `"original"`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
masked_content(`str`, *optional*, defaults to `"original"`):
masked_content (`str`, *optional*, defaults to `"original"`):

This option determines how the masked content on the original image would affect the generation
process. Choose from `"original"` or `"blank"`. If `"original"`, the entire image will be used to
create the initial latent, therefore the maksed content in will influence the result. If `"blank"`, the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
create the initial latent, therefore the maksed content in will influence the result. If `"blank"`, the
create the initial latent, therefore the masked content in will influence the result. If `"blank"`, the

masked image will be used to create initial latent, therefore the masked content will not have any
influence on the results. This option is only applicable when you use inpainting pipeline with
text-to-image Unet Model.

Examples:

```py
Expand Down Expand Up @@ -923,11 +932,34 @@ def __call__(
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)

# 7. Prepare mask latent variables
mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)

if masked_image_latents is None:
masked_image = init_image * (mask_condition < 0.5)
else:
masked_image = masked_image_latents

mask, masked_image_latents = self.prepare_mask_latents(
mask_condition,
masked_image,
batch_size * num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
)

# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
num_channels_unet = self.unet.config.in_channels
return_image_latents = num_channels_unet == 4

if num_channels_unet == 4 and masked_content == "blank":
init_image = masked_image_latents.chunk(2)[0]

latents_outputs = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
Expand All @@ -949,26 +981,6 @@ def __call__(
else:
latents, noise = latents_outputs

# 7. Prepare mask latent variables
mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width)

if masked_image_latents is None:
masked_image = init_image * (mask_condition < 0.5)
else:
masked_image = masked_image_latents

mask, masked_image_latents = self.prepare_mask_latents(
mask_condition,
masked_image,
batch_size * num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
)

# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
# default case for runwayml/stable-diffusion-inpainting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,7 @@ def __call__(
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None,
masked_content="original", # original, blank
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -1066,6 +1067,13 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
masked_content(`str`, *optional*, defaults to `"original"`):
This option determines how the masked content on the original image would affect the generation
process. Choose from `"original"` or `"blank"`. If `"original"`, the entire image will be used to
create the initial latent, therefore the maksed content in will influence the result. If `"blank"`, the
masked image will be used to create initial latent, therefore the masked content will not have any
influence on the results. This option is only applicable when you use inpainting pipeline with
text-to-image Unet Model.

Examples:

Expand Down Expand Up @@ -1165,10 +1173,25 @@ def denoising_value_valid(dnv):
else:
masked_image = init_image * (mask < 0.5)

# 7. Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
batch_size * num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
)

# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
num_channels_unet = self.unet.config.in_channels
return_image_latents = num_channels_unet == 4
if num_channels_unet == 4 and masked_content == "blank":
init_image = masked_image_latents.chunk(2)[0]

add_noise = True if denoising_start is None else False
latents_outputs = self.prepare_latents(
Expand All @@ -1193,19 +1216,6 @@ def denoising_value_valid(dnv):
else:
latents, noise = latents_outputs

# 7. Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
batch_size * num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
)

# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
# default case for runwayml/stable-diffusion-inpainting
Expand Down
2 changes: 1 addition & 1 deletion tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def test_controlnet_sdxl_guess(self):
output = sd_pipe(**inputs)
image_slice = output.images[0, -3:, -3:, -1]
expected_slice = np.array(
[0.5381963, 0.4836803, 0.45821992, 0.5577731, 0.51210403, 0.4794795, 0.59282357, 0.5647199, 0.43100584]
[0.5381064, 0.39685374, 0.4803775, 0.6868999, 0.5960682, 0.58983135, 0.59488124, 0.5560177, 0.5095338]
)

# make sure that it's equal
Expand Down
13 changes: 10 additions & 3 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def test_stable_diffusion_inpaint(self):
image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.4703, 0.5697, 0.3879, 0.5470, 0.6042, 0.4413, 0.5078, 0.4728, 0.4469])
expected_slice = np.array(
[0.7116542, 0.5356421, 0.5738175, 0.5997549, 0.6192689, 0.6549013, 0.60130644, 0.57260084, 0.5453409]
)

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

Expand Down Expand Up @@ -257,10 +259,12 @@ def test_stable_diffusion_inpaint_mask_latents(self):
masked_image = image * (mask < 0.5)

generator = torch.Generator(device=device).manual_seed(0)
torch.randn((1, 4, 32, 32), generator=generator)
image_latents = (
sd_pipe.vae.encode(image).latent_dist.sample(generator=generator) * sd_pipe.vae.config.scaling_factor
)
torch.randn((1, 4, 32, 32), generator=generator)

generator = torch.Generator(device=device).manual_seed(0)
mask_latents = (
sd_pipe.vae.encode(masked_image).latent_dist.sample(generator=generator)
* sd_pipe.vae.config.scaling_factor
Expand All @@ -271,6 +275,7 @@ def test_stable_diffusion_inpaint_mask_latents(self):
inputs["strength"] = 0.9
generator = torch.Generator(device=device).manual_seed(0)
torch.randn((1, 4, 32, 32), generator=generator)
torch.randn((1, 4, 32, 32), generator=generator)
inputs["generator"] = generator
out_1 = sd_pipe(**inputs).images
assert np.abs(out_0 - out_1).max() < 1e-2
Expand Down Expand Up @@ -372,7 +377,9 @@ def test_stable_diffusion_inpaint(self):
image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6584, 0.5424, 0.5649, 0.5449, 0.5897, 0.6111, 0.5404, 0.5463, 0.5214])
expected_slice = np.array(
[0.41372567, 0.41823727, 0.40892008, 0.5120787, 0.456928, 0.514932, 0.4214989, 0.45367384, 0.48013452]
)

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

Expand Down
Loading