diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 28a12f2d1364..bc7b54afa5f9 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -18,7 +18,7 @@ import numpy as np import PIL.Image import torch -from PIL import Image +from PIL import Image, ImageFilter, ImageOps from .configuration_utils import ConfigMixin, register_to_config from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate @@ -157,6 +157,26 @@ def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image: return image + @staticmethod + def fill_mask(image: PIL.Image.Image, mask: PIL.Image.Image) -> PIL.Image.Image: + """ + fills masked regions with colors from image using blur. Not extremely effective. + """ + + image_mod = Image.new("RGBA", (image.width, image.height)) + + image_masked = Image.new("RGBa", (image.width, image.height)) + image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L"))) + + image_masked = image_masked.convert("RGBa") + + for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]: + blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert("RGBA") + for _ in range(repeats): + image_mod.alpha_composite(blurred) + + return image_mod.convert("RGB") + def get_default_height_width( self, image: [PIL.Image.Image, np.ndarray, torch.Tensor], diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 2065343fe06c..a8d62f390a7d 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -982,6 +982,7 @@ def __call__( control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, clip_skip: Optional[int] = None, + masked_content: str = "original", # original, blank, fill ): r""" The call function to the pipeline for generation. @@ -1077,6 +1078,13 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + masked_content(`str`, *optional*, defaults to `"original"`): + This option determines how the masked content on the original image would affect the generation + process. Choose from `"original"` or `"blank"`. If `"original"`, the entire image will be used to + create the initial latent, therefore the maksed content in will influence the result. If `"blank"`, the + masked image will be used to create initial latent, therefore the masked content will not have any + influence on the results. This option is only applicable when you use inpainting pipeline with + text-to-image Unet Model. Examples: @@ -1196,6 +1204,8 @@ def __call__( assert False # 4. Preprocess mask and image - resizes image and mask w.r.t height and width + if masked_content == "fill": + image = self.image_processor.fill_mask(image, mask_image) init_image = self.image_processor.preprocess(image, height=height, width=width) init_image = init_image.to(dtype=torch.float32) @@ -1214,10 +1224,27 @@ def __call__( # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise is_strength_max = strength == 1.0 + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_unet = self.unet.config.in_channels return_image_latents = num_channels_unet == 4 + + if num_channels_unet == 4 and masked_content == "blank": + init_image = masked_image_latents.chunk(2)[0] + latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -1239,19 +1266,6 @@ def __call__( else: latents, noise = latents_outputs - # 7. Prepare mask latent variables - mask, masked_image_latents = self.prepare_mask_latents( - mask, - masked_image, - batch_size * num_images_per_prompt, - height, - width, - prompt_embeds.dtype, - device, - generator, - do_classifier_free_guidance, - ) - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 118fc0230e46..40c95333b898 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -976,6 +976,7 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, + masked_content="original", # original, blank ): r""" Function invoked when calling the pipeline for generation. @@ -1105,6 +1106,13 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + masked_content(`str`, *optional*, defaults to `"original"`): + This option determines how the masked content on the original image would affect the generation + process. Choose from `"original"` or `"blank"`. If `"original"`, the entire image will be used to + create the initial latent, therefore the maksed content in will influence the result. If `"blank"`, the + masked image will be used to create initial latent, therefore the masked content will not have any + influence on the results. This option is only applicable when you use inpainting pipeline with + text-to-image Unet Model. Examples: @@ -1268,10 +1276,25 @@ def denoising_value_valid(dnv): masked_image = init_image * (mask < 0.5) _, _, height, width = init_image.shape + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_unet = self.unet.config.in_channels return_image_latents = num_channels_unet == 4 + if num_channels_unet == 4 and masked_content == "blank": + init_image = masked_image_latents.chunk(2)[0] add_noise = True if denoising_start is None else False latents_outputs = self.prepare_latents( @@ -1296,19 +1319,6 @@ def denoising_value_valid(dnv): else: latents, noise = latents_outputs - # 7. Prepare mask latent variables - mask, masked_image_latents = self.prepare_mask_latents( - mask, - masked_image, - batch_size * num_images_per_prompt, - height, - width, - prompt_embeds.dtype, - device, - generator, - do_classifier_free_guidance, - ) - # 8. Check that sizes of mask, masked image and latents match if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 01000d8f37c9..fe5f708d605e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -738,6 +738,7 @@ def __call__( callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, + masked_content="original", # original, blank ): r""" The call function to the pipeline for generation. @@ -813,6 +814,14 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + masked_content(`str`, *optional*, defaults to `"original"`): + This option determines how the masked content on the original image would affect the generation + process. Choose from `"original"` or `"blank"`. If `"original"`, the entire image will be used to + create the initial latent, therefore the maksed content in will influence the result. If `"blank"`, the + masked image will be used to create initial latent, therefore the masked content will not have any + influence on the results. This option is only applicable when you use inpainting pipeline with + text-to-image Unet Model. + Examples: ```py @@ -923,11 +932,34 @@ def __call__( init_image = self.image_processor.preprocess(image, height=height, width=width) init_image = init_image.to(dtype=torch.float32) + # 7. Prepare mask latent variables + mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_unet = self.unet.config.in_channels return_image_latents = num_channels_unet == 4 + if num_channels_unet == 4 and masked_content == "blank": + init_image = masked_image_latents.chunk(2)[0] + latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -949,26 +981,6 @@ def __call__( else: latents, noise = latents_outputs - # 7. Prepare mask latent variables - mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width) - - if masked_image_latents is None: - masked_image = init_image * (mask_condition < 0.5) - else: - masked_image = masked_image_latents - - mask, masked_image_latents = self.prepare_mask_latents( - mask_condition, - masked_image, - batch_size * num_images_per_prompt, - height, - width, - prompt_embeds.dtype, - device, - generator, - do_classifier_free_guidance, - ) - # 8. Check that sizes of mask, masked image and latents match if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index fc942ca5227b..d7440fec8025 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -916,6 +916,7 @@ def __call__( aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, + masked_content="original", # original, blank ): r""" Function invoked when calling the pipeline for generation. @@ -1066,6 +1067,13 @@ def __call__( clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + masked_content(`str`, *optional*, defaults to `"original"`): + This option determines how the masked content on the original image would affect the generation + process. Choose from `"original"` or `"blank"`. If `"original"`, the entire image will be used to + create the initial latent, therefore the maksed content in will influence the result. If `"blank"`, the + masked image will be used to create initial latent, therefore the masked content will not have any + influence on the results. This option is only applicable when you use inpainting pipeline with + text-to-image Unet Model. Examples: @@ -1165,10 +1173,25 @@ def denoising_value_valid(dnv): else: masked_image = init_image * (mask < 0.5) + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_unet = self.unet.config.in_channels return_image_latents = num_channels_unet == 4 + if num_channels_unet == 4 and masked_content == "blank": + init_image = masked_image_latents.chunk(2)[0] add_noise = True if denoising_start is None else False latents_outputs = self.prepare_latents( @@ -1193,19 +1216,6 @@ def denoising_value_valid(dnv): else: latents, noise = latents_outputs - # 7. Prepare mask latent variables - mask, masked_image_latents = self.prepare_mask_latents( - mask, - masked_image, - batch_size * num_images_per_prompt, - height, - width, - prompt_embeds.dtype, - device, - generator, - do_classifier_free_guidance, - ) - # 8. Check that sizes of mask, masked image and latents match if num_channels_unet == 9: # default case for runwayml/stable-diffusion-inpainting diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py index 0ac8996fe0ef..a5d79a166fa0 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py @@ -290,7 +290,7 @@ def test_controlnet_sdxl_guess(self): output = sd_pipe(**inputs) image_slice = output.images[0, -3:, -3:, -1] expected_slice = np.array( - [0.5381963, 0.4836803, 0.45821992, 0.5577731, 0.51210403, 0.4794795, 0.59282357, 0.5647199, 0.43100584] + [0.5381064, 0.39685374, 0.4803775, 0.6868999, 0.5960682, 0.58983135, 0.59488124, 0.5560177, 0.5095338] ) # make sure that it's equal diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index e485bc9123b0..a2a579d859a1 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -197,7 +197,9 @@ def test_stable_diffusion_inpaint(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4703, 0.5697, 0.3879, 0.5470, 0.6042, 0.4413, 0.5078, 0.4728, 0.4469]) + expected_slice = np.array( + [0.7116542, 0.5356421, 0.5738175, 0.5997549, 0.6192689, 0.6549013, 0.60130644, 0.57260084, 0.5453409] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -257,10 +259,12 @@ def test_stable_diffusion_inpaint_mask_latents(self): masked_image = image * (mask < 0.5) generator = torch.Generator(device=device).manual_seed(0) + torch.randn((1, 4, 32, 32), generator=generator) image_latents = ( sd_pipe.vae.encode(image).latent_dist.sample(generator=generator) * sd_pipe.vae.config.scaling_factor ) - torch.randn((1, 4, 32, 32), generator=generator) + + generator = torch.Generator(device=device).manual_seed(0) mask_latents = ( sd_pipe.vae.encode(masked_image).latent_dist.sample(generator=generator) * sd_pipe.vae.config.scaling_factor @@ -271,6 +275,7 @@ def test_stable_diffusion_inpaint_mask_latents(self): inputs["strength"] = 0.9 generator = torch.Generator(device=device).manual_seed(0) torch.randn((1, 4, 32, 32), generator=generator) + torch.randn((1, 4, 32, 32), generator=generator) inputs["generator"] = generator out_1 = sd_pipe(**inputs).images assert np.abs(out_0 - out_1).max() < 1e-2 @@ -372,7 +377,9 @@ def test_stable_diffusion_inpaint(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.6584, 0.5424, 0.5649, 0.5449, 0.5897, 0.6111, 0.5404, 0.5463, 0.5214]) + expected_slice = np.array( + [0.41372567, 0.41823727, 0.40892008, 0.5120787, 0.456928, 0.514932, 0.4214989, 0.45367384, 0.48013452] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index 1e726b95960f..f08139502e12 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -139,8 +139,9 @@ def test_stable_diffusion_inpaint(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.4727, 0.5735, 0.3941, 0.5446, 0.5926, 0.4394, 0.5062, 0.4654, 0.4476]) - + expected_slice = np.array( + [0.71828955, 0.5354332, 0.58167696, 0.5995904, 0.6197585, 0.6600398, 0.6049546, 0.5725968, 0.54803497] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_inference_batch_single_identical(self): diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index 7e3698d8ca16..7390319f94c1 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -193,7 +193,9 @@ def test_stable_diffusion_xl_inpaint_euler(self): assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.8029, 0.5523, 0.5825, 0.6003, 0.6702, 0.7018, 0.6369, 0.5955, 0.5123]) + expected_slice = np.array( + [0.47297692, 0.43896025, 0.40697068, 0.45582378, 0.4125126, 0.5327985, 0.38261384, 0.47744697, 0.47403562] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -290,7 +292,9 @@ def test_stable_diffusion_xl_refiner(self): assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.7045, 0.4838, 0.5454, 0.6270, 0.6168, 0.6717, 0.6484, 0.5681, 0.4922]) + expected_slice = np.array( + [0.51198494, 0.49192587, 0.414692, 0.4513318, 0.44658783, 0.5387383, 0.39105427, 0.49726686, 0.46108854] + ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -547,8 +551,10 @@ def test_stable_diffusion_xl_inpaint_mask_latents(self): masked_image = image * (mask < 0.5) generator = torch.Generator(device=device).manual_seed(0) - image_latents = sd_pipe._encode_vae_image(image, generator=generator) torch.randn((1, 4, 32, 32), generator=generator) + image_latents = sd_pipe._encode_vae_image(image, generator=generator) + + generator = torch.Generator(device=device).manual_seed(0) mask_latents = sd_pipe._encode_vae_image(masked_image, generator=generator) inputs["image"] = image_latents inputs["masked_image_latents"] = mask_latents @@ -556,6 +562,7 @@ def test_stable_diffusion_xl_inpaint_mask_latents(self): inputs["strength"] = 0.9 generator = torch.Generator(device=device).manual_seed(0) torch.randn((1, 4, 32, 32), generator=generator) + torch.randn((1, 4, 32, 32), generator=generator) inputs["generator"] = generator out_1 = sd_pipe(**inputs).images assert np.abs(out_0 - out_1).max() < 1e-2