Skip to content

Commit

Permalink
Add FluxDisableGuidance node to disable using the guidance embed.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jan 20, 2025
1 parent d8a7a32 commit fb2ad64
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 10 deletions.
7 changes: 3 additions & 4 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,8 @@ def forward_orig(
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))

vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
Expand Down Expand Up @@ -186,7 +185,7 @@ def block_wrap(args):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img

def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
Expand Down
7 changes: 3 additions & 4 deletions comfy/ldm/hunyuan_video/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,8 @@ def forward_orig(
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])

if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))

if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
Expand Down Expand Up @@ -314,7 +313,7 @@ def block_wrap(args):
img = img.reshape(initial_shape)
return img

def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
Expand Down
10 changes: 8 additions & 2 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,10 @@ def extra_conds(self, **kwargs):
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))

guidance = kwargs.get("guidance", 3.5)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out

class GenmoMochi(BaseModel):
Expand Down Expand Up @@ -869,7 +872,10 @@ def extra_conds(self, **kwargs):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))

guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out

class CosmosVideo(BaseModel):
Expand Down
19 changes: 19 additions & 0 deletions comfy_extras/nodes_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,26 @@ def append(self, conditioning, guidance):
return (c, )


class FluxDisableGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
}}

RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"

CATEGORY = "advanced/conditioning/flux"
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"

def append(self, conditioning):
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
return (c, )


NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance,
}

0 comments on commit fb2ad64

Please sign in to comment.