From b2c9b5469a06c577651811d2dcffabcd10b256ac Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 7 Oct 2022 17:01:51 +0200 Subject: [PATCH] [img2img, inpainting] fix fp16 inference (#769) * handle dtype in vae and image2image pipeline * fix inpaint in fp16 * dtype should be handled in add_noise * style * address review comments * add simple fast tests to check fp16 * fix test name * put mask in fp16 --- src/diffusers/models/vae.py | 8 +- .../pipeline_stable_diffusion_img2img.py | 46 +++---- .../pipeline_stable_diffusion_inpaint.py | 74 +++++------ tests/test_pipelines.py | 118 ++++++++++++++++++ 4 files changed, 186 insertions(+), 60 deletions(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index fe89b41c074e..7ce2f98eee27 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -337,12 +337,16 @@ def __init__(self, parameters, deterministic=False): self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) if self.deterministic: - self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: device = self.parameters.device sample_device = "cpu" if device.type == "mps" else device - sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device) + sample = torch.randn(self.mean.shape, generator=generator, device=sample_device) + # make sure sample is on the same device as the parameters and has same dtype + sample = sample.to(device=device, dtype=self.parameters.dtype) x = self.mean + self.std * sample return x diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 15bdd0208825..72e15f4f904b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -217,26 +217,6 @@ def __call__( if isinstance(init_image, PIL.Image.Image): init_image = preprocess(init_image) - # encode the init image into latents and scale the latents - init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - - # expand init_latents for batch_size - init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) - # get prompt text embeddings text_inputs = self.tokenizer( prompt, @@ -297,6 +277,28 @@ def __call__( # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + # encode the init image into latents and scale the latents + latents_dtype = text_embeddings.dtype + init_image = init_image.to(device=self.device, dtype=latents_dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # expand init_latents for batch_size + init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 @@ -341,7 +343,9 @@ def __call__( image = image.cpu().permute(0, 2, 3, 1).numpy() safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + ) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 24f4bc99bddc..30a588e754b3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -234,43 +234,6 @@ def __call__( # set timesteps self.scheduler.set_timesteps(num_inference_steps) - # preprocess image - if not isinstance(init_image, torch.FloatTensor): - init_image = preprocess_image(init_image) - init_image = init_image.to(self.device) - - # encode the init image into latents and scale the latents - init_latent_dist = self.vae.encode(init_image).latent_dist - init_latents = init_latent_dist.sample(generator=generator) - - init_latents = 0.18215 * init_latents - - # Expand init_latents for batch_size and num_images_per_prompt - init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) - init_latents_orig = init_latents - - # preprocess mask - if not isinstance(mask_image, torch.FloatTensor): - mask_image = preprocess_mask(mask_image) - mask_image = mask_image.to(self.device) - mask = torch.cat([mask_image] * batch_size * num_images_per_prompt) - - # check sizes - if not mask.shape == init_latents.shape: - raise ValueError("The mask and init_image should be the same size!") - - # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) - - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) - - # add noise to latents using the timesteps - noise = torch.randn(init_latents.shape, generator=generator, device=self.device) - init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) - # get prompt text embeddings text_inputs = self.tokenizer( prompt, @@ -335,6 +298,43 @@ def __call__( # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + # preprocess image + if not isinstance(init_image, torch.FloatTensor): + init_image = preprocess_image(init_image) + + # encode the init image into latents and scale the latents + latents_dtype = text_embeddings.dtype + init_image = init_image.to(device=self.device, dtype=latents_dtype) + init_latent_dist = self.vae.encode(init_image).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0) + init_latents_orig = init_latents + + # preprocess mask + if not isinstance(mask_image, torch.FloatTensor): + mask_image = preprocess_mask(mask_image) + mask_image = mask_image.to(device=self.device, dtype=latents_dtype) + mask = torch.cat([mask_image] * batch_size * num_images_per_prompt) + + # check sizes + if not mask.shape == init_latents.shape: + raise ValueError("The mask and init_image should be the same size!") + + # get the original timestep using init_timestep + offset = self.scheduler.config.get("steps_offset", 0) + init_timestep = int(num_inference_steps * strength) + offset + init_timestep = min(init_timestep, num_inference_steps) + + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device) + + # add noise to latents using the timesteps + noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timesteps) + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 567699986eb3..c18ebe6754c9 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1005,6 +1005,124 @@ def test_stable_diffusion_inpaint_num_images_per_prompt(self): assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3) + @unittest.skipIf(torch_device == "cpu", "This test requires a GPU") + def test_stable_diffusion_fp16(self): + """Test that stable diffusion works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=self.dummy_safety_checker, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images + + assert image.shape == (1, 128, 128, 3) + + @unittest.skipIf(torch_device == "cpu", "This test requires a GPU") + def test_stable_diffusion_img2img_fp16(self): + """Test that stable diffusion img2img works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + init_image = self.dummy_image.to(torch_device) + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionImg2ImgPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=self.dummy_safety_checker, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], + generator=generator, + num_inference_steps=2, + output_type="np", + init_image=init_image, + ).images + + assert image.shape == (1, 32, 32, 3) + + @unittest.skipIf(torch_device == "cpu", "This test requires a GPU") + def test_stable_diffusion_inpaint_fp16(self): + """Test that stable diffusion inpaint works with fp16""" + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB") + mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128)) + + # put models in fp16 + unet = unet.half() + vae = vae.half() + bert = bert.half() + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionInpaintPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=self.dummy_safety_checker, + feature_extractor=self.dummy_extractor, + ) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=torch_device).manual_seed(0) + image = sd_pipe( + [prompt], + generator=generator, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + ).images + + assert image.shape == (1, 32, 32, 3) + class PipelineTesterMixin(unittest.TestCase): def tearDown(self):