Skip to content

Commit

Permalink
[refactor] refactor after review
Browse files Browse the repository at this point in the history
  • Loading branch information
Suprhimp committed Jan 17, 2025
1 parent 25fa97c commit 5d6b78c
Showing 1 changed file with 17 additions and 22 deletions.
39 changes: 17 additions & 22 deletions src/diffusers/pipelines/flux/pipeline_flux_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,10 @@ def __init__(
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=self.vae.config.latent_channels,
vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
Expand Down Expand Up @@ -656,7 +657,7 @@ def disable_vae_tiling(self):
"""
self.vae.disable_tiling()

# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxImg2ImgPipeline.prepare_latents
def prepare_latents(
self,
image,
Expand All @@ -670,20 +671,24 @@ def prepare_latents(
generator,
latents=None,
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))

shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)

if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids

# if latents is not None:
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)

latent_image_ids = self._prepare_latent_image_ids(
batch_size, height // 2, width // 2, device, dtype
)
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
Expand All @@ -695,19 +700,10 @@ def prepare_latents(
else:
image_latents = torch.cat([image_latents], dim=0)

if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
else:
noise = latents.to(device)
latents = noise

noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
image_latents = self._pack_latents(
image_latents, batch_size, num_channels_latents, height, width
)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, noise, image_latents, latent_image_ids
return latents, latent_image_ids

@property
def guidance_scale(self):
Expand Down Expand Up @@ -866,7 +862,6 @@ def __call__(
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False

original_image = image
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)

Expand Down Expand Up @@ -935,7 +930,7 @@ def __call__(

# 5. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
latents, noise, image_latents, latent_image_ids = self.prepare_latents(
latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
batch_size * num_images_per_prompt,
Expand Down

0 comments on commit 5d6b78c

Please sign in to comment.