Skip to content

Commit

Permalink
[SDXL Refiner] Fix refiner forward pass for batched input (#4327)
Browse files Browse the repository at this point in the history
* fix_batch_xl

* Fix other pipelines as well

* up

* up

* Update tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py

* sort

* up

* Finish it all up Co-authored-by: Bagheera <[email protected]>

* Co-authored-by: Bagheera [email protected]

* Co-authored-by: Bagheera <[email protected]>

* Finish it all up Co-authored-by: Bagheera <[email protected]>
  • Loading branch information
patrickvonplaten authored and sayakpaul committed Jul 28, 2023
1 parent aa4634a commit c63d7cd
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -906,15 +906,17 @@ def denoising_value_valid(dnv):
negative_aesthetic_score,
dtype=prompt_embeds.dtype,
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)

if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)

prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = add_time_ids.to(device)

# 9. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1168,15 +1168,17 @@ def denoising_value_valid(dnv):
negative_aesthetic_score,
dtype=prompt_embeds.dtype,
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)

if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)

prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = add_time_ids.to(device)

# 11. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ def __call__(
negative_aesthetic_score,
dtype=prompt_embeds.dtype,
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)

original_prompt_embeds_len = len(prompt_embeds)
original_add_text_embeds_len = len(add_text_embeds)
Expand All @@ -819,6 +820,7 @@ def __call__(
if do_classifier_free_guidance:
prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0)
add_text_embeds = torch.cat([add_text_embeds, negative_pooled_prompt_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = torch.cat([add_time_ids, add_neg_time_ids], dim=0)

# Make dimensions consistent
Expand All @@ -828,7 +830,7 @@ def __call__(

prompt_embeds = prompt_embeds.to(device).to(torch.float32)
add_text_embeds = add_text_embeds.to(device).to(torch.float32)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
add_time_ids = add_time_ids.to(device)

# 11. Denoising loop
self.unet = self.unet.to(torch.float32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def get_dummy_components(self, skip_first_text_encoder=False):
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
cross_attention_dim=64 if not skip_first_text_encoder else 32,
)
scheduler = EulerDiscreteScheduler(
Expand Down Expand Up @@ -113,9 +113,18 @@ def get_dummy_components(self, skip_first_text_encoder=False):
"tokenizer": tokenizer if not skip_first_text_encoder else None,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"requires_aesthetics_score": True,
}
return components

def test_components_function(self):
init_components = self.get_dummy_components()
init_components.pop("requires_aesthetics_score")
pipe = self.pipeline_class(**init_components)

self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))

def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
Expand Down Expand Up @@ -147,7 +156,7 @@ def test_stable_diffusion_xl_img2img_euler(self):

assert image.shape == (1, 32, 32, 3)

expected_slice = np.array([0.4656, 0.4840, 0.4439, 0.6698, 0.5574, 0.4524, 0.5799, 0.5943, 0.5165])
expected_slice = np.array([0.4664, 0.4886, 0.4403, 0.6902, 0.5592, 0.4534, 0.5931, 0.5951, 0.5224])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

Expand All @@ -165,7 +174,7 @@ def test_stable_diffusion_xl_refiner(self):

assert image.shape == (1, 32, 32, 3)

expected_slice = np.array([0.4676, 0.4865, 0.4335, 0.6715, 0.5578, 0.4497, 0.5847, 0.5967, 0.5198])
expected_slice = np.array([0.4578, 0.4981, 0.4301, 0.6454, 0.5588, 0.4442, 0.5678, 0.5940, 0.5176])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_dummy_components(self, skip_first_text_encoder=False):
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
cross_attention_dim=64 if not skip_first_text_encoder else 32,
)
scheduler = EulerDiscreteScheduler(
Expand Down Expand Up @@ -115,6 +115,7 @@ def get_dummy_components(self, skip_first_text_encoder=False):
"tokenizer": tokenizer if not skip_first_text_encoder else None,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"requires_aesthetics_score": True,
}
return components

Expand Down Expand Up @@ -142,6 +143,14 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs

def test_components_function(self):
init_components = self.get_dummy_components()
init_components.pop("requires_aesthetics_score")
pipe = self.pipeline_class(**init_components)

self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))

def test_stable_diffusion_xl_inpaint_euler(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
Expand All @@ -155,7 +164,7 @@ def test_stable_diffusion_xl_inpaint_euler(self):

assert image.shape == (1, 64, 64, 3)

expected_slice = np.array([0.6965, 0.5584, 0.5693, 0.5739, 0.6092, 0.6620, 0.5902, 0.5612, 0.5319])
expected_slice = np.array([0.8029, 0.5523, 0.5825, 0.6003, 0.6702, 0.7018, 0.6369, 0.5955, 0.5123])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

Expand Down Expand Up @@ -250,10 +259,9 @@ def test_stable_diffusion_xl_refiner(self):
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]

print(torch.from_numpy(image_slice).flatten())
assert image.shape == (1, 64, 64, 3)

expected_slice = np.array([0.9106, 0.6563, 0.6766, 0.6537, 0.6709, 0.7367, 0.6537, 0.5937, 0.5418])
expected_slice = np.array([0.7045, 0.4838, 0.5454, 0.6270, 0.6168, 0.6717, 0.6484, 0.5681, 0.4922])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_dummy_components(self):
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
projection_class_embeddings_input_dim=72, # 5 * 8 + 32
cross_attention_dim=64,
)

Expand Down Expand Up @@ -118,8 +118,7 @@ def get_dummy_components(self):
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
# "safety_checker": None,
# "feature_extractor": None,
"requires_aesthetics_score": True,
}
return components

Expand All @@ -141,6 +140,14 @@ def get_dummy_inputs(self, device, seed=0):
}
return inputs

def test_components_function(self):
init_components = self.get_dummy_components()
init_components.pop("requires_aesthetics_score")
pipe = self.pipeline_class(**init_components)

self.assertTrue(hasattr(pipe, "components"))
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))

def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)

Expand Down

0 comments on commit c63d7cd

Please sign in to comment.