Skip to content

Commit

Permalink
[Flux] Improve true cfg condition (#10539)
Browse files Browse the repository at this point in the history
* improve flux true cfg condition

* add test
  • Loading branch information
sayakpaul authored Jan 12, 2025
1 parent 0785dba commit edb8c1b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,10 @@ def __call__(
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
(
prompt_embeds,
pooled_prompt_embeds,
Expand Down
11 changes: 11 additions & 0 deletions tests/pipelines/flux/test_pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ def test_flux_image_output_shape(self):
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)

def test_flux_true_cfg(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
inputs.pop("generator")

no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
inputs["negative_prompt"] = "bad quality"
inputs["true_cfg_scale"] = 2.0
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
assert not np.allclose(no_true_cfg_out, true_cfg_out)


@nightly
@require_big_gpu_with_torch_cuda
Expand Down

0 comments on commit edb8c1b

Please sign in to comment.