diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 8ff340cdb7ca..3ae53a101707 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -225,9 +225,10 @@ def __init__( # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible # by the patch size. So the vae scale factor is multiplied by the patch size to account for this self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16 self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor * 2, - vae_latent_channels=self.vae.config.latent_channels, + vae_latent_channels=latent_channels, do_normalize=False, do_binarize=True, do_convert_grayscale=True, @@ -656,7 +657,7 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxImg2ImgPipeline.prepare_latents def prepare_latents( self, image, @@ -670,20 +671,24 @@ def prepare_latents( generator, latents=None, ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) - shape = (batch_size, num_channels_latents, height, width) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + if latents is not None: + return latents.to(device=device, dtype=dtype), latent_image_ids - # if latents is not None: image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) - - latent_image_ids = self._prepare_latent_image_ids( - batch_size, height // 2, width // 2, device, dtype - ) if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // image_latents.shape[0] @@ -695,19 +700,10 @@ def prepare_latents( else: image_latents = torch.cat([image_latents], dim=0) - if latents is None: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) - else: - noise = latents.to(device) - latents = noise - - noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width) - image_latents = self._pack_latents( - image_latents, batch_size, num_channels_latents, height, width - ) + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - return latents, noise, image_latents, latent_image_ids + return latents, latent_image_ids @property def guidance_scale(self): @@ -866,7 +862,6 @@ def __call__( self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False - original_image = image init_image = self.image_processor.preprocess(image, height=height, width=width) init_image = init_image.to(dtype=torch.float32) @@ -935,7 +930,7 @@ def __call__( # 5. Prepare latent variables num_channels_latents = self.vae.config.latent_channels - latents, noise, image_latents, latent_image_ids = self.prepare_latents( + latents, latent_image_ids = self.prepare_latents( init_image, latent_timestep, batch_size * num_images_per_prompt,