-
Notifications
You must be signed in to change notification settings - Fork 7.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
support system prompt and cfg renorm in Lumina2 (#6795)
* support system prompt and cfg renorm in Lumina2 * fix issues with the ruff style check
- Loading branch information
Showing
2 changed files
with
105 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict | ||
import torch | ||
|
||
|
||
class RenormCFG: | ||
@classmethod | ||
def INPUT_TYPES(s): | ||
return {"required": { "model": ("MODEL",), | ||
"cfg_trunc": ("FLOAT", {"default": 100, "min": 0.0, "max": 100.0, "step": 0.01}), | ||
"renorm_cfg": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), | ||
}} | ||
RETURN_TYPES = ("MODEL",) | ||
FUNCTION = "patch" | ||
|
||
CATEGORY = "advanced/model" | ||
|
||
def patch(self, model, cfg_trunc, renorm_cfg): | ||
def renorm_cfg_func(args): | ||
cond_denoised = args["cond_denoised"] | ||
uncond_denoised = args["uncond_denoised"] | ||
cond_scale = args["cond_scale"] | ||
timestep = args["timestep"] | ||
x_orig = args["input"] | ||
in_channels = model.model.diffusion_model.in_channels | ||
|
||
if timestep[0] < cfg_trunc: | ||
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels] | ||
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:] | ||
half_eps = uncond_eps + cond_scale * (cond_eps - uncond_eps) | ||
half_rest = cond_rest | ||
|
||
if float(renorm_cfg) > 0.0: | ||
ori_pos_norm = torch.linalg.vector_norm(cond_eps | ||
, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True | ||
) | ||
max_new_norm = ori_pos_norm * float(renorm_cfg) | ||
new_pos_norm = torch.linalg.vector_norm( | ||
half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True | ||
) | ||
if new_pos_norm >= max_new_norm: | ||
half_eps = half_eps * (max_new_norm / new_pos_norm) | ||
else: | ||
cond_eps, uncond_eps = cond_denoised[:, :in_channels], uncond_denoised[:, :in_channels] | ||
cond_rest, _ = cond_denoised[:, in_channels:], uncond_denoised[:, in_channels:] | ||
half_eps = cond_eps | ||
half_rest = cond_rest | ||
|
||
cfg_result = torch.cat([half_eps, half_rest], dim=1) | ||
|
||
# cfg_result = uncond_denoised + (cond_denoised - uncond_denoised) * cond_scale | ||
|
||
return x_orig - cfg_result | ||
|
||
m = model.clone() | ||
m.set_model_sampler_cfg_function(renorm_cfg_func) | ||
return (m, ) | ||
|
||
|
||
class CLIPTextEncodeLumina2(ComfyNodeABC): | ||
SYSTEM_PROMPT = { | ||
"superior": "You are an assistant designed to generate superior images with the superior "\ | ||
"degree of image-text alignment based on textual prompts or user prompts.", | ||
"alignment": "You are an assistant designed to generate high-quality images with the "\ | ||
"highest degree of image-text alignment based on textual prompts." | ||
} | ||
SYSTEM_PROMPT_TIP = "Lumina2 provide two types of system prompts:" \ | ||
"Superior: You are an assistant designed to generate superior images with the superior "\ | ||
"degree of image-text alignment based on textual prompts or user prompts. "\ | ||
"Alignment: You are an assistant designed to generate high-quality images with the highest "\ | ||
"degree of image-text alignment based on textual prompts." | ||
@classmethod | ||
def INPUT_TYPES(s) -> InputTypeDict: | ||
return { | ||
"required": { | ||
"system_prompt": (list(CLIPTextEncodeLumina2.SYSTEM_PROMPT.keys()), {"tooltip": CLIPTextEncodeLumina2.SYSTEM_PROMPT_TIP}), | ||
"user_prompt": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}), | ||
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}) | ||
} | ||
} | ||
RETURN_TYPES = (IO.CONDITIONING,) | ||
OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",) | ||
FUNCTION = "encode" | ||
|
||
CATEGORY = "conditioning" | ||
DESCRIPTION = "Encodes a system prompt and a user prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." | ||
|
||
def encode(self, clip, user_prompt, system_prompt): | ||
if clip is None: | ||
raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.") | ||
system_prompt = CLIPTextEncodeLumina2.SYSTEM_PROMPT[system_prompt] | ||
prompt = f'{system_prompt} <Prompt Start> {user_prompt}' | ||
tokens = clip.tokenize(prompt) | ||
return (clip.encode_from_tokens_scheduled(tokens), ) | ||
|
||
|
||
NODE_CLASS_MAPPINGS = { | ||
"CLIPTextEncodeLumina2": CLIPTextEncodeLumina2, | ||
"RenormCFG": RenormCFG | ||
} | ||
|
||
|
||
NODE_DISPLAY_NAME_MAPPINGS = { | ||
"CLIPTextEncodeLumina2": "CLIP Text Encode for Lumina2", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters