From edb8c1bce67e81f0de90a7e4c16b2f6537d39f2d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sun, 12 Jan 2025 18:33:34 +0530 Subject: [PATCH] [Flux] Improve true cfg condition (#10539) * improve flux true cfg condition * add test --- src/diffusers/pipelines/flux/pipeline_flux.py | 5 ++++- tests/pipelines/flux/test_pipeline_flux.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 33154db54c73..f5716dc9c8ea 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -790,7 +790,10 @@ def __call__( lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) - do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt ( prompt_embeds, pooled_prompt_embeds, diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index ab36333c4056..addc29e14670 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -209,6 +209,17 @@ def test_flux_image_output_shape(self): output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width) + def test_flux_true_cfg(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + inputs.pop("generator") + + no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + inputs["negative_prompt"] = "bad quality" + inputs["true_cfg_scale"] = 2.0 + true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0] + assert not np.allclose(no_true_cfg_out, true_cfg_out) + @nightly @require_big_gpu_with_torch_cuda