Skip to content

Commit

Permalink
support system prompt and cfg renorm in Lumina2 (#6795)
Browse files Browse the repository at this point in the history
* support system prompt and cfg renorm in Lumina2

* fix issues with the ruff style check
  • Loading branch information
lzyhha authored Feb 16, 2025
1 parent d0399f4 commit 61c8c70
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 0 deletions.
104 changes: 104 additions & 0 deletions comfy_extras/nodes_lumina2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict
import torch


class RenormCFG:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"cfg_trunc": ("FLOAT", {"default": 100, "min": 0.0, "max": 100.0, "step": 0.01}),
"renorm_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "advanced/model"

def patch(self, model, cfg_trunc, renorm_cfg):
def renorm_cfg_func(args):
cond_denoised = args["cond_denoised"]
uncond_denoised = args["uncond_denoised"]
cond_scale = args["cond_scale"]
timestep = args["timestep"]
x_orig = args["input"]
in_channels = model.model.diffusion_model.in_channels

if timestep[0] < cfg_trunc:
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
half_eps = uncond_eps + cond_scale * (cond_eps - uncond_eps)
half_rest = cond_rest

if float(renorm_cfg) > 0.0:
ori_pos_norm = torch.linalg.vector_norm(cond_eps
, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True
)
max_new_norm = ori_pos_norm * float(renorm_cfg)
new_pos_norm = torch.linalg.vector_norm(
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True
)
if new_pos_norm >= max_new_norm:
half_eps = half_eps * (max_new_norm / new_pos_norm)
else:
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels]
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:]
half_eps = cond_eps
half_rest = cond_rest

cfg_result = torch.cat([half_eps, half_rest], dim=1)

# cfg_result = uncond_denoised + (cond_denoised - uncond_denoised) * cond_scale

return x_orig - cfg_result

m = model.clone()
m.set_model_sampler_cfg_function(renorm_cfg_func)
return (m, )


class CLIPTextEncodeLumina2(ComfyNodeABC):
SYSTEM_PROMPT = {
"superior": "You are an assistant designed to generate superior images with the superior "\
"degree of image-text alignment based on textual prompts or user prompts.",
"alignment": "You are an assistant designed to generate high-quality images with the "\
"highest degree of image-text alignment based on textual prompts."
}
SYSTEM_PROMPT_TIP = "Lumina2 provide two types of system prompts:" \
"Superior: You are an assistant designed to generate superior images with the superior "\
"degree of image-text alignment based on textual prompts or user prompts. "\
"Alignment: You are an assistant designed to generate high-quality images with the highest "\
"degree of image-text alignment based on textual prompts."
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
return {
"required": {
"system_prompt": (list(CLIPTextEncodeLumina2.SYSTEM_PROMPT.keys()), {"tooltip": CLIPTextEncodeLumina2.SYSTEM_PROMPT_TIP}),
"user_prompt": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
}
}
RETURN_TYPES = (IO.CONDITIONING,)
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",)
FUNCTION = "encode"

CATEGORY = "conditioning"
DESCRIPTION = "Encodes a system prompt and a user prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."

def encode(self, clip, user_prompt, system_prompt):
if clip is None:
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.")
system_prompt = CLIPTextEncodeLumina2.SYSTEM_PROMPT[system_prompt]
prompt = f'{system_prompt} <Prompt Start> {user_prompt}'
tokens = clip.tokenize(prompt)
return (clip.encode_from_tokens_scheduled(tokens), )


NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeLumina2": CLIPTextEncodeLumina2,
"RenormCFG": RenormCFG
}


NODE_DISPLAY_NAME_MAPPINGS = {
"CLIPTextEncodeLumina2": "CLIP Text Encode for Lumina2",
}
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2233,6 +2233,7 @@ def init_builtin_extra_nodes():
"nodes_hooks.py",
"nodes_load_3d.py",
"nodes_cosmos.py",
"nodes_lumina2.py",
]

import_failed = []
Expand Down

0 comments on commit 61c8c70

Please sign in to comment.